Skip to content

Commit ccf07e7

Browse files
committed
refactoring
1 parent 4b5934c commit ccf07e7

File tree

6 files changed

+28
-21
lines changed

6 files changed

+28
-21
lines changed

_doc/api/light_api.rst

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -78,12 +78,6 @@ BaseEmitter
7878
.. autoclass:: onnx_array_api.light_api.base_emitter.BaseEmitter
7979
:members:
8080

81-
Emitter
82-
+++++++
83-
84-
.. autoclass:: onnx_array_api.light_api.emitter.Emitter
85-
:members:
86-
8781
EventType
8882
+++++++++
8983

@@ -96,6 +90,12 @@ InnerEmitter
9690
.. autoclass:: onnx_array_api.light_api.inner_emitter.InnerEmitter
9791
:members:
9892

93+
LightEmitter
94+
++++++++++++
95+
96+
.. autoclass:: onnx_array_api.light_api.emitter.LightEmitter
97+
:members:
98+
9999
Translater
100100
++++++++++
101101

_unittests/ut_light_api/test_translate.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from onnx.reference import ReferenceEvaluator
77
from onnx_array_api.ext_test_case import ExtTestCase
88
from onnx_array_api.light_api import start, translate, g
9-
from onnx_array_api.light_api.emitter import EventType
9+
from onnx_array_api.light_api.base_emitter import EventType
1010

1111
OPSET_API = min(19, onnx_opset_version() - 1)
1212

@@ -220,4 +220,5 @@ def test_aionnxml(self):
220220

221221

222222
if __name__ == "__main__":
223+
TestTranslate().test_export_if()
223224
unittest.main(verbosity=2)

onnx_array_api/light_api/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def translate(proto: ModelProto, single_line: bool = False, api: str = "light")
6767
:param single_line: as a single line or not
6868
:param api: API to export into,
6969
default is `"light"` and this is handle by class
70-
:class:`onnx_array_api.light_api.emitter.Emitter`,
70+
:class:`onnx_array_api.light_api.light_emitter.LightEmitter`,
7171
another value is `"onnx"` which is the inner API implemented
7272
in onnx package.
7373
:return: code

onnx_array_api/light_api/inner_emitter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Any, Dict, List, Optional, Tuple
22
from onnx import AttributeProto
33
from .annotations import ELEMENT_TYPE_NAME
4-
from .emitter import BaseEmitter
4+
from .base_emitter import BaseEmitter
55
from .translate import Translater
66

77

onnx_array_api/light_api/emitter.py renamed to onnx_array_api/light_api/light_emitter.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from .base_emitter import BaseEmitter
44

55

6-
class Emitter(BaseEmitter):
6+
class LightEmitter(BaseEmitter):
77
"""
88
Converts event into proper code.
99
"""
@@ -29,6 +29,9 @@ def _emit_start(self, **kwargs: Dict[str, Any]) -> List[str]:
2929
def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]:
3030
return ["to_onnx()"]
3131

32+
def _emit_to_onnx_function(self, **kwargs: Dict[str, Any]) -> List[str]:
33+
return []
34+
3235
def _emit_begin_graph(self, **kwargs: Dict[str, Any]) -> List[str]:
3336
return []
3437

onnx_array_api/light_api/translate.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from onnx.numpy_helper import to_array
55
from ..reference import to_array_extended
66
from .base_emitter import EventType
7-
from .emitter import Emitter
7+
from .light_emitter import LightEmitter
88

99

1010
class Translater:
@@ -15,10 +15,10 @@ class Translater:
1515
def __init__(
1616
self,
1717
proto: Union[ModelProto, FunctionProto, GraphProto],
18-
emitter: Optional[Emitter] = None,
18+
emitter: Optional[LightEmitter] = None,
1919
):
2020
self.proto_ = proto
21-
self.emitter = emitter or Emitter()
21+
self.emitter = emitter or LightEmitter()
2222

2323
def __repr__(self) -> str:
2424
return f"{self.__class__.__name__}(<{type(self.proto_)})"
@@ -43,6 +43,7 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]:
4343
sparse_initializers = self.proto_.graph.sparse_initializer
4444
attributes = []
4545
last_event = EventType.TO_ONNX_MODEL
46+
is_function = False
4647
elif isinstance(self.proto_, (FunctionProto, GraphProto)):
4748
inputs = self.proto_.input
4849
outputs = self.proto_.output
@@ -56,14 +57,17 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]:
5657
attributes = (
5758
self.proto_.attribute if hasattr(self.proto_, "attribute") else []
5859
)
59-
last_event = EventType.TO_ONNX_FUNCTION
60+
is_function = isinstance(self.proto_, FunctionProto)
61+
last_event = (
62+
EventType.TO_ONNX_FUNCTION if is_function else EventType.TO_ONNX_MODEL
63+
)
6064
else:
6165
raise ValueError(f"Unexpected type {type(self.proto_)} for proto.")
6266

6367
if sparse_initializers:
6468
raise NotImplementedError("Sparse initializer not supported yet.")
6569

66-
if isinstance(self.proto_, FunctionProto):
70+
if is_function:
6771
rows.extend(
6872
self.emitter(
6973
EventType.BEGIN_FUNCTION,
@@ -85,7 +89,7 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]:
8589
)
8690

8791
for i in inputs:
88-
if isinstance(i, str):
92+
if is_function:
8993
rows.extend(self.emitter(EventType.FUNCTION_INPUT, name=i))
9094
else:
9195
rows.extend(
@@ -100,7 +104,7 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]:
100104
)
101105
)
102106

103-
if attributes:
107+
if is_function and attributes:
104108
rows.extend(
105109
self.emitter(EventType.FUNCTION_ATTRIBUTES, attributes=list(attributes))
106110
)
@@ -119,7 +123,7 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]:
119123
)
120124

121125
for o in outputs:
122-
if isinstance(o, str):
126+
if is_function:
123127
rows.extend(self.emitter(EventType.FUNCTION_OUTPUT, name=o))
124128
else:
125129
rows.extend(
@@ -137,11 +141,10 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]:
137141
name = self.proto_.name
138142
else:
139143
name = self.proto_.graph.name
144+
140145
rows.extend(
141146
self.emitter(
142-
EventType.END_FUNCTION
143-
if isinstance(self.proto_, FunctionProto)
144-
else EventType.END_GRAPH,
147+
EventType.END_FUNCTION if is_function else EventType.END_GRAPH,
145148
name=name,
146149
)
147150
)

0 commit comments

Comments
 (0)