Skip to content

Ewise-union #159

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions grblas/_automethods.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,10 @@ def ewise_mult(self):
return self._get_value("ewise_mult")


def ewise_union(self):
return self._get_value("ewise_union")


def gb_obj(self):
return self._get_value("gb_obj")

Expand Down Expand Up @@ -332,6 +336,7 @@ def __ixor__(self, other):
"apply",
"ewise_add",
"ewise_mult",
"ewise_union",
"ss",
"to_values",
}
Expand Down
63 changes: 27 additions & 36 deletions grblas/_infixmethods.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .vector import Vector, VectorExpression


def call_op(self, other, method, op, *, scalar_only=False, outer=False):
def call_op(self, other, method, op, *, outer=False, union=False):
type1 = output_type(self)
type2 = output_type(other)
if (
Expand All @@ -16,23 +16,10 @@ def call_op(self, other, method, op, *, scalar_only=False, outer=False):
or type1 is TransposedMatrix
and type2 is Matrix
):
if scalar_only:
raise TypeError(
f"Infix operator {method} between {type1.__name__} and {type2.__name__} is not "
"supported. This infix operation is only allowed if one of the arguments is a "
"scalar. We refuse to guess whether you intend to do ewise_mult or ewise_add."
"\n\nYou must indicate ewise_mult (intersection) or ewise_add (union) explicitly."
"\n\nFor ewise_mult:\n"
f" >>> op.{op.name}(x & y)\n"
"or\n"
f" >>> x.ewise_mult(y, op.{op.name})\n\n"
"For ewise_add:\n"
f" >>> op.{op.name}(x | y)\n"
"or\n"
f" >>> x.ewise_add(y, op.{op.name})\n\n"
)
elif outer:
if outer:
return op(self | other, require_monoid=False)
elif union:
return self.ewise_union(other, op, False, False)
else:
return op(self & other)
return op(self, other)
Expand Down Expand Up @@ -110,21 +97,6 @@ def __iand__(self, other):
return self


def __sub__(self, other):
# TODO: use GxB Union
return call_op(self, other, "__sub__", binary.minus, scalar_only=True)


def __rsub__(self, other):
# TODO: use GxB Union
return call_op(other, self, "__rsub__", binary.minus, scalar_only=True)


def __isub__(self, other):
self << __sub__(self, other)
return self


# Begin auto-generated code
def __eq__(self, other):
return call_op(self, other, "__eq__", binary.eq)
Expand Down Expand Up @@ -215,6 +187,19 @@ def __ipow__(self, other):
return self


def __sub__(self, other):
return call_op(self, other, "__sub__", binary.minus, union=True)


def __rsub__(self, other):
return call_op(other, self, "__rsub__", binary.minus, union=True)


def __isub__(self, other):
self << __sub__(self, other)
return self


def __truediv__(self, other):
return call_op(self, other, "__truediv__", binary.truediv)

