@@ -66,8 +66,7 @@ def identity(x: ArrayLike) -> Array:
6666 Array([-2. , -1. , -0.5, 0. , 0.5, 1. , 2. ], dtype=float32)
6767
6868 """
69- numpy_util .check_arraylike ("identity" , x )
70- return jnp .asarray (x )
69+ return numpy_util .ensure_arraylike ("identity" , x )
7170
7271@custom_derivatives .custom_jvp
7372@api .jit
@@ -121,10 +120,7 @@ def squareplus(x: ArrayLike, b: ArrayLike = 4) -> Array:
121120 x : input array
122121 b : smoothness parameter
123122 """
124- numpy_util .check_arraylike ("squareplus" , x )
125- numpy_util .check_arraylike ("squareplus" , b )
126- x = jnp .asarray (x )
127- b = jnp .asarray (b )
123+ x , b = numpy_util .ensure_arraylike ("squareplus" , x , b )
128124 y = x + jnp .sqrt (jnp .square (x ) + b )
129125 return y / 2
130126
@@ -164,8 +160,7 @@ def sparse_plus(x: ArrayLike) -> Array:
164160 Args:
165161 x: input (float)
166162 """
167- numpy_util .check_arraylike ("sparse_plus" , x )
168- x = jnp .asarray (x )
163+ x = numpy_util .ensure_arraylike ("sparse_plus" , x )
169164 return jnp .where (x <= - 1.0 , 0.0 , jnp .where (x >= 1.0 , x , (x + 1.0 )** 2 / 4 ))
170165
171166@api .jit
@@ -180,8 +175,7 @@ def soft_sign(x: ArrayLike) -> Array:
180175 Args:
181176 x : input array
182177 """
183- numpy_util .check_arraylike ("soft_sign" , x )
184- x_arr = jnp .asarray (x )
178+ x_arr = numpy_util .ensure_arraylike ("soft_sign" , x )
185179 return x_arr / (jnp .abs (x_arr ) + 1 )
186180
187181@api .jit (inline = True )
@@ -257,8 +251,7 @@ def silu(x: ArrayLike) -> Array:
257251 See also:
258252 :func:`sigmoid`
259253 """
260- numpy_util .check_arraylike ("silu" , x )
261- x_arr = jnp .asarray (x )
254+ x_arr = numpy_util .ensure_arraylike ("silu" , x )
262255 return x_arr * sigmoid (x_arr )
263256
264257swish = silu
@@ -282,8 +275,7 @@ def mish(x: ArrayLike) -> Array:
282275 Returns:
283276 An array.
284277 """
285- numpy_util .check_arraylike ("mish" , x )
286- x_arr = jnp .asarray (x )
278+ x_arr = numpy_util .ensure_arraylike ("mish" , x )
287279 return x_arr * jnp .tanh (softplus (x_arr ))
288280
289281@api .jit
@@ -304,8 +296,7 @@ def log_sigmoid(x: ArrayLike) -> Array:
304296 See also:
305297 :func:`sigmoid`
306298 """
307- numpy_util .check_arraylike ("log_sigmoid" , x )
308- x_arr = jnp .asarray (x )
299+ x_arr = numpy_util .ensure_arraylike ("log_sigmoid" , x )
309300 return - softplus (- x_arr )
310301
311302@api .jit
@@ -330,8 +321,7 @@ def elu(x: ArrayLike, alpha: ArrayLike = 1.0) -> Array:
330321 See also:
331322 :func:`selu`
332323 """
333- numpy_util .check_arraylike ("elu" , x )
334- x_arr = jnp .asarray (x )
324+ x_arr = numpy_util .ensure_arraylike ("elu" , x )
335325 return jnp .where (x_arr > 0 ,
336326 x_arr ,
337327 alpha * jnp .expm1 (jnp .where (x_arr > 0 , 0. , x_arr )))
@@ -360,8 +350,7 @@ def leaky_relu(x: ArrayLike, negative_slope: ArrayLike = 1e-2) -> Array:
360350 See also:
361351 :func:`relu`
362352 """
363- numpy_util .check_arraylike ("leaky_relu" , x )
364- x_arr = jnp .asarray (x )
353+ x_arr = numpy_util .ensure_arraylike ("leaky_relu" , x )
365354 return jnp .where (x_arr >= 0 , x_arr , negative_slope * x_arr )
366355
367356@api .jit
@@ -383,8 +372,7 @@ def hard_tanh(x: ArrayLike) -> Array:
383372 Returns:
384373 An array.
385374 """
386- numpy_util .check_arraylike ("hard_tanh" , x )
387- x_arr = jnp .asarray (x )
375+ x_arr = numpy_util .ensure_arraylike ("hard_tanh" , x )
388376 return jnp .where (x_arr > 1 , 1 , jnp .where (x_arr < - 1 , - 1 , x_arr ))
389377
390378@api .jit
@@ -504,8 +492,7 @@ def glu(x: ArrayLike, axis: int = -1) -> Array:
504492 See also:
505493 :func:`sigmoid`
506494 """
507- numpy_util .check_arraylike ("glu" , x )
508- x_arr = jnp .asarray (x )
495+ x_arr = numpy_util .ensure_arraylike ("glu" , x )
509496 size = x_arr .shape [axis ]
510497 assert size % 2 == 0 , "axis size must be divisible by 2"
511498 x1 , x2 = jnp .split (x_arr , 2 , axis )
@@ -575,8 +562,7 @@ def log_softmax(x: ArrayLike,
575562 See also:
576563 :func:`softmax`
577564 """
578- numpy_util .check_arraylike ("log_softmax" , x )
579- x_arr = jnp .asarray (x )
565+ x_arr = numpy_util .ensure_arraylike ("log_softmax" , x )
580566 x_max = jnp .max (x_arr , axis , where = where , initial = - np .inf , keepdims = True )
581567 x_safe = x_arr if where is None else jnp .where (where , x_arr , - np .inf )
582568 shifted = x_safe - lax .stop_gradient (x_max )
@@ -849,8 +835,7 @@ def hard_silu(x: ArrayLike) -> Array:
849835 See also:
850836 :func:`hard_sigmoid`
851837 """
852- numpy_util .check_arraylike ("hard_silu" , x )
853- x_arr = jnp .asarray (x )
838+ x_arr = numpy_util .ensure_arraylike ("hard_silu" , x )
854839 return x_arr * hard_sigmoid (x_arr )
855840
856841hard_swish = hard_silu
@@ -1496,8 +1481,7 @@ def log1mexp(x: ArrayLike) -> Array:
14961481 .. [1] Martin Mächler. `Accurately Computing log(1 − exp(−|a|)) Assessed by the Rmpfr package.
14971482 <https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf>`_.
14981483 """
1499- numpy_util .check_arraylike ("log1mexp" , x )
1500- x = jnp .asarray (x )
1484+ x = numpy_util .ensure_arraylike ("log1mexp" , x )
15011485 c = jnp .log (2.0 )
15021486 return jnp .where (
15031487 x < c ,
0 commit comments