@@ -20,6 +20,10 @@ class BuilderEmitter(BaseEmitter):
20
20
Converts event into proper code.
21
21
"""
22
22
23
+ def __init__ (self , make_model_function : str = "" ):
24
+ super ().__init__ ()
25
+ self .make_model_function = make_model_function
26
+
23
27
def join (self , rows : List [str ], single_line : bool = False ) -> str :
24
28
"Join the rows"
25
29
assert (
@@ -29,6 +33,7 @@ def join(self, rows: List[str], single_line: bool = False) -> str:
29
33
30
34
def _emit_start (self , ** kwargs : Dict [str , Any ]) -> List [str ]:
31
35
self .opsets = kwargs .get ("opsets" , {})
36
+ self .ir_version = kwargs .get ("ir_version" , None )
32
37
return []
33
38
34
39
def _emit_to_onnx_model (self , ** kwargs : Dict [str , Any ]) -> List [str ]:
@@ -43,12 +48,27 @@ def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]:
43
48
)
44
49
rows = [
45
50
"" ,
46
- f"g = GraphBuilder({ self .opsets } )" ,
51
+ (
52
+ f"g = GraphBuilder({ self .opsets } , ir_version={ self .ir_version } )"
53
+ if self .ir_version
54
+ else f"GraphBuilder({ self .opsets } )"
55
+ ),
47
56
* inputs ,
48
57
f"{ self .name } ({ inps } )" ,
49
58
* outputs ,
50
59
"model = g.to_onnx()" ,
51
60
]
61
+ if self .make_model_function :
62
+ rows = [
63
+ "" ,
64
+ "" ,
65
+ f'def { self .make_model_function } () -> "ModelProto":' ,
66
+ * [" " + _ for _ in rows [1 :]],
67
+ " return model" ,
68
+ "" ,
69
+ "" ,
70
+ f"model = { self .make_model_function } ()" ,
71
+ ]
52
72
return rows
53
73
54
74
def _emit_begin_graph (self , ** kwargs : Dict [str , Any ]) -> List [str ]:
@@ -79,12 +99,14 @@ def _emit_input(self, **kwargs: Dict[str, Any]) -> List[str]:
79
99
itype = kwargs .get ("elem_type" , 0 )
80
100
shape = kwargs .get ("shape" , None )
81
101
if itype == 0 :
82
- inp = "X"
102
+ inp = name or "X"
83
103
else :
84
104
if shape is None :
85
- inp = f'X : "{ _itype_to_string (itype )} "'
105
+ inp = f'{ name } : "{ _itype_to_string (itype )} "'
86
106
else :
87
- inp = f'X: "{ _itype_to_string (itype )} [{ ", " .join (map (str , shape ))} ]"'
107
+ inp = (
108
+ f'{ name } : "{ _itype_to_string (itype )} [{ ", " .join (map (str , shape ))} ]"'
109
+ )
88
110
self .inputs_full .append (inp )
89
111
self .inputs .append (name )
90
112
self .inputs_full_ .append ((name , _itype_to_string (itype ), shape ))
0 commit comments