Skip to content

Commit bf4dba0

Browse files
committed
finalize other domain epxressions
1 parent 1c14009 commit bf4dba0

File tree

5 files changed

+147
-16
lines changed

5 files changed

+147
-16
lines changed

_doc/api/light_api.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@ OnnxGraph
3434
BaseVar
3535
+++++++
3636

37+
.. autoclass:: onnx_array_api.light_api.var.BaseVar
38+
:members:
39+
40+
SubDomain
41+
+++++++++
42+
3743
.. autoclass:: onnx_array_api.light_api.var.BaseVar
3844
:members:
3945

_unittests/ut_light_api/test_light_api.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import inspect
12
import unittest
23
from typing import Callable, Optional
34
import numpy as np
@@ -12,6 +13,7 @@
1213
from onnx.reference import ReferenceEvaluator
1314
from onnx_array_api.ext_test_case import ExtTestCase, skipif_ci_windows
1415
from onnx_array_api.light_api import start, OnnxGraph, Var, g
16+
from onnx_array_api.light_api.var import SubDomain
1517
from onnx_array_api.light_api._op_var import OpsVar
1618
from onnx_array_api.light_api._op_vars import OpsVars
1719

@@ -473,21 +475,40 @@ def test_if(self):
473475
self.assertEqualArray(np.array([0], dtype=np.int64), got[0])
474476

475477
def test_domain(self):
476-
onx = (
477-
start()
478-
.vin("X")
479-
.reshape((-1, 1))
480-
.ai.onnx.ml.Normalizer(norm="L1")
481-
.rename("Y")
482-
.vout()
483-
.to_onnx()
484-
)
478+
onx = start(opsets={"ai.onnx.ml": 3}).vin("X").reshape((-1, 1)).rename("USE")
479+
480+
class A:
481+
def g(self):
482+
return True
483+
484+
def ah(self):
485+
return True
486+
487+
setattr(A, "h", ah)
488+
489+
self.assertTrue(A().h())
490+
self.assertIn("(self)", str(inspect.signature(A.h)))
491+
self.assertTrue(issubclass(onx._ai, SubDomain))
492+
self.assertIsInstance(onx.ai, SubDomain)
493+
self.assertIsInstance(onx.ai.parent, Var)
494+
self.assertTrue(issubclass(onx._ai._onnx, SubDomain))
495+
self.assertIsInstance(onx.ai.onnx, SubDomain)
496+
self.assertIsInstance(onx.ai.onnx.parent, Var)
497+
self.assertTrue(issubclass(onx._ai._onnx._ml, SubDomain))
498+
self.assertIsInstance(onx.ai.onnx.ml, SubDomain)
499+
self.assertIsInstance(onx.ai.onnx.ml.parent, Var)
500+
self.assertIn("(self,", str(inspect.signature(onx._ai._onnx._ml.Normalizer)))
501+
onx = onx.ai.onnx.ml.Normalizer(norm="MAX")
502+
onx = onx.rename("Y").vout().to_onnx()
485503
self.assertIsInstance(onx, ModelProto)
486-
self.assertIn("Transpose", str(onx))
504+
self.assertIn("Normalizer", str(onx))
505+
self.assertIn('domain: "ai.onnx.ml"', str(onx))
506+
self.assertIn('input: "USE"', str(onx))
487507
ref = ReferenceEvaluator(onx)
488508
a = np.arange(10).astype(np.float32)
489509
got = ref.run(None, {"X": a})[0]
490-
self.assertEqualArray(a.reshape((-1, 1)).T, got)
510+
expected = (a > 0).astype(int).astype(np.float32).reshape((-1, 1))
511+
self.assertEqualArray(expected, got)
491512

492513

493514
if __name__ == "__main__":

onnx_array_api/light_api/annotations.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,18 @@ def domain(domain: str, op_type: Optional[str] = None) -> Callable:
2626
"""
2727
Registers one operator into a sub domain.
2828
"""
29-
pieces = domain.split(".")
30-
sub = pieces[0]
29+
names = [op_type]
3130

3231
def decorate(op_method: Callable) -> Callable:
32+
if names[0] is None:
33+
names[0] = op_method.__name__
34+
3335
def wrapper(self, *args: List[Any], **kwargs: Dict[str, Any]) -> Any:
34-
if not self.hasattr(sub):
35-
raise RuntimeError(f"Class has not registered subdomain {sub!r}.")
36-
return op_method(self, *args, **kwargs)
36+
return op_method(self.parent, *args, **kwargs)
3737

