44from types import ModuleType
55from .types import FormatInfo , RoundMode
66import numpy as np
7+ import array_api_compat
78
89
910def _isodd (v : np .ndarray ) -> np .ndarray :
1011 return v & 0x1 == 1
1112
1213
14+ def _ldexp (v : np .ndarray , s : np .ndarray ) -> np .ndarray :
15+ xp = array_api_compat .array_namespace (v , s )
16+ if (
17+ array_api_compat .is_torch_array (v )
18+ or array_api_compat .is_jax_array (v )
19+ or array_api_compat .is_numpy_array (v )
20+ ):
21+ return xp .ldexp (v , s )
22+
23+ # Scale away from subnormal/infinite ranges
24+ offset = 24
25+ vlo = (v * 2.0 ** + offset ) * 2.0 ** xp .astype (s - offset , v .dtype )
26+ vhi = (v * 2.0 ** - offset ) * 2.0 ** xp .astype (s + offset , v .dtype )
27+ return xp .where (v < 1.0 , vlo , vhi )
28+
29+
1330def round_ndarray (
1431 fi : FormatInfo ,
1532 v : np .ndarray ,
1633 rnd : RoundMode = RoundMode .TiesToEven ,
1734 sat : bool = False ,
1835 srbits : Optional [np .ndarray ] = None ,
1936 srnumbits : int = 0 ,
20- np : ModuleType = np ,
2137) -> np .ndarray :
2238 """
2339 Vectorized version of :meth:`round_float`.
@@ -38,8 +54,6 @@ def round_ndarray(
3854 srbits (int array): Bits to use for stochastic rounding if rnd == Stochastic.
3955 srnumbits (int): How many bits are in srbits. Implies srbits < 2**srnumbits.
4056
41- np (Module): May be `numpy`, `jax.numpy` or another module cloning numpy
42-
4357 Returns:
4458 An array of floats which is a subset of the format's value set.
4559
@@ -48,27 +62,38 @@ def round_ndarray(
4862 (e.g. converting a `NaN`, or an `Inf` when the target has no
4963 `NaN` or `Inf`, and :paramref:`sat` is false)
5064 """
65+ xp = array_api_compat .array_namespace (v , srbits )
66+
5167 p = fi .precision
5268 bias = fi .expBias
5369
54- is_negative = np .signbit (v ) & fi .is_signed
55- absv = np .where (is_negative , - v , v )
70+ is_negative = xp .signbit (v ) & fi .is_signed
71+ absv = xp .where (is_negative , - v , v )
5672
57- finite_nonzero = ~ (np .isnan (v ) | np .isinf (v ) | (v == 0 ))
73+ finite_nonzero = ~ (xp .isnan (v ) | xp .isinf (v ) | (v == 0 ))
5874
5975 # Place 1.0 where finite_nonzero is False, to avoid log of {0,inf,nan}
60- absv_masked = np .where (finite_nonzero , absv , 1.0 )
76+ absv_masked = xp .where (finite_nonzero , absv , 1.0 )
6177
62- expval = np .floor (np .log2 (absv_masked )).astype (int )
78+ int_type = xp .int64 if fi .k > 8 or srnumbits > 8 else xp .int16
79+
80+ def to_int (x : np .ndarray ) -> np .ndarray :
81+ return xp .astype (x , int_type )
82+
83+ def to_float (x : np .ndarray ) -> np .ndarray :
84+ return xp .astype (x , v .dtype )
85+
86+ expval = to_int (xp .floor (xp .log2 (absv_masked )))
6387
6488 if fi .has_subnormals :
65- expval = np .maximum (expval , 1 - bias )
89+ expval = xp .maximum (expval , 1 - bias )
6690
6791 expval = expval - p + 1
68- fsignificand = np . ldexp (absv_masked , - expval )
92+ fsignificand = _ldexp (absv_masked , - expval )
6993
70- isignificand = np .floor (fsignificand ).astype (np .int64 )
71- delta = fsignificand - isignificand
94+ floorfsignificand = xp .floor (fsignificand )
95+ isignificand = to_int (floorfsignificand )
96+ delta = fsignificand - floorfsignificand
7297
7398 if fi .precision > 1 :
7499 code_is_odd = _isodd (isignificand )
@@ -77,7 +102,7 @@ def round_ndarray(
77102
78103 match rnd :
79104 case RoundMode .TowardZero :
80- should_round_away = np .zeros_like (delta , dtype = bool )
105+ should_round_away = xp .zeros_like (delta , dtype = xp . bool )
81106
82107 case RoundMode .TowardPositive :
83108 should_round_away = ~ is_negative & (delta > 0 )
@@ -95,38 +120,44 @@ def round_ndarray(
95120 assert srbits is not None
96121 ## RTNE delta to srbits
97122 d = delta * 2.0 ** srnumbits
98- floord = np .floor (d ).astype (np .int64 )
99- dd = d - floord
100- drnd = floord + (dd > 0.5 ) + ((dd == 0.5 ) & _isodd (floord ))
123+ floord = to_int (xp .floor (d ))
124+ dd = d - xp .floor (d )
125+ should_round_away_tne = (dd > 0.5 ) | ((dd == 0.5 ) & _isodd (floord ))
126+ drnd = floord + xp .astype (should_round_away_tne , floord .dtype )
101127
102- should_round_away = drnd + srbits >= 2.0 ** srnumbits
128+ should_round_away = drnd + srbits >= 2 ** srnumbits
103129
104130 case RoundMode .StochasticOdd :
105131 assert srbits is not None
106132 ## RTNO delta to srbits
107133 d = delta * 2.0 ** srnumbits
108- floord = np .floor (d ).astype (np .int64 )
109- dd = d - floord
110- drnd = floord + (dd > 0.5 ) + ((dd == 0.5 ) & ~ _isodd (floord ))
134+ floord = to_int (xp .floor (d ))
135+ dd = d - xp .floor (d )
136+ should_round_away_tno = (dd > 0.5 ) | ((dd == 0.5 ) & ~ _isodd (floord ))
137+ drnd = floord + xp .astype (should_round_away_tno , floord .dtype )
111138
112- should_round_away = drnd + srbits >= 2.0 ** srnumbits
139+ should_round_away = drnd + srbits >= 2 ** srnumbits
113140
114141 case RoundMode .StochasticFast :
115142 assert srbits is not None
116- should_round_away = delta + (2 * srbits + 1 ) * 2.0 ** - (1 + srnumbits ) >= 1.0
143+ should_round_away = (
144+ delta + to_float (2 * srbits + 1 ) * 2.0 ** - (1 + srnumbits ) >= 1.0
145+ )
117146
118147 case RoundMode .StochasticFastest :
119148 assert srbits is not None
120- should_round_away = delta + srbits * 2.0 ** - srnumbits >= 1.0
149+ should_round_away = delta + to_float (srbits ) * 2.0 ** - srnumbits >= 1.0
150+
151+ isignificand = xp .where (should_round_away , isignificand + 1 , isignificand )
121152
122- isignificand = np . where ( should_round_away , isignificand + 1 , isignificand )
153+ fresult = _ldexp ( to_float ( isignificand ), expval )
123154
124- result = np .where (finite_nonzero , np . ldexp ( isignificand , expval ) , absv )
155+ result = xp .where (finite_nonzero , fresult , absv )
125156
126- amax = np .where (is_negative , - fi .min , fi .max )
157+ amax = xp .where (is_negative , - fi .min , fi .max )
127158
128159 if sat :
129- result = np .where (result > amax , amax , result )
160+ result = xp .where (result > amax , amax , result )
130161 else :
131162 match rnd :
132163 case RoundMode .TowardNegative :
@@ -136,25 +167,25 @@ def round_ndarray(
136167 case RoundMode .TowardZero :
137168 put_amax_at = result > amax
138169 case _:
139- put_amax_at = np .zeros_like (result , dtype = bool )
170+ put_amax_at = xp .zeros_like (result , dtype = xp . bool )
140171
141- result = np .where (finite_nonzero & put_amax_at , amax , result )
172+ result = xp .where (finite_nonzero & put_amax_at , amax , result )
142173
143174 # Now anything larger than amax goes to infinity or NaN
144175 if fi .has_infs :
145- result = np .where (result > amax , np .inf , result )
176+ result = xp .where (result > amax , xp .inf , result )
146177 elif fi .num_nans > 0 :
147- result = np .where (result > amax , np .nan , result )
178+ result = xp .where (result > amax , xp .nan , result )
148179 else :
149- if np .any (result > amax ):
180+ if xp .any (result > amax ):
150181 raise ValueError (f"No Infs or NaNs in format { fi } , and sat=False" )
151182
152- result = np .where (is_negative , - result , result )
183+ result = xp .where (is_negative , - result , result )
153184
154185 # Make negative zeros negative if has_nz, else make them not negative.
155186 if fi .has_nz :
156- result = np .where ((result == 0 ) & is_negative , - 0.0 , result )
187+ result = xp .where ((result == 0 ) & is_negative , - 0.0 , result )
157188 else :
158- result = np .where (result == 0 , 0.0 , result )
189+ result = xp .where (result == 0 , 0.0 , result )
159190
160191 return result
0 commit comments