Skip to content

Allow __index__ only for integral dtypes on Scalars #481

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
ci:
# See: https://pre-commit.ci/#configuration
autofix_prs: false
autoupdate_schedule: monthly
autoupdate_schedule: quarterly
autoupdate_commit_msg: "chore: update pre-commit hooks"
autofix_commit_msg: "style: pre-commit fixes"
skip: [pylint, no-commit-to-branch]
Expand Down Expand Up @@ -51,7 +51,7 @@ repos:
- id: isort
# Let's keep `pyupgrade` even though `ruff --fix` probably does most of it
- repo: https://github.com/asottile/pyupgrade
rev: v3.8.0
rev: v3.9.0
hooks:
- id: pyupgrade
args: [--py38-plus]
Expand All @@ -61,12 +61,12 @@ repos:
- id: auto-walrus
args: [--line-length, "100"]
- repo: https://github.com/psf/black
rev: 23.3.0
rev: 23.7.0
hooks:
- id: black
- id: black-jupyter
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.0.277
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.278
hooks:
- id: ruff
args: [--fix-only, --show-fixes]
Expand All @@ -93,8 +93,8 @@ repos:
types_or: [python, rst, markdown]
additional_dependencies: [tomli]
files: ^(graphblas|docs)/
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.0.277
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.278
hooks:
- id: ruff
- repo: https://github.com/sphinx-contrib/sphinx-lint
Expand Down
8 changes: 6 additions & 2 deletions graphblas/core/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np

from .. import backend, binary, config, monoid
from ..dtypes import _INDEX, FP64, lookup_dtype, unify
from ..dtypes import _INDEX, FP64, _index_dtypes, lookup_dtype, unify
from ..exceptions import EmptyObject, check_status
from . import _has_numba, _supports_udfs, automethods, ffi, lib, utils
from .base import BaseExpression, BaseType, call
Expand Down Expand Up @@ -158,7 +158,11 @@ def __int__(self):
def __complex__(self):
return complex(self.value)

__index__ = __int__
@property
def __index__(self):
if self.dtype in _index_dtypes:
return self.__int__
raise AttributeError("Scalar object only has `__index__` for integral dtypes")

def __array__(self, dtype=None):
if dtype is None:
Expand Down
7 changes: 3 additions & 4 deletions graphblas/core/ss/config.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from collections.abc import MutableMapping
from numbers import Integral

from ...dtypes import lookup_dtype
from ...exceptions import _error_code_lookup, check_status
from .. import NULL, ffi, lib
from ..utils import values_to_numpy_buffer
from ..utils import maybe_integral, values_to_numpy_buffer


class BaseConfig(MutableMapping):
Expand Down Expand Up @@ -147,8 +146,8 @@ def __setitem__(self, key, val):
bitwise = self._bitwise[key]
if isinstance(val, str):
val = bitwise[val.lower()]
elif isinstance(val, Integral):
val = bitwise.get(val, val)
elif (x := maybe_integral(val)) is not None:
val = bitwise.get(x, x)
else:
bits = 0
for x in val:
Expand Down
36 changes: 23 additions & 13 deletions graphblas/core/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from numbers import Integral, Number
from operator import index

import numpy as np

Expand Down Expand Up @@ -158,6 +158,17 @@ def get_order(order):
)


def maybe_integral(val):
"""Ensure ``val`` is an integer or return None if it's not."""
try:
return index(val)
except TypeError:
pass
if isinstance(val, float) and val.is_integer():
return int(val)
return None


def normalize_chunks(chunks, shape):
"""Normalize chunks argument for use by ``Matrix.ss.split``.

Expand All @@ -175,8 +186,8 @@ def normalize_chunks(chunks, shape):
"""
if isinstance(chunks, (list, tuple)):
pass
elif isinstance(chunks, Number):
chunks = (chunks,) * len(shape)
elif (chunk := maybe_integral(chunks)) is not None:
chunks = (chunk,) * len(shape)
elif isinstance(chunks, np.ndarray):
chunks = chunks.tolist()
else:
Expand All @@ -192,22 +203,21 @@ def normalize_chunks(chunks, shape):
for size, chunk in zip(shape, chunks):
if chunk is None:
cur_chunks = [size]
elif isinstance(chunk, Integral) or isinstance(chunk, float) and chunk.is_integer():
chunk = int(chunk)
if chunk < 0:
raise ValueError(f"Chunksize must be greater than 0; got: {chunk}")
div, mod = divmod(size, chunk)
cur_chunks = [chunk] * div
elif (c := maybe_integral(chunk)) is not None:
if c < 0:
raise ValueError(f"Chunksize must be greater than 0; got: {c}")
div, mod = divmod(size, c)
cur_chunks = [c] * div
if mod:
cur_chunks.append(mod)
elif isinstance(chunk, (list, tuple)):
cur_chunks = []
none_index = None
for c in chunk:
if isinstance(c, Integral) or isinstance(c, float) and c.is_integer():
c = int(c)
if c < 0:
raise ValueError(f"Chunksize must be greater than 0; got: {c}")
if (val := maybe_integral(c)) is not None:
if val < 0:
raise ValueError(f"Chunksize must be greater than 0; got: {val}")
c = val
elif c is None:
if none_index is not None:
raise TypeError(
Expand Down
3 changes: 3 additions & 0 deletions graphblas/dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,6 @@ def __getattr__(key):
globals()["ss"] = ss
return ss
raise AttributeError(f"module {__name__!r} has no attribute {key!r}")


_index_dtypes = {BOOL, INT8, UINT8, INT16, UINT16, INT32, UINT32, INT64, UINT64, _INDEX}
2 changes: 2 additions & 0 deletions graphblas/tests/test_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ def test_casting(s):
assert float(s) == 5.0
assert type(float(s)) is float
assert range(s) == range(5)
with pytest.raises(AttributeError, match="Scalar .* only .*__index__.*integral"):
range(s.dup(float))
assert complex(s) == complex(5)
assert type(complex(s)) is complex

Expand Down
4 changes: 2 additions & 2 deletions scripts/check_versions.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
# Use, adjust, copy/paste, etc. as necessary to answer your questions.
# This may be helpful when updating dependency versions in CI.
# Tip: add `--json` for more information.
conda search 'numpy[channel=conda-forge]>=1.25.0'
conda search 'numpy[channel=conda-forge]>=1.25.1'
conda search 'pandas[channel=conda-forge]>=2.0.3'
conda search 'scipy[channel=conda-forge]>=1.11.1'
conda search 'networkx[channel=conda-forge]>=3.1'
conda search 'awkward[channel=conda-forge]>=2.3.0'
conda search 'awkward[channel=conda-forge]>=2.3.1'
conda search 'sparse[channel=conda-forge]>=0.14.0'
conda search 'fast_matrix_market[channel=conda-forge]>=1.7.2'
conda search 'numba[channel=conda-forge]>=0.57.1'
Expand Down