38+
wrapper.__qual__name__ = f"[{domain}]{names[0]}"
39+
wrapper.__name__ = f"[{domain}]{names[0]}"
40+
wrapper.__domain__ = domain
3841
return wrapper
3942

4043
return decorate

onnx_array_api/light_api/model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,9 @@ def make_node(
248248

249249
node = make_node(op_type, input_names, output_names, domain=domain, **kwargs)
250250
self.nodes.append(node)
251+
if domain != "":
252+
if not self.opsets or domain not in self.opsets:
253+
raise RuntimeError(f"No opset value was given for domain {domain!r}.")
251254
return node
252255

253256
def cst(self, value: np.ndarray, name: Optional[str] = None) -> "Var":

onnx_array_api/light_api/var.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import inspect
12
from typing import Any, Dict, List, Optional, Tuple, Union
23
import numpy as np
34
from onnx import TensorProto
@@ -16,6 +17,26 @@
1617
from ._op_vars import OpsVars
1718

1819

20+
class SubDomain:
21+
"""
22+
Declares a domain or a piece of it (if it contains '.' in its name).
23+
"""
24+
25+
def __init__(self, var: "BaseVar"):
26+
if not isinstance(var, BaseVar):
27+
raise TypeError(f"Unexpected type {type(var)}.")
28+
self.parent = var
29+
30+
31+
def _getclassattr_(self, name):
32+
if not hasattr(self.__class__, name):
33+
raise TypeError(
34+
f"Unable to find {name!r} in class {self.__class__.__name__!r}, "
35+
f"available {dir(self.__class__)}."
36+
)
37+
return getattr(self.__class__, name)
38+
39+
1940
class BaseVar:
2041
"""
2142
Represents an input, an initializer, a node, an output,
@@ -24,6 +45,83 @@ class BaseVar:
2445
:param parent: the graph containing the Variable
2546
"""
2647

48+
def __new__(cls, *args, **kwargs):
49+
res = super().__new__(cls)
50+
res.__init__(*args, **kwargs)
51+
if getattr(cls, "__incomplete", True):
52+
for k in dir(cls):
53+
att = getattr(cls, k, None)
54+
if not att:
55+
continue
56+
name = getattr(att, "__name__", None)
57+
if not name or name[0] != "[":
58+
continue
59+
60+
# A function with a domain name
61+
if not inspect.isfunction(att):
62+
raise RuntimeError(f"{cls.__name__}.{k} is not a function.")
63+
domain, op_type = name[1:].split("]")
64+
if "." in domain:
65+
spl = domain.split(".", maxsplit=1)
66+
dname = f"_{spl[0]}"
67+
if not hasattr(cls, dname):
68+
d = type(
69+
f"{cls.__name__}{dname}", (SubDomain,), {"name": dname[1:]}
70+
)
71+
setattr(cls, dname, d)
72+
setattr(
73+
cls,
74+
spl[0],
75+
property(
76+
lambda self, _name_=dname: _getclassattr_(self, _name_)(
77+
self
78+
)
79+
),
80+
)
81+
else:
82+
d = getattr(cls, dname)
83+
suffix = spl[0]
84+
for p in spl[1].split("."):
85+
dname = f"_{p}"
86+
suffix += dname
87+
if not hasattr(d, dname):
88+
sd = type(
89+
f"{cls.__name__}_{suffix}",
90+
(SubDomain,),
91+
{"name": suffix},
92+
)
93+
setattr(d, dname, sd)
94+
setattr(
95+
d,
96+
p,
97+
property(
98+
lambda self, _name_=dname: _getclassattr_(
99+
self, _name_
100+
)(self.parent)
101+
),
102+
)
103+
d = sd
104+
else:
105+
d = getattr(d, dname)
106+
elif not hasattr(cls, domain):
107+
dname = f"_{domain}"
108+
d = type(f"{cls.__name__}{dname}", (SubDomain,), {"name": domain})
109+
setattr(cls, dname, d)
110+
setattr(
111+
cls,
112+
domain,
113+
property(
114+
lambda self, _name_=dname: _getclassattr_(self, _name_)(
115+
self
116+
)
117+
),
118+
)
119+
120+
setattr(d, op_type, att)
121+
setattr(cls, "__incomplete", False)
122+
123+
return res
124+
27125
def __init__(
28126
self,
29127
parent: OnnxGraph,

0 commit comments

Comments
 (0)