Skip to content

Commit 2fde01f

Browse files
authored
Fix float 8 conversion (#36)
1 parent cd73f71 commit 2fde01f

File tree

6 files changed

+51
-27
lines changed

6 files changed

+51
-27
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ _doc/examples/data/*.optimized.onnx
2222
_doc/examples/*.html
2323
_doc/_static/require.js
2424
_doc/_static/viz.js
25+
_doc/LICENSE.txt
26+
_doc/CHANGELOGS.rst
2527
_unittests/ut__main/*.png
2628
_unittests/ut__main/_cache/*
2729
_unittests/ut__main/*.html

_doc/api/f8.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
Float 8
2+
=======
3+
4+
.. automodule:: onnx_array_api.validation.f8
5+
:members:

_doc/api/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@ API
2020
reference
2121
tools
2222
profiling
23+
f8

_doc/conf.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,7 @@
122122
"onnxruntime": "https://onnxruntime.ai/",
123123
"numpy": "https://numpy.org/",
124124
"numba": "https://numba.pydata.org/",
125-
"onnx-array-api": (
126-
"http://www.xavierdupre.fr/app/onnx-array-api/helpsphinx/index.html"
127-
),
125+
"onnx-array-api": ("https://sdpython.github.io/doc/onnx-array-api/dev/"),
128126
"pyinstrument": "https://github.com/joerick/pyinstrument",
129127
"python": "https://www.python.org/",
130128
"scikit-learn": "https://scikit-learn.org/stable/",

_unittests/ut_validation/test_f8.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1151,7 +1151,16 @@ def test_float8_e5m2fnuz_negative_nan(self):
11511151
back = fe4m3_to_float32(to, fn=True, uz=True)
11521152
self.assertTrue(numpy.isnan(back))
11531153

1154+
def test_fe4m3fn_to_float32_bug(self):
1155+
cases = [(1.8131605, 1.875)]
1156+
for val, expected in cases:
1157+
with self.subTest(value=val, expected=expected):
1158+
res = fe4m3_to_float32(search_float32_into_fe4m3(val))
1159+
self.assertEqual(expected, res)
1160+
res = fe4m3_to_float32(float32_to_fe4m3(val))
1161+
self.assertEqual(expected, res)
1162+
11541163

11551164
if __name__ == "__main__":
1156-
TestF8().test_search_float32_into_fe4m3fn_simple()
1165+
TestF8().test_fe4m3fn_to_float32_bug()
11571166
unittest.main(verbosity=2)

onnx_array_api/validation/f8.py

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,17 @@ class UndefinedCastError(FloatingPointError):
1212
pass
1313

1414

15-
def display_float32(value, sign=1, exponent=8, mantissa=23):
15+
def display_int(ival, sign=1, exponent=8, mantissa=23):
1616
"""
17-
Displays a float32 into b.
17+
Displays an integer as bits.
1818
19-
:param value: value to display (float32)
19+
:param ival: value to display (float32)
2020
:param sign: number of bits for the sign
2121
:param exponent: number of bits for the exponent
2222
:param mantissa: number of bits for the mantissa
2323
:return: string
2424
"""
2525
t = sign + exponent + mantissa
26-
ival = int.from_bytes(struct.pack("<f", numpy.float32(value)), "little")
2726
s = bin(ival)[2:]
2827
s = "0" * (t - len(s)) + s
2928
s1 = s[:sign]
@@ -32,6 +31,24 @@ def display_float32(value, sign=1, exponent=8, mantissa=23):
3231
return ".".join([s1, s2, s3])
3332

3433

34+
def display_float32(value, sign=1, exponent=8, mantissa=23):
35+
"""
36+
Displays a float32 into b.
37+
38+
:param value: value to display (float32)
39+
:param sign: number of bits for the sign
40+
:param exponent: number of bits for the exponent
41+
:param mantissa: number of bits for the mantissa
42+
:return: string
43+
"""
44+
return display_int(
45+
int.from_bytes(struct.pack("<f", numpy.float32(value)), "little"),
46+
sign=sign,
47+
exponent=exponent,
48+
mantissa=mantissa,
49+
)
50+
51+
3552
def display_float16(value, sign=1, exponent=5, mantissa=10):
3653
"""
3754
Displays a float32 into b.
@@ -42,14 +59,9 @@ def display_float16(value, sign=1, exponent=5, mantissa=10):
4259
:param mantissa: number of bits for the mantissa
4360
:return: string
4461
"""
45-
t = sign + exponent + mantissa
46-
ival = numpy.float16(value).view("H") # pylint: disable=E1121
47-
s = bin(ival)[2:]
48-
s = "0" * (t - len(s)) + s
49-
s1 = s[:sign]
50-
s2 = s[sign : sign + exponent]
51-
s3 = s[sign + exponent :]
52-
return ".".join([s1, s2, s3])
62+
return display_int(
63+
numpy.float16(value).view("H"), sign=sign, exponent=exponent, mantissa=mantissa
64+
)
5365

5466

5567
def display_fexmx(value, sign, exponent, mantissa):
@@ -64,14 +76,7 @@ def display_fexmx(value, sign, exponent, mantissa):
6476
:param mantissa: number of bits for the mantissa
6577
:return: string
6678
"""
67-
t = sign + exponent + mantissa
68-
ival = value
69-
s = bin(ival)[2:]
70-
s = "0" * (t - len(s)) + s
71-
s1 = s[:sign]
72-
s2 = s[sign : sign + exponent]
73-
s3 = s[sign + exponent :]
74-
return ".".join([s1, s2, s3])
79+
return display_int(value, sign=sign, exponent=exponent, mantissa=mantissa)
7580

7681

7782
def display_fe4m3(value, sign=1, exponent=4, mantissa=3):
@@ -534,7 +539,9 @@ def float32_to_fe4m3(x, fn: bool = True, uz: bool = False, saturate: bool = True
534539
else:
535540
ret |= ex << 3
536541
ret |= m >> 20
537-
if m & 0x80000:
542+
if (m & 0x80000) and (
543+
(m & 0x100000) or (m & 0x7FFFF)
544+
): # round to nearest even
538545
if (ret & 0x7F) < 0x7F:
539546
# rounding
540547
ret += 1
@@ -584,7 +591,7 @@ def float32_to_fe4m3(x, fn: bool = True, uz: bool = False, saturate: bool = True
584591
if (ret & 0x7F) == 0x7F:
585592
ret &= 0xFE
586593
if (m & 0x80000) and (
587-
(m & 0x100000) or (m & 0x7C000)
594+
(m & 0x100000) or (m & 0x7FFFF)
588595
): # round to nearest even
589596
if (ret & 0x7F) < 0x7E:
590597
# rounding
@@ -642,7 +649,9 @@ def float32_to_fe5m2(x, fn: bool = False, uz: bool = False, saturate: bool = Tru
642649
ex = e - 111 # 127 - 16
643650
ret |= ex << 2
644651
ret |= m >> 21
645-
if m & 0x100000:
652+
if m & 0x100000 and (
653+
(m & 0xFFFFF) or (m & 0x200000)
654+
): # round to nearest even
646655
if (ret & 0x7F) < 0x7F:
647656
# rounding
648657
ret += 1

0 commit comments

Comments
 (0)