Skip to content

Fix infix expression _value and _expr usage #418

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ repos:
- id: mixed-line-ending
- id: trailing-whitespace
- repo: https://github.com/abravalheri/validate-pyproject
rev: v0.12.1
rev: v0.12.2
hooks:
- id: validate-pyproject
name: Validate pyproject.toml
Expand All @@ -47,7 +47,7 @@ repos:
- id: black
- id: black-jupyter
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.0.257
rev: v0.0.259
hooks:
- id: ruff
args: [--fix-only]
Expand Down Expand Up @@ -75,7 +75,7 @@ repos:
additional_dependencies: [tomli]
files: ^(graphblas|docs)/
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.0.257
rev: v0.0.259
hooks:
- id: ruff
- repo: https://github.com/sphinx-contrib/sphinx-lint
Expand Down
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ dependencies:
# - snakeviz
# - sphinx-lint
# - sympy
# - tuna
# - twine
# - vim
# - yesqa
Expand Down
31 changes: 20 additions & 11 deletions graphblas/core/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,33 +478,34 @@ def __bool__(self):


class InfixExprBase:
__slots__ = "left", "right", "_value", "__weakref__"
__slots__ = "left", "right", "_expr", "__weakref__"
_is_scalar = False

def __init__(self, left, right):
self.left = left
self.right = right
self._value = None
self._expr = None

def new(self, dtype=None, *, mask=None, name=None, **opts):
if (
mask is None
and self._value is not None
and (dtype is None or self._value.dtype == dtype)
and self._expr is not None
and self._expr._value is not None
and (dtype is None or self._expr._value.dtype == dtype)
):
rv = self._value
rv = self._expr._value
if name is not None:
rv.name = name
self._value = None
self._expr._value = None
return rv
expr = self._to_expr()
return expr.new(dtype, mask=mask, name=name, **opts)

def _to_expr(self):
if self._value is None:
if self._expr is None:
# Rely on the default operator for `x @ y`
self._value = getattr(self.left, self.method_name)(self.right)
return self._value
self._expr = getattr(self.left, self.method_name)(self.right)
return self._expr

def _get_value(self, attr=None, default=None):
expr = self._to_expr()
Expand Down Expand Up @@ -536,10 +537,18 @@ def __repr__(self):

@property
def dtype(self):
if self._value is not None:
return self._value.dtype
return self._to_expr().dtype

@property
def _value(self):
if self._expr is None:
return None
return self._expr._value

@_value.setter
def _value(self, val):
self._to_expr()._value = val


# Mistakes
utils._output_types[AmbiguousAssignOrExtract] = AmbiguousAssignOrExtract
Expand Down
1 change: 1 addition & 0 deletions graphblas/core/formatting.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# This file imports pandas, so it should only be imported when formatting
import numpy as np

from .. import backend, config, monoid, unary
Expand Down
16 changes: 8 additions & 8 deletions graphblas/core/infix.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@


def _ewise_add_to_expr(self):
if self._value is not None:
return self._value
if self._expr is not None:
return self._expr
if self.left.dtype == BOOL and self.right.dtype == BOOL:
self._value = self.left.ewise_add(self.right, lor)
return self._value
self._expr = self.left.ewise_add(self.right, lor)
return self._expr
raise TypeError(
"Bad dtypes for `x | y`! Automatic computation of `x | y` infix expressions is only valid "
f"for BOOL dtypes. The argument dtypes are {self.left.dtype} and {self.right.dtype}.\n\n"
Expand All @@ -30,11 +30,11 @@ def _ewise_add_to_expr(self):


def _ewise_mult_to_expr(self):
if self._value is not None:
return self._value
if self._expr is not None:
return self._expr
if self.left.dtype == BOOL and self.right.dtype == BOOL:
self._value = self.left.ewise_mult(self.right, land)
return self._value
self._expr = self.left.ewise_mult(self.right, land)
return self._expr
raise TypeError(
"Bad dtypes for `x & y`! Automatic computation of `x & y` infix expressions is only valid "
f"for BOOL dtypes. The argument dtypes are {self.left.dtype} and {self.right.dtype}.\n\n"
Expand Down
3 changes: 2 additions & 1 deletion graphblas/core/recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from ..dtypes import DataType
from . import base, lib
from .base import _recorder
from .formatting import CSS_STYLE
from .mask import Mask
from .matrix import TransposedMatrix
from .operator import TypedOpBase
Expand Down Expand Up @@ -103,6 +102,8 @@ def is_recording(self):
return self._token is not None and _recorder.get(base._prev_recorder) is self

def _repr_base_(self):
from .formatting import CSS_STYLE

status = (
'<div style="'
"height: 12px; "
Expand Down
18 changes: 18 additions & 0 deletions graphblas/tests/test_infix.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,3 +342,21 @@ def test_inplace_infix(s1, v1, v2, A1, A2):
expr @= A
with pytest.raises(TypeError, match="not supported"):
s1 @= v1


@autocompute
def test_infix_expr_value_types():
"""Test bug where `infix_expr._value` was used as MatrixExpression or Matrix"""
from graphblas.core.matrix import MatrixExpression

A = Matrix(int, 3, 3)
A << 1
expr = A @ A.T
assert expr._expr is None
assert expr._value is None
assert type(expr._get_value()) is Matrix
assert type(expr._expr) is MatrixExpression
assert type(expr.new()) is Matrix
assert expr._expr is not None
assert expr._value is None
assert type(expr.new()) is Matrix
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,8 @@ ignore = [
"graphblas/core/ss/matrix.py" = ["NPY002"] # numba doesn't support rng generator yet
"graphblas/core/ss/vector.py" = ["NPY002"] # numba doesn't support rng generator yet
"graphblas/ss/_core.py" = ["N999"] # We want _core.py to be underscopre
"graphblas/tests/*py" = ["S101", "T201", "D103", "D100", "SIM300"] # Allow assert, print, no docstring, and yoda
# Allow assert, pickle, RNG, print, no docstring, and yoda in tests
"graphblas/tests/*py" = ["S101", "S301", "S311", "T201", "D103", "D100", "SIM300"]
"graphblas/tests/test_formatting.py" = ["E501"] # Allow long lines
"graphblas/**/__init__.py" = ["F401"] # Allow unused imports (w/o defining `__all__`)
"scripts/*.py" = ["INP001"] # Not a package
Expand Down