Skip to content

Add semiring(A @ B @ C) that applies semiring to both matmuls #501

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Nov 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions graphblas/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,23 +263,31 @@ def __call__(
)

def __or__(self, other):
from .infix import _ewise_infix_expr
from .infix import _ewise_infix_expr, _ewise_mult_expr_types

if isinstance(other, _ewise_mult_expr_types):
raise TypeError("XXX")
return _ewise_infix_expr(self, other, method="ewise_add", within="__or__")

def __ror__(self, other):
from .infix import _ewise_infix_expr
from .infix import _ewise_infix_expr, _ewise_mult_expr_types

if isinstance(other, _ewise_mult_expr_types):
raise TypeError("XXX")
return _ewise_infix_expr(other, self, method="ewise_add", within="__ror__")

def __and__(self, other):
from .infix import _ewise_infix_expr
from .infix import _ewise_add_expr_types, _ewise_infix_expr

if isinstance(other, _ewise_add_expr_types):
raise TypeError("XXX")
return _ewise_infix_expr(self, other, method="ewise_mult", within="__and__")

def __rand__(self, other):
from .infix import _ewise_infix_expr
from .infix import _ewise_add_expr_types, _ewise_infix_expr

if isinstance(other, _ewise_add_expr_types):
raise TypeError("XXX")
return _ewise_infix_expr(other, self, method="ewise_mult", within="__rand__")

def __matmul__(self, other):
Expand Down
72 changes: 72 additions & 0 deletions graphblas/core/infix.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,19 @@ class ScalarEwiseAddExpr(ScalarInfixExpr):

_to_expr = _ewise_add_to_expr

# Allow e.g. `plus(x | y | z)`
__or__ = Scalar.__or__
__ror__ = Scalar.__ror__
_ewise_add = Scalar._ewise_add
_ewise_union = Scalar._ewise_union

# Don't allow e.g. `plus(x | y & z)`
def __and__(self, other):
raise TypeError("XXX")

def __rand__(self, other):
raise TypeError("XXX")


class ScalarEwiseMultExpr(ScalarInfixExpr):
__slots__ = ()
Expand All @@ -135,6 +148,18 @@ class ScalarEwiseMultExpr(ScalarInfixExpr):

_to_expr = _ewise_mult_to_expr

# Allow e.g. `plus(x & y & z)`
__and__ = Scalar.__and__
__rand__ = Scalar.__rand__
_ewise_mult = Scalar._ewise_mult

# Don't allow e.g. `plus(x | y & z)`
def __or__(self, other):
raise TypeError("XXX")

def __ror__(self, other):
raise TypeError("XXX")


class ScalarMatMulExpr(ScalarInfixExpr):
__slots__ = ()
Expand Down Expand Up @@ -239,6 +264,15 @@ class VectorEwiseAddExpr(VectorInfixExpr):

_to_expr = _ewise_add_to_expr

# Allow e.g. `plus(x | y | z)`
__or__ = Vector.__or__
__ror__ = Vector.__ror__
_ewise_add = Vector._ewise_add
_ewise_union = Vector._ewise_union
# Don't allow e.g. `plus(x | y & z)`
__and__ = ScalarEwiseAddExpr.__and__ # raises
__rand__ = ScalarEwiseAddExpr.__rand__ # raises


class VectorEwiseMultExpr(VectorInfixExpr):
__slots__ = ()
Expand All @@ -248,6 +282,14 @@ class VectorEwiseMultExpr(VectorInfixExpr):

_to_expr = _ewise_mult_to_expr

# Allow e.g. `plus(x & y & z)`
__and__ = Vector.__and__
__rand__ = Vector.__rand__
_ewise_mult = Vector._ewise_mult
# Don't allow e.g. `plus(x | y & z)`
__or__ = ScalarEwiseMultExpr.__or__ # raises
__ror__ = ScalarEwiseMultExpr.__ror__ # raises


class VectorMatMulExpr(VectorInfixExpr):
__slots__ = "method_name"
Expand All @@ -259,6 +301,11 @@ def __init__(self, left, right, *, method_name, size):
self.method_name = method_name
self._size = size

__matmul__ = Vector.__matmul__
__rmatmul__ = Vector.__rmatmul__
_inner = Vector._inner
_vxm = Vector._vxm


utils._output_types[VectorEwiseAddExpr] = Vector
utils._output_types[VectorEwiseMultExpr] = Vector
Expand Down Expand Up @@ -376,6 +423,15 @@ class MatrixEwiseAddExpr(MatrixInfixExpr):

_to_expr = _ewise_add_to_expr

# Allow e.g. `plus(x | y | z)`
__or__ = Matrix.__or__
__ror__ = Matrix.__ror__
_ewise_add = Matrix._ewise_add
_ewise_union = Matrix._ewise_union
# Don't allow e.g. `plus(x | y & z)`
__and__ = VectorEwiseAddExpr.__and__ # raises
__rand__ = VectorEwiseAddExpr.__rand__ # raises


class MatrixEwiseMultExpr(MatrixInfixExpr):
__slots__ = ()
Expand All @@ -385,6 +441,14 @@ class MatrixEwiseMultExpr(MatrixInfixExpr):

_to_expr = _ewise_mult_to_expr

# Allow e.g. `plus(x & y & z)`
__and__ = Matrix.__and__
__rand__ = Matrix.__rand__
_ewise_mult = Matrix._ewise_mult
# Don't allow e.g. `plus(x | y & z)`
__or__ = VectorEwiseMultExpr.__or__ # raises
__ror__ = VectorEwiseMultExpr.__ror__ # raises


class MatrixMatMulExpr(MatrixInfixExpr):
__slots__ = ()
Expand All @@ -397,6 +461,11 @@ def __init__(self, left, right, *, nrows, ncols):
self._nrows = nrows
self._ncols = ncols

__matmul__ = Matrix.__matmul__
__rmatmul__ = Matrix.__rmatmul__
_mxm = Matrix._mxm
_mxv = Matrix._mxv


utils._output_types[MatrixEwiseAddExpr] = Matrix
utils._output_types[MatrixEwiseMultExpr] = Matrix
Expand Down Expand Up @@ -514,5 +583,8 @@ def _matmul_infix_expr(left, right, *, within):
)


_ewise_add_expr_types = (MatrixEwiseAddExpr, VectorEwiseAddExpr, ScalarEwiseAddExpr)
_ewise_mult_expr_types = (MatrixEwiseMultExpr, VectorEwiseMultExpr, ScalarEwiseMultExpr)

# Import infixmethods, which has side effects
from . import infixmethods # noqa: E402, F401 isort:skip
Loading