Skip to content

Commit 8f6f9d3

Browse files
committed
Add support for matrix multiplication. Fixes ionelmc#66.
1 parent 30e8c5a commit 8f6f9d3

File tree

5 files changed

+89
-2
lines changed

5 files changed

+89
-2
lines changed

src/lazy_object_proxy/cext.c

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,16 @@ static PyObject *Proxy_multiply(PyObject *o1, PyObject *o2)
248248

249249
/* ------------------------------------------------------------------------- */
250250

251+
static PyObject *Proxy_matrix_multiply(PyObject *o1, PyObject *o2)
252+
{
253+
Proxy__WRAPPED_REPLACE_OR_RETURN_NULL(o1);
254+
Proxy__WRAPPED_REPLACE_OR_RETURN_NULL(o2);
255+
256+
return PyNumber_MatrixMultiply(o1, o2);
257+
}
258+
259+
/* ------------------------------------------------------------------------- */
260+
251261
static PyObject *Proxy_remainder(PyObject *o1, PyObject *o2)
252262
{
253263
Proxy__WRAPPED_REPLACE_OR_RETURN_NULL(o1);
@@ -458,6 +468,28 @@ static PyObject *Proxy_inplace_multiply(
458468

459469
/* ------------------------------------------------------------------------- */
460470

471+
static PyObject *Proxy_inplace_matrix_multiply(
472+
ProxyObject *self, PyObject *other)
473+
{
474+
PyObject *object = NULL;
475+
476+
Proxy__ENSURE_WRAPPED_OR_RETURN_NULL(self);
477+
Proxy__WRAPPED_REPLACE_OR_RETURN_NULL(other);
478+
479+
object = PyNumber_InPlaceMatrixMultiply(self->wrapped, other);
480+
481+
if (!object)
482+
return NULL;
483+
484+
Py_DECREF(self->wrapped);
485+
self->wrapped = object;
486+
487+
Py_INCREF(self);
488+
return (PyObject *)self;
489+
}
490+
491+
/* ------------------------------------------------------------------------- */
492+
461493
static PyObject *Proxy_inplace_remainder(
462494
ProxyObject *self, PyObject *other)
463495
{
@@ -1239,6 +1271,8 @@ static PyNumberMethods Proxy_as_number = {
12391271
(binaryfunc)Proxy_inplace_floor_divide, /*nb_inplace_floor_divide*/
12401272
(binaryfunc)Proxy_inplace_true_divide, /*nb_inplace_true_divide*/
12411273
(unaryfunc)Proxy_index, /*nb_index*/
1274+
(binaryfunc)Proxy_matrix_multiply, /*nb_matrix_multiply*/
1275+
(binaryfunc)Proxy_inplace_matrix_multiply, /*nb_inplace_matrix_multiply*/
12421276
};
12431277

12441278
static PySequenceMethods Proxy_as_sequence = {

src/lazy_object_proxy/simple.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ def __delattr__(self, name):
141141
__add__ = make_proxy_method(operator.add)
142142
__sub__ = make_proxy_method(operator.sub)
143143
__mul__ = make_proxy_method(operator.mul)
144+
__matmul__ = make_proxy_method(operator.matmul)
144145
__truediv__ = make_proxy_method(operator.truediv)
145146
__floordiv__ = make_proxy_method(operator.floordiv)
146147
__mod__ = make_proxy_method(operator.mod)
@@ -161,6 +162,9 @@ def __rsub__(self, other):
161162
def __rmul__(self, other):
162163
return other * self.__wrapped__
163164

165+
def __rmatmul__(self, other):
166+
return other @ self.__wrapped__
167+
164168
def __rdiv__(self, other):
165169
return operator.div(other, self.__wrapped__)
166170

@@ -197,6 +201,7 @@ def __ror__(self, other):
197201
__iadd__ = make_proxy_method(operator.iadd)
198202
__isub__ = make_proxy_method(operator.isub)
199203
__imul__ = make_proxy_method(operator.imul)
204+
__imatmul__ = make_proxy_method(operator.imatmul)
200205
__itruediv__ = make_proxy_method(operator.itruediv)
201206
__ifloordiv__ = make_proxy_method(operator.ifloordiv)
202207
__imod__ = make_proxy_method(operator.imod)

src/lazy_object_proxy/slots.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,9 @@ def __sub__(self, other):
226226
def __mul__(self, other):
227227
return self.__wrapped__ * other
228228

229+
def __matmul__(self, other):
230+
return self.__wrapped__ @ other
231+
229232
def __truediv__(self, other):
230233
return operator.truediv(self.__wrapped__, other)
231234

@@ -265,6 +268,9 @@ def __rsub__(self, other):
265268
def __rmul__(self, other):
266269
return other * self.__wrapped__
267270

271+
def __rmatmul__(self, other):
272+
return other @ self.__wrapped__
273+
268274
def __rdiv__(self, other):
269275
return operator.div(other, self.__wrapped__)
270276

@@ -310,8 +316,8 @@ def __imul__(self, other):
310316
self.__wrapped__ *= other
311317
return self
312318

313-
def __idiv__(self, other):
314-
self.__wrapped__ = operator.idiv(self.__wrapped__, other)
319+
def __imatmul__(self, other):
320+
self.__wrapped__ @= other
315321
return self
316322

317323
def __itruediv__(self, other):

tests/test_lazy_object_proxy.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -912,6 +912,26 @@ def test_mul(lop):
912912
assert two * 3 == 2 * 3
913913

914914

915+
def test_matmul(lop):
916+
import numpy
917+
918+
one = numpy.array((1, 2, 3))
919+
two = numpy.array((2, 3, 4))
920+
assert one @ two == 20
921+
922+
one = lop.Proxy(lambda: numpy.array((1, 2, 3)))
923+
two = lop.Proxy(lambda: numpy.array((2, 3, 4)))
924+
assert one @ two == 20
925+
926+
one = lop.Proxy(lambda: numpy.array((1, 2, 3)))
927+
two = numpy.array((2, 3, 4))
928+
assert one @ two == 20
929+
930+
one = numpy.array((1, 2, 3))
931+
two = lop.Proxy(lambda: numpy.array((2, 3, 4)))
932+
assert one @ two == 20
933+
934+
915935
def test_div(lop):
916936
# On Python 2 this will pick up div and on Python
917937
# 3 it will pick up truediv.
@@ -1067,6 +1087,27 @@ def test_imul(lop):
10671087
assert type(value) == lop.Proxy
10681088

10691089

1090+
def test_imatmul(lop):
1091+
class InplaceMatmul:
1092+
value = None
1093+
1094+
def __imatmul__(self, other):
1095+
self.value = other
1096+
return self
1097+
1098+
value = InplaceMatmul()
1099+
assert value.value is None
1100+
value @= 123
1101+
assert value.value == 123
1102+
1103+
value = lop.Proxy(InplaceMatmul)
1104+
value @= 234
1105+
assert value.value == 234
1106+
1107+
if lop.kind != 'simple':
1108+
assert type(value) == lop.Proxy
1109+
1110+
10701111
def test_idiv(lop):
10711112
# On Python 2 this will pick up div and on Python
10721113
# 3 it will pick up truediv.

tox.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ deps =
4242
pytest
4343
pytest-benchmark
4444
Django
45+
numpy
4546
objproxies==0.9.4
4647
hunter
4748
cover: pytest-cov

0 commit comments

Comments
 (0)