Expand Down Expand Up @@ -296,11 +281,15 @@ def __itruediv__(self, other):
"floordiv": "floordiv",
"mod": "numpy.mod",
"pow": "pow",
"sub": "minus",
}
# monoids with 0 identity use outer (ewise_add): + | ^
outer = {
"add",
}
union = {
"sub",
}
custom = {
"abs",
"divmod",
Expand All @@ -312,9 +301,6 @@ def __itruediv__(self, other):
"ixor",
"ior",
"iand",
"sub",
"isub",
"rsub",
}
# Skipped: rshift, pos
# Already used for syntax: lshift, and, or
Expand All @@ -326,7 +312,12 @@ def __itruediv__(self, other):
f' return call_op(self, other, "__{method}__", binary.{op})\n\n'
)
for method, op in sorted(operations.items()):
out = ", outer=True" if method in outer else ""
if method in outer:
out = ", outer=True"
elif method in union:
out = ", union=True"
else:
out = ""
lines.append(
f"def __{method}__(self, other):\n"
f' return call_op(self, other, "__{method}__", binary.{op}{out})\n\n'
Expand Down
11 changes: 5 additions & 6 deletions grblas/_ss/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ..base import call, record_raw
from ..dtypes import _INDEX, INT64, lookup_dtype
from ..exceptions import check_status, check_status_carg
from ..scalar import Scalar, _CScalar
from ..scalar import Scalar, _CScalar, _GrBScalar
from ..utils import (
_CArray,
_Pointer,
Expand All @@ -23,7 +23,6 @@
values_to_numpy_buffer,
wrapdoc,
)
from .scalar import gxb_scalar
from .utils import get_order

ffi_new = ffi.new
Expand Down Expand Up @@ -539,11 +538,11 @@ def build_scalar(self, rows, columns, value):
raise ValueError(
f"`rows` and `columns` lengths must match: {rows.size}, {columns.size}"
)
scalar = gxb_scalar(self._parent.dtype, value)
status = lib.GxB_Matrix_build_Scalar(
self._parent._carg, _CArray(rows)._carg, _CArray(columns)._carg, scalar[0], rows.size
scalar = _GrBScalar(value, self._parent.dtype)
call(
"GxB_Matrix_build_Scalar",
[self._parent, _CArray(rows), _CArray(columns), scalar, _CScalar(rows.size)],
)
check_status(status, self._parent)

def export(self, format=None, *, sort=False, give_ownership=False, raw=False):
"""
Expand Down
19 changes: 0 additions & 19 deletions grblas/_ss/scalar.py

This file was deleted.

11 changes: 5 additions & 6 deletions grblas/_ss/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ..base import call
from ..dtypes import INT64, UINT64, lookup_dtype
from ..exceptions import check_status, check_status_carg
from ..scalar import _CScalar
from ..scalar import _CScalar, _GrBScalar
from ..utils import (
_CArray,
ints_to_numpy_buffer,
Expand All @@ -18,7 +18,6 @@
)
from .matrix import MatrixArray, _concat_mn, normalize_chunks
from .prefix_scan import prefix_scan
from .scalar import gxb_scalar
from .utils import get_order

ffi_new = ffi.new
Expand Down Expand Up @@ -246,11 +245,11 @@ def build_scalar(self, indices, value):
Vector.from_values
"""
indices = ints_to_numpy_buffer(indices, np.uint64, name="indices")
scalar = gxb_scalar(self._parent.dtype, value)
status = lib.GxB_Vector_build_Scalar(
self._parent._carg, _CArray(indices)._carg, scalar[0], indices.size
scalar = _GrBScalar(value, self._parent.dtype)
call(
"GxB_Vector_build_Scalar",
[self._parent, _CArray(indices), scalar, _CScalar(indices.size)],
)
check_status(status, self._parent)

def export(self, format=None, *, sort=False, give_ownership=False, raw=False):
"""
Expand Down
7 changes: 5 additions & 2 deletions grblas/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .exceptions import check_status
from .expr import AmbiguousAssignOrExtract, Updater
from .mask import Mask
from .operator import UNKNOWN_OPCLASS, find_opclass, get_typed_op
from .operator import UNKNOWN_OPCLASS, binary_from_string, find_opclass, get_typed_op
from .utils import _Pointer, libget, output_type

NULL = ffi.NULL
Expand Down Expand Up @@ -207,7 +207,10 @@ def __call__(
raise TypeError("Got multiple values for argument 'accum'")
accum_arg, opclass = find_opclass(arg)
if opclass == UNKNOWN_OPCLASS:
raise TypeError(f"Invalid item found in output params: {type(arg)}")
if isinstance(accum_arg, str):
accum_arg = binary_from_string(accum_arg)
else:
raise TypeError(f"Invalid item found in output params: {type(arg)}")
# Merge positional and keyword arguments
if mask_arg is not None and mask is not None:
raise TypeError("Got multiple values for argument 'mask'")
Expand Down
2 changes: 2 additions & 0 deletions grblas/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ def check_status(response_code, args):
if type(arg) is _Pointer:
arg = arg.val
type_name = type(arg).__name__
if type_name == "_GrBScalar":
type_name = "Scalar"
carg = arg._carg
return check_status_carg(response_code, type_name, carg)

Expand Down
2 changes: 2 additions & 0 deletions grblas/infix.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def shape(self):
apply = wrapdoc(Vector.apply)(property(_automethods.apply))
ewise_add = wrapdoc(Vector.ewise_add)(property(_automethods.ewise_add))
ewise_mult = wrapdoc(Vector.ewise_mult)(property(_automethods.ewise_mult))
ewise_union = wrapdoc(Vector.ewise_union)(property(_automethods.ewise_union))
gb_obj = wrapdoc(Vector.gb_obj)(property(_automethods.gb_obj))
inner = wrapdoc(Vector.inner)(property(_automethods.inner))
isclose = wrapdoc(Vector.isclose)(property(_automethods.isclose))
Expand Down Expand Up @@ -180,6 +181,7 @@ def shape(self):
apply = wrapdoc(Matrix.apply)(property(_automethods.apply))
ewise_add = wrapdoc(Matrix.ewise_add)(property(_automethods.ewise_add))
ewise_mult = wrapdoc(Matrix.ewise_mult)(property(_automethods.ewise_mult))
ewise_union = wrapdoc(Matrix.ewise_union)(property(_automethods.ewise_union))
gb_obj = wrapdoc(Matrix.gb_obj)(property(_automethods.gb_obj))
isclose = wrapdoc(Matrix.isclose)(property(_automethods.isclose))
isequal = wrapdoc(Matrix.isequal)(property(_automethods.isequal))
Expand Down
38 changes: 37 additions & 1 deletion grblas/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from .expr import AmbiguousAssignOrExtract, IndexerResolver, Updater
from .mask import StructuralMask, ValueMask
from .operator import get_typed_op
from .scalar import Scalar, ScalarExpression, _CScalar
from .scalar import Scalar, ScalarExpression, _CScalar, _GrBScalar
from .utils import (
_CArray,
_Pointer,
Expand Down Expand Up @@ -487,6 +487,40 @@ def ewise_mult(self, other, op=binary.times):
expr.new(name="") # incompatible shape; raise now
return expr

def ewise_union(self, other, op, left_default, right_default):
"""
GxB_Matrix_eWiseUnion

This is similar to `ewise_add` in that result will contain the union of
indices from both Matrices. Unlike `ewise_add`, this will use
``left_default`` for the left value when there is a value on the right
but not the left, and ``right_default`` for the right value when there
is a value on the left but not the right.

``op`` should be a BinaryOp or Monoid.
"""
# SS, SuiteSparse-specific: eWiseUnion
method_name = "ewise_union"
other = self._expect_type(other, Matrix, within=method_name, argname="other", op=op)
left = _GrBScalar(left_default)
right = _GrBScalar(right_default)
scalar_dtype = unify(left.dtype, right.dtype)
nonscalar_dtype = unify(self.dtype, other.dtype)
op = get_typed_op(op, scalar_dtype, nonscalar_dtype, is_left_scalar=True, kind="binary")
self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op")
if op.opclass == "Monoid":
op = op.binaryop
expr = MatrixExpression(
method_name,
"GxB_Matrix_eWiseUnion",
[self, left, other, right],
op=op,
expr_repr="{0.name}.{method_name}({2.name}, {op}, {1.name}, {3.name})",
)
if self.shape != other.shape:
expr.new(name="") # incompatible shape; raise now
return expr

def mxv(self, other, op=semiring.plus_times):
"""
GrB_mxv
Expand Down Expand Up @@ -1244,6 +1278,7 @@ def shape(self):
apply = wrapdoc(Matrix.apply)(property(_automethods.apply))
ewise_add = wrapdoc(Matrix.ewise_add)(property(_automethods.ewise_add))
ewise_mult = wrapdoc(Matrix.ewise_mult)(property(_automethods.ewise_mult))
ewise_union = wrapdoc(Matrix.ewise_union)(property(_automethods.ewise_union))
gb_obj = wrapdoc(Matrix.gb_obj)(property(_automethods.gb_obj))
isclose = wrapdoc(Matrix.isclose)(property(_automethods.isclose))
isequal = wrapdoc(Matrix.isequal)(property(_automethods.isequal))
Expand Down Expand Up @@ -1351,6 +1386,7 @@ def _name_html(self):
# Delayed methods
ewise_add = Matrix.ewise_add
ewise_mult = Matrix.ewise_mult
ewise_union = Matrix.ewise_union
mxv = Matrix.mxv
mxm = Matrix.mxm
kronecker = Matrix.kronecker
Expand Down
32 changes: 29 additions & 3 deletions grblas/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

import numpy as np

from . import _automethods, backend, ffi, utils
from .base import BaseExpression, BaseType
from . import _automethods, backend, ffi, lib, utils
from .base import BaseExpression, BaseType, call
from .binary import isclose
from .dtypes import _INDEX, BOOL, lookup_dtype
from .exceptions import check_status
from .operator import get_typed_op
from .utils import output_type, wrapdoc
from .utils import _Pointer, output_type, wrapdoc

ffi_new = ffi.new

Expand Down Expand Up @@ -380,5 +381,30 @@ def __eq__(self, other):
return self.scalar == other


class _GrBScalar:
"""Wrap scalars as GrB_Scalars for calling into C"""

__slots__ = "gb_obj", "dtype", "name"

def __init__(self, scalar, dtype=None):
cscalar = _CScalar(scalar, dtype)
self.gb_obj = ffi_new("GrB_Scalar*")
self.dtype = cscalar.dtype
self.name = cscalar.name
call("GrB_Scalar_new", [_Pointer(self), self.dtype])
if not cscalar.scalar._is_empty:
call(f"GrB_Scalar_setElement_{self.dtype.name}", [self, cscalar])

def __del__(self):
gb_obj = getattr(self, "gb_obj", None)
if gb_obj is not None:
# it's difficult/dangerous to record the call, b/c `self.name` may not exist
check_status(lib.GrB_Scalar_free(gb_obj), self)

@property
def _carg(self):
return self.gb_obj[0]


utils._output_types[Scalar] = Scalar
utils._output_types[ScalarExpression] = Scalar
4 changes: 2 additions & 2 deletions grblas/tests/test_infix.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,8 +274,8 @@ def test_apply_binary_bad(s1, v1):
def test_infix_nonscalars(v1, v2):
# with raises(TypeError, match="refuse to guess"):
assert (v1 + v2).new().isequal(op.plus(v1 | v2).new())
with raises(TypeError, match="refuse to guess"):
v1 - v2 # Not handled yet
# with raises(TypeError, match="refuse to guess"):
assert (v1 - v2).new().isequal(v1.ewise_union(v2, "-", 0, 0).new())


@autocompute
Expand Down
Loading