Skip to content

Commit e29df50

Browse files
committed
Improves translation to GraphBuilder
1 parent 689cc6f commit e29df50

File tree

4 files changed

+37
-6
lines changed

4 files changed

+37
-6
lines changed

CHANGELOGS.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
Change Logs
22
===========
33

4+
0.3.1
5+
+++++
6+
7+
* :pr:`94`: improves translation to GraphBuilder
8+
49
0.3.0
510
+++++
611

onnx_array_api/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22
APIs to create ONNX Graphs.
33
"""
44

5-
__version__ = "0.3.0"
5+
__version__ = "0.3.1"
66
__author__ = "Xavier Dupré"

onnx_array_api/translate_api/builder_emitter.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ class BuilderEmitter(BaseEmitter):
2020
Converts event into proper code.
2121
"""
2222

23+
def __init__(self, make_model_function: str = ""):
24+
super().__init__()
25+
self.make_model_function = make_model_function
26+
2327
def join(self, rows: List[str], single_line: bool = False) -> str:
2428
"Join the rows"
2529
assert (
@@ -29,6 +33,7 @@ def join(self, rows: List[str], single_line: bool = False) -> str:
2933

3034
def _emit_start(self, **kwargs: Dict[str, Any]) -> List[str]:
3135
self.opsets = kwargs.get("opsets", {})
36+
self.ir_version = kwargs.get("ir_version", None)
3237
return []
3338

3439
def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]:
@@ -43,12 +48,27 @@ def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]:
4348
)
4449
rows = [
4550
"",
46-
f"g = GraphBuilder({self.opsets})",
51+
(
52+
f"g = GraphBuilder({self.opsets}, ir_version={self.ir_version})"
53+
if self.ir_version
54+
else f"GraphBuilder({self.opsets})"
55+
),
4756
*inputs,
4857
f"{self.name}({inps})",
4958
*outputs,
5059
"model = g.to_onnx()",
5160
]
61+
if self.make_model_function:
62+
rows = [
63+
"",
64+
"",
65+
f'def {self.make_model_function}() -> "ModelProto":',
66+
*[" " + _ for _ in rows[1:]],
67+
" return model",
68+
"",
69+
"",
70+
f"model = {self.make_model_function}()",
71+
]
5272
return rows
5373

5474
def _emit_begin_graph(self, **kwargs: Dict[str, Any]) -> List[str]:
@@ -79,12 +99,14 @@ def _emit_input(self, **kwargs: Dict[str, Any]) -> List[str]:
7999
itype = kwargs.get("elem_type", 0)
80100
shape = kwargs.get("shape", None)
81101
if itype == 0:
82-
inp = "X"
102+
inp = name or "X"
83103
else:
84104
if shape is None:
85-
inp = f'X: "{_itype_to_string(itype)}"'
105+
inp = f'{name}: "{_itype_to_string(itype)}"'
86106
else:
87-
inp = f'X: "{_itype_to_string(itype)}[{", ".join(map(str, shape))}]"'
107+
inp = (
108+
f'{name}: "{_itype_to_string(itype)}[{", ".join(map(str, shape))}]"'
109+
)
88110
self.inputs_full.append(inp)
89111
self.inputs.append(name)
90112
self.inputs_full_.append((name, _itype_to_string(itype), shape))

onnx_array_api/translate_api/translate.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,11 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]:
3535
last_event = None
3636
if isinstance(self.proto_, ModelProto):
3737
opsets = {d.domain: d.version for d in self.proto_.opset_import}
38-
rows.extend(self.emitter(EventType.START, opsets=opsets))
38+
rows.extend(
39+
self.emitter(
40+
EventType.START, opsets=opsets, ir_version=self.proto_.ir_version
41+
)
42+
)
3943
inputs = self.proto_.graph.input
4044
outputs = self.proto_.graph.output
4145
nodes = self.proto_.graph.node

0 commit comments

Comments
 (0)