Skip to content

Commit ac4acc6

Browse files
authored
Fix as_tensor in onnx_text_plot_tree (#101)
* Fix as_tensor * fix issues * lint * fix clean * atol * fix issues
1 parent 96eb50e commit ac4acc6

File tree

9 files changed

+78
-86
lines changed

9 files changed

+78
-86
lines changed

CHANGELOGS.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
Change Logs
22
===========
33

4+
0.3.2
5+
+++++
6+
7+
* :pr:`101`: fix as_tensor in onnx_text_plot_tree
8+
49
0.3.1
510
+++++
611

_unittests/ut_light_api/test_backend_export.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@
1919
make_opsetid,
2020
make_tensor_value_info,
2121
)
22-
from onnx.reference.op_run import to_array_extended
22+
23+
try:
24+
from onnx.reference.op_run import to_array_extended
25+
except ImportError:
26+
from onnx.numpy_helper import to_array as to_array_extended
2327
from onnx.numpy_helper import from_array, to_array
2428
from onnx.backend.base import Device, DeviceType
2529
from onnx_array_api.reference import ExtendedReferenceEvaluator
@@ -240,7 +244,19 @@ def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs):
240244
raise NotImplementedError("Unable to run the model node by node.")
241245

242246

243-
backend_test = onnx.backend.test.BackendTest(ExportBackend, __name__)
247+
dft_atol = 1e-3 if sys.platform != "linux" else 1e-5
248+
backend_test = onnx.backend.test.BackendTest(
249+
ExportBackend,
250+
__name__,
251+
test_kwargs={
252+
"test_dft": {"atol": dft_atol},
253+
"test_dft_axis": {"atol": dft_atol},
254+
"test_dft_axis_opset19": {"atol": dft_atol},
255+
"test_dft_inverse": {"atol": dft_atol},
256+
"test_dft_inverse_opset19": {"atol": dft_atol},
257+
"test_dft_opset19": {"atol": dft_atol},
258+
},
259+
)
244260

245261
# The following tests are too slow with the reference implementation (Conv).
246262
backend_test.exclude(

_unittests/ut_reference/test_backend_extended_reference_evaluator.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import platform
3+
import sys
34
import unittest
45
from typing import Any
56
import numpy
@@ -78,10 +79,21 @@ def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs):
7879
raise NotImplementedError("Unable to run the model node by node.")
7980

8081

82+
dft_atol = 1e-3 if sys.platform != "linux" else 1e-5
8183
backend_test = onnx.backend.test.BackendTest(
82-
ExtendedReferenceEvaluatorBackend, __name__
84+
ExtendedReferenceEvaluatorBackend,
85+
__name__,
86+
test_kwargs={
87+
"test_dft": {"atol": dft_atol},
88+
"test_dft_axis": {"atol": dft_atol},
89+
"test_dft_axis_opset19": {"atol": dft_atol},
90+
"test_dft_inverse": {"atol": dft_atol},
91+
"test_dft_inverse_opset19": {"atol": dft_atol},
92+
"test_dft_opset19": {"atol": dft_atol},
93+
},
8394
)
8495

96+
8597
if os.getenv("APPVEYOR"):
8698
backend_test.exclude("(test_vgg19|test_zfnet)")
8799
if platform.architecture()[0] == "32bit":

azure-pipelines.yml

Lines changed: 0 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -93,63 +93,6 @@ jobs:
9393
python -m pytest
9494
displayName: 'Runs Unit Tests'
9595
96-
- job: 'TestLinuxArrayApi'
97-
pool:
98-
vmImage: 'ubuntu-latest'
99-
strategy:
100-
matrix:
101-
Python310-Linux:
102-
python.version: '3.10'
103-
maxParallel: 3
104-
105-
steps:
106-
- task: UsePythonVersion@0
107-
inputs:
108-
versionSpec: '$(python.version)'
109-
architecture: 'x64'
110-
- script: sudo apt-get update
111-
displayName: 'AptGet Update'
112-
- script: python -m pip install --upgrade pip setuptools wheel
113-
displayName: 'Install tools'
114-
- script: pip install -r requirements.txt
115-
displayName: 'Install Requirements'
116-
- script: pip install onnxruntime
117-
displayName: 'Install onnxruntime'
118-
- script: python setup.py install
119-
displayName: 'Install onnx_array_api'
120-
- script: |
121-
git clone https://github.com/data-apis/array-api-tests.git
122-
displayName: 'clone array-api-tests'
123-
- script: |
124-
cd array-api-tests
125-
git submodule update --init --recursive
126-
cd ..
127-
displayName: 'get submodules for array-api-tests'
128-
- script: pip install -r array-api-tests/requirements.txt
129-
displayName: 'Install Requirements dev'
130-
- script: |
131-
export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy
132-
cd array-api-tests
133-
displayName: 'Set API'
134-
- script: |
135-
python -m pip freeze
136-
displayName: 'pip freeze'
137-
- script: |
138-
export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy
139-
cd array-api-tests
140-
python -m pytest -x array_api_tests/test_creation_functions.py --skips-file=../_unittests/onnx-numpy-skips.txt --hypothesis-explain
141-
displayName: "numpy test_creation_functions.py"
142-
# - script: |
143-
# export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_ort
144-
# cd array-api-tests
145-
# python -m pytest -x array_api_tests/test_creation_functions.py --skips-file=../_unittests/onnx-ort-skips.txt --hypothesis-explain
146-
# displayName: "ort test_creation_functions.py"
147-
#- script: |
148-
# export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy
149-
# cd array-api-tests
150-
# python -m pytest -x array_api_tests
151-
# displayName: "all tests"
152-
15396
- job: 'TestLinux'
15497
pool:
15598
vmImage: 'ubuntu-latest'

