Skip to content

Commit 6532733

Browse files
committed
add method check_order
1 parent b73a0cb commit 6532733

File tree

2 files changed

+54
-5
lines changed

2 files changed

+54
-5
lines changed

_unittests/ut_graph_api/test_graph_builder_optim.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import unittest
33
import onnx
4+
from onnx.inliner import inline_local_functions
45
from onnx_array_api.ext_test_case import ExtTestCase
56
from onnx_array_api.graph_api.graph_builder import GraphBuilder
67

@@ -54,7 +55,7 @@ def test_keep_unused_outputs(self):
5455
self.assertEqual(len(onx.graph.node), 2)
5556
self.assertEqual(onx.graph.node[0].op_type, "Split")
5657

57-
def test_check_files(self):
58+
def test_check_afiles(self):
5859
import onnxruntime
5960

6061
data = os.path.join(os.path.dirname(__file__), "data")
@@ -66,8 +67,14 @@ def test_check_files(self):
6667
os.path.join(data, f), providers=["CPUExecutionProvider"]
6768
)
6869
assert sess
69-
g = GraphBuilder(onx)
70-
g.optimize()
70+
onxi = inline_local_functions(onx)
71+
sess = onnxruntime.InferenceSession(
72+
onxi.SerializeToString(), providers=["CPUExecutionProvider"]
73+
)
74+
assert sess
75+
g = GraphBuilder(onxi)
76+
g.optimize(check_order=True)
77+
g.check_order()
7178
onx2 = g.to_onnx()
7279
sess2 = onnxruntime.InferenceSession(
7380
onx2.SerializeToString(), providers=["CPUExecutionProvider"]

onnx_array_api/graph_api/graph_builder.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from functools import partial
2-
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
2+
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union
33
import numpy as np
44
import onnx.helper as oh
55
import onnx.numpy_helper as onh
@@ -604,14 +604,56 @@ def to_onnx(
604604
model = oh.make_model(graph, opset_imports=opsets)
605605
return model
606606

607-
def optimize(self):
607+
def _check_order_node(self, ind: int, node: NodeProto, existing: Set[str]):
608+
for i in node.input:
609+
if i not in existing:
610+
raise RuntimeError(
611+
f"Unknown input {i!r} from node {ind}:{node.op_type}:{node.name}. "
612+
f"Known: {existing}."
613+
)
614+
for att in node.attribute:
615+
if att.type == AttributeProto.GRAPH and att.g:
616+
g_existing = existing.copy()
617+
for i in att.g.input:
618+
g_existing.add(i.name)
619+
for ind2, node2 in enumerate(att.g.node):
620+
self._check_order_node((ind, ind2), node2, g_existing)
621+
for o in att.g.output:
622+
if o.name not in g_existing:
623+
raise RuntimeError(
624+
f"Unknown output {o.name!r}. Known: {g_existing}."
625+
)
626+
for o in node.output:
627+
existing.add(o)
628+
629+
def check_order(self):
630+
existing = set(self.initializers_dict)
631+
for i in self.inputs:
632+
existing.add(i.name)
633+
for ind, node in enumerate(self.nodes):
634+
self._check_order_node(ind, node, existing)
635+
for o in self.outputs:
636+
if o.name not in existing:
637+
raise RuntimeError(f"Unknown output {o.name!r}. Known: {existing}.")
638+
639+
def optimize(self, check_order: bool = False):
640+
if check_order:
641+
self.check_order()
608642
self.remove_identity_nodes()
643+
if check_order:
644+
self.check_order()
609645
if self.optimization_options.remove_unused:
610646
self.remove_unused()
647+
if check_order:
648+
self.check_order()
611649
if self.optimization_options.constant_folding:
612650
self.constant_folding()
651+
if check_order:
652+
self.check_order()
613653
if self.optimization_options.remove_unused:
614654
self.remove_unused()
655+
if check_order:
656+
self.check_order()
615657

616658
def remove_unused(self):
617659
"""

0 commit comments

Comments
 (0)