|
1 | 1 | from functools import partial
|
2 |
| -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union |
| 2 | +from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union |
3 | 3 | import numpy as np
|
4 | 4 | import onnx.helper as oh
|
5 | 5 | import onnx.numpy_helper as onh
|
@@ -604,14 +604,56 @@ def to_onnx(
|
604 | 604 | model = oh.make_model(graph, opset_imports=opsets)
|
605 | 605 | return model
|
606 | 606 |
|
607 |
| - def optimize(self): |
| 607 | + def _check_order_node(self, ind: int, node: NodeProto, existing: Set[str]): |
| 608 | + for i in node.input: |
| 609 | + if i not in existing: |
| 610 | + raise RuntimeError( |
| 611 | + f"Unknown input {i!r} from node {ind}:{node.op_type}:{node.name}. " |
| 612 | + f"Known: {existing}." |
| 613 | + ) |
| 614 | + for att in node.attribute: |
| 615 | + if att.type == AttributeProto.GRAPH and att.g: |
| 616 | + g_existing = existing.copy() |
| 617 | + for i in att.g.input: |
| 618 | + g_existing.add(i.name) |
| 619 | + for ind2, node2 in enumerate(att.g.node): |
| 620 | + self._check_order_node((ind, ind2), node2, g_existing) |
| 621 | + for o in att.g.output: |
| 622 | + if o.name not in g_existing: |
| 623 | + raise RuntimeError( |
| 624 | + f"Unknown output {o.name!r}. Known: {g_existing}." |
| 625 | + ) |
| 626 | + for o in node.output: |
| 627 | + existing.add(o) |
| 628 | + |
| 629 | + def check_order(self): |
| 630 | + existing = set(self.initializers_dict) |
| 631 | + for i in self.inputs: |
| 632 | + existing.add(i.name) |
| 633 | + for ind, node in enumerate(self.nodes): |
| 634 | + self._check_order_node(ind, node, existing) |
| 635 | + for o in self.outputs: |
| 636 | + if o.name not in existing: |
| 637 | + raise RuntimeError(f"Unknown output {o.name!r}. Known: {existing}.") |
| 638 | + |
| 639 | + def optimize(self, check_order: bool = False): |
| 640 | + if check_order: |
| 641 | + self.check_order() |
608 | 642 | self.remove_identity_nodes()
|
| 643 | + if check_order: |
| 644 | + self.check_order() |
609 | 645 | if self.optimization_options.remove_unused:
|
610 | 646 | self.remove_unused()
|
| 647 | + if check_order: |
| 648 | + self.check_order() |
611 | 649 | if self.optimization_options.constant_folding:
|
612 | 650 | self.constant_folding()
|
| 651 | + if check_order: |
| 652 | + self.check_order() |
613 | 653 | if self.optimization_options.remove_unused:
|
614 | 654 | self.remove_unused()
|
| 655 | + if check_order: |
| 656 | + self.check_order() |
615 | 657 |
|
616 | 658 | def remove_unused(self):
|
617 | 659 | """
|
|
0 commit comments