Skip to content

Autocompute AmbiguousAssignOrExtract as extract #166

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Feb 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions grblas/_agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,11 +530,11 @@ def _first_last(agg, updater, expr, *, in_composite, semiring):
# Populate numpy array
vals = np.empty(Is.size, dtype=A.dtype.np_type)
for index, (i, j) in enumerate(zip(Is, Js)):
vals[index] = A[i, j].value
vals[index] = A[i, j].new().value
# or Vector
# v = expr._new_vector(A.dtype, size=A._nrows)
# for i, j in zip(Is, Js):
# v[i] = A[i, j].value
# v[i] = A[i, j].new().value
result = Vector.from_values(Is, vals, size=A._nrows)
updater << result
if in_composite:
Expand All @@ -544,7 +544,7 @@ def _first_last(agg, updater, expr, *, in_composite, semiring):
init = expr._new_matrix(bool, nrows=v._size, ncols=1)
init[...] = False # O(1) dense matrix in SuiteSparse 5
step1 = semiring(v @ init).new()
index = step1[0].value
index = step1[0].new().value
if index is None:
index = 0
if in_composite:
Expand All @@ -558,11 +558,11 @@ def _first_last(agg, updater, expr, *, in_composite, semiring):
init2 = expr._new_vector(bool, size=A._nrows)
init2[...] = False # O(1) dense vector in SuiteSparse 5
step2 = semiring(step1.T @ init2).new()
i = step2[0].value
i = step2[0].new().value
if i is None:
i = j = 0
else:
j = step1[i, 0].value
j = step1[i, 0].new().value
if in_composite:
return A[i, [j]].new()
updater << A[i, j]
Expand Down
11 changes: 11 additions & 0 deletions grblas/_automethods.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,14 @@ def __ror__(self):
return self._get_value("__ror__")


def _as_matrix(self):
return self._get_value("_as_matrix")


def _as_vector(self):
return self._get_value("_as_vector")


def _carg(self):
return self._get_value("_carg")

Expand Down Expand Up @@ -321,6 +329,8 @@ def __ixor__(self, other):
"__int__",
"__invert__",
"__neg__",
"_as_matrix",
"_as_vector",
"_is_empty",
"is_empty",
"value",
Expand All @@ -346,6 +356,7 @@ def __ixor__(self, other):
"to_values",
}
vector = {
"_as_matrix",
"inner",
"outer",
"reduce",
Expand Down
8 changes: 6 additions & 2 deletions grblas/_infixmethods.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from . import binary, unary
from .dtypes import BOOL
from .infix import MatrixInfixExpr, VectorInfixExpr
from .matrix import Matrix, MatrixExpression, TransposedMatrix
from .matrix import Matrix, MatrixExpression, MatrixIndexExpr, TransposedMatrix
from .utils import output_type
from .vector import Vector, VectorExpression
from .vector import Vector, VectorExpression, VectorIndexExpr


def call_op(self, other, method, op, *, outer=False, union=False):
Expand Down Expand Up @@ -262,6 +262,8 @@ def __itruediv__(self, other):
setattr(MatrixExpression, name, val)
setattr(VectorInfixExpr, name, val)
setattr(MatrixInfixExpr, name, val)
setattr(VectorIndexExpr, name, val)
setattr(MatrixIndexExpr, name, val)
# End auto-generated code

if __name__ == "__main__":
Expand Down Expand Up @@ -350,6 +352,8 @@ def __itruediv__(self, other):
" setattr(MatrixExpression, name, val)\n"
" setattr(VectorInfixExpr, name, val)\n"
" setattr(MatrixInfixExpr, name, val)\n"
" setattr(VectorIndexExpr, name, val)\n"
" setattr(MatrixIndexExpr, name, val)\n"
)
from .utils import _autogenerate_code

Expand Down
8 changes: 4 additions & 4 deletions grblas/_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,20 @@ def slice_to_index(index, size):
length = len(range(start, stop, step))
if length == size and step == 1:
# [:] means all indices; use special GrB_ALL indicator
return AxisIndex(size, _ALL_INDICES, _as_scalar(size, _INDEX, is_cscalar=True))
return AxisIndex(size, _ALL_INDICES, _as_scalar(size, _INDEX, is_cscalar=True), size)
# SS, SuiteSparse-specific: slicing.
# For non-SuiteSparse, do: index = list(range(size)[index])
# SuiteSparse indexing is inclusive for both start and stop, and unsigned.
if step < 0:
if start < 0:
start = stop = 0 # Must be empty
return AxisIndex(length, _CArray([start, stop + 1, -step]), gxb_backwards)
return AxisIndex(length, _CArray([start, stop + 1, -step]), gxb_backwards, size)
if stop > 0:
stop -= 1
elif start == 0:
# [0:0] slice should be empty, so change to [1:0]
start += 1
if step == 1:
return AxisIndex(length, _CArray([start, stop]), gxb_range)
return AxisIndex(length, _CArray([start, stop]), gxb_range, size)
else:
return AxisIndex(length, _CArray([start, stop, step]), gxb_stride)
return AxisIndex(length, _CArray([start, stop, step]), gxb_stride, size)
9 changes: 5 additions & 4 deletions grblas/_ss/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,9 +266,10 @@ def normalize_chunks(chunks, shape):
def _concat_mn(tiles, *, is_matrix=None):
"""Argument checking for `Matrix.ss.concat` and returns number of tiles in each dimension"""
from ..matrix import Matrix, TransposedMatrix
from ..scalar import Scalar
from ..vector import Vector