onnx_array_api/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22
APIs to create ONNX Graphs.
33
"""
44

5-
__version__ = "0.3.1"
5+
__version__ = "0.3.2"
66
__author__ = "Xavier Dupré"

onnx_array_api/plotting/text_plot.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,10 @@ def __init__(self, i, atts):
6464
self.nodes_missing_value_tracks_true = None
6565
for k, v in atts.items():
6666
if k.startswith("nodes"):
67-
setattr(self, k, v[i])
67+
if k.endswith("_as_tensor"):
68+
setattr(self, k.replace("_as_tensor", ""), v[i])
69+
else:
70+
setattr(self, k, v[i])
6871
self.depth = 0
6972
self.true_false = ""
7073
self.targets = []
@@ -120,10 +123,7 @@ def process_tree(atts, treeid):
120123
]
121124
for k, v in atts.items():
122125
if k.startswith(prefix):
123-
if "classlabels" in k:
124-
short[k] = list(v)
125-
else:
126-
short[k] = [v[i] for i in idx]
126+
short[k] = list(v) if "classlabels" in k else [v[i] for i in idx]
127127

128128
nodes = OrderedDict()
129129
for i in range(len(short["nodes_treeids"])):
@@ -132,9 +132,10 @@ def process_tree(atts, treeid):
132132
for i in range(len(short[f"{prefix}_treeids"])):
133133
idn = short[f"{prefix}_nodeids"][i]
134134
node = nodes[idn]
135-
node.append_target(
136-
tid=short[f"{prefix}_ids"][i], weight=short[f"{prefix}_weights"][i]
137-
)
135+
key = f"{prefix}_weights"
136+
if key not in short:
137+
key = f"{prefix}_weights_as_tensor"
138+
node.append_target(tid=short[f"{prefix}_ids"][i], weight=short[key][i])
138139

139140
def iterate(nodes, node, depth=0, true_false=""):
140141
node.depth = depth

onnx_array_api/profiling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,7 @@ def add_rows(rows, d):
438438
if verbose and fLOG is not None:
439439
fLOG(
440440
"[pstats] %s=%r"
441-
% ((clean_text(k[0].replace("\\", "/")),) + k[1:], v)
441+
% ((clean_text(k[0].replace("\\", "/")), *k[1:]), v)
442442
)
443443
if len(v) < 5:
444444
continue

onnx_array_api/reference/__init__.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,21 @@
22
import numpy as np
33
from onnx import TensorProto
44
from onnx.numpy_helper import from_array as onnx_from_array
5-
from onnx.reference.ops.op_cast import (
6-
bfloat16,
7-
float8e4m3fn,
8-
float8e4m3fnuz,
9-
float8e5m2,
10-
float8e5m2fnuz,
11-
)
12-
from onnx.reference.op_run import to_array_extended
5+
6+
try:
7+
from onnx.reference.ops.op_cast import (
8+
bfloat16,
9+
float8e4m3fn,
10+
float8e4m3fnuz,
11+
float8e5m2,
12+
float8e5m2fnuz,
13+
)
14+
except ImportError:
15+
bfloat16 = None
16+
try:
17+
from onnx.reference.op_run import to_array_extended
18+
except ImportError:
19+
from onnx.numpy_helper import to_array as to_array_extended
1320
from .evaluator import ExtendedReferenceEvaluator
1421
from .evaluator_yield import (
1522
DistanceExecution,
@@ -28,6 +35,8 @@ def from_array_extended(tensor: np.array, name: Optional[str] = None) -> TensorP
2835
:param name: name
2936
:return: TensorProto
3037
"""
38+
if bfloat16 is None:
39+
return onnx_from_array(tensor, name)
3140
dt = tensor.dtype
3241
if dt == float8e4m3fn and dt.descr[0][0] == "e4m3fn":
3342
to = TensorProto.FLOAT8E4M3FN

onnx_array_api/reference/ops/op_cast_like.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,23 @@
11
from onnx.helper import np_dtype_to_tensor_dtype
22
from onnx.onnx_pb import TensorProto
33
from onnx.reference.op_run import OpRun
4-
from onnx.reference.ops.op_cast import (
5-
bfloat16,
6-
cast_to,
7-
float8e4m3fn,
8-
float8e4m3fnuz,
9-
float8e5m2,
10-
float8e5m2fnuz,
11-
)
4+
from onnx.reference.ops.op_cast import cast_to
5+
6+
try:
7+
from onnx.reference.ops.op_cast import (
8+
bfloat16,
9+
float8e4m3fn,
10+
float8e4m3fnuz,
11+
float8e5m2,
12+
float8e5m2fnuz,
13+
)
14+
except ImportError:
15+
bfloat16 = None
1216

1317

1418
def _cast_like(x, y, saturate):
19+
if bfloat16 is None:
20+
return (cast_to(x, y.dtype, saturate),)
1521
if y.dtype == bfloat16 and y.dtype.descr[0][0] == "bfloat16":
1622
# np.uint16 == np.uint16 is True as well as np.uint16 == bfloat16
1723
to = TensorProto.BFLOAT16

0 commit comments

Comments
 (0)