valid_types = (Matrix, TransposedMatrix, Vector)
valid_types = (Matrix, TransposedMatrix, Vector, Scalar)
if not isinstance(tiles, (list, tuple)):
raise TypeError(f"tiles argument must be list or tuple; got: {type(tiles)}")
if not tiles:
Expand All @@ -279,10 +280,10 @@ def _concat_mn(tiles, *, is_matrix=None):
new_tiles = []
for i, row_tiles in enumerate(tiles):
if not isinstance(row_tiles, (list, tuple)):
if not is_matrix and output_type(row_tiles) is Vector:
if not is_matrix and output_type(row_tiles) in {Vector, Scalar}:
new_tiles.append(
dummy._expect_type(
row_tiles, Vector, within="ss.concat", argname="tiles"
row_tiles, (Vector, Scalar), within="ss.concat", argname="tiles"
)._as_matrix()
)
is_matrix = False
Expand Down Expand Up @@ -319,7 +320,7 @@ def _concat_mn(tiles, *, is_matrix=None):


def _as_matrix(x):
return x._as_matrix() if type(x) is gb.Vector else x
return x._as_matrix() if hasattr(x, "_as_matrix") else x


class MatrixArray:
Expand Down
8 changes: 1 addition & 7 deletions grblas/_ss/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,7 @@
from ..dtypes import _INDEX, INT64, UINT64, lookup_dtype
from ..exceptions import check_status, check_status_carg
from ..scalar import _as_scalar
from ..utils import (
_CArray,
ints_to_numpy_buffer,
libget,
values_to_numpy_buffer,
wrapdoc,
)
from ..utils import _CArray, ints_to_numpy_buffer, libget, values_to_numpy_buffer, wrapdoc
from .matrix import MatrixArray, _concat_mn, normalize_chunks
from .prefix_scan import prefix_scan
from .utils import get_order
Expand Down
14 changes: 11 additions & 3 deletions grblas/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,10 @@ def _expect_op(self, op, values, **kwargs):
raise TypeError(message) from None


AmbiguousAssignOrExtract._expect_op = _expect_op
AmbiguousAssignOrExtract._expect_type = _expect_type


def _check_mask(mask, output=None):
if not isinstance(mask, Mask):
if isinstance(mask, BaseType):
Expand Down Expand Up @@ -316,8 +320,8 @@ def update(self, expr):
def _update(self, expr, mask=None, accum=None, replace=False, input_mask=None):
# TODO: check expected output type (now included in Expression object)
if not isinstance(expr, BaseExpression):
if type(expr) is AmbiguousAssignOrExtract:
if expr.resolved_indexes.is_single_element and self._is_scalar:
if isinstance(expr, AmbiguousAssignOrExtract):
if expr._is_scalar and self._is_scalar:
# Extract element (s << v[1])
if accum is not None:
raise TypeError(
Expand Down Expand Up @@ -475,6 +479,8 @@ def _update(self, expr, mask=None, accum=None, replace=False, input_mask=None):
self.value = fake_self[0].new(is_cscalar=True, name="")
# SS: this assumes GrB_Scalar was cast to Vector
elif is_temp_scalar:
if temp_scalar._is_cscalar:
temp_scalar._empty = False
self.value = temp_scalar
elif self._is_cscalar:
self._empty = False
Expand Down Expand Up @@ -579,7 +585,9 @@ def _format_expr(self):
return self.expr_repr.format(*self.args, method_name=self.method_name, op=self.op)

def _format_expr_html(self):
expr_repr = self.expr_repr.replace(".name", "._name_html")
expr_repr = self.expr_repr.replace(".name", "._name_html").replace(
"._expr_name", "._expr_name_html"
)
return expr_repr.format(*self.args, method_name=self.method_name, op=self.op)

_expect_type = _expect_type
Expand Down
Loading