Skip to content

Commit dea8708

Browse files
Merge pull request #32990 from arvindsankar:ensure_arraylike
PiperOrigin-RevId: 826177852
2 parents eca5b59 + 5e72730 commit dea8708

File tree

1 file changed

+14
-30
lines changed

1 file changed

+14
-30
lines changed

jax/_src/nn/functions.py

Lines changed: 14 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,7 @@ def identity(x: ArrayLike) -> Array:
6666
Array([-2. , -1. , -0.5, 0. , 0.5, 1. , 2. ], dtype=float32)
6767
6868
"""
69-
numpy_util.check_arraylike("identity", x)
70-
return jnp.asarray(x)
69+
return numpy_util.ensure_arraylike("identity", x)
7170

7271
@custom_derivatives.custom_jvp
7372
@api.jit
@@ -121,10 +120,7 @@ def squareplus(x: ArrayLike, b: ArrayLike = 4) -> Array:
121120
x : input array
122121
b : smoothness parameter
123122
"""
124-
numpy_util.check_arraylike("squareplus", x)
125-
numpy_util.check_arraylike("squareplus", b)
126-
x = jnp.asarray(x)
127-
b = jnp.asarray(b)
123+
x, b = numpy_util.ensure_arraylike("squareplus", x, b)
128124
y = x + jnp.sqrt(jnp.square(x) + b)
129125
return y / 2
130126

@@ -164,8 +160,7 @@ def sparse_plus(x: ArrayLike) -> Array:
164160
Args:
165161
x: input (float)
166162
"""
167-
numpy_util.check_arraylike("sparse_plus", x)
168-
x = jnp.asarray(x)
163+
x = numpy_util.ensure_arraylike("sparse_plus", x)
169164
return jnp.where(x <= -1.0, 0.0, jnp.where(x >= 1.0, x, (x + 1.0)**2/4))
170165

171166
@api.jit
@@ -180,8 +175,7 @@ def soft_sign(x: ArrayLike) -> Array:
180175
Args:
181176
x : input array
182177
"""
183-
numpy_util.check_arraylike("soft_sign", x)
184-
x_arr = jnp.asarray(x)
178+
x_arr = numpy_util.ensure_arraylike("soft_sign", x)
185179
return x_arr / (jnp.abs(x_arr) + 1)
186180

187181
@api.jit(inline=True)
@@ -257,8 +251,7 @@ def silu(x: ArrayLike) -> Array:
257251
See also:
258252
:func:`sigmoid`
259253
"""
260-
numpy_util.check_arraylike("silu", x)
261-
x_arr = jnp.asarray(x)
254+
x_arr = numpy_util.ensure_arraylike("silu", x)
262255
return x_arr * sigmoid(x_arr)
263256

264257
swish = silu
@@ -282,8 +275,7 @@ def mish(x: ArrayLike) -> Array:
282275
Returns:
283276
An array.
284277
"""
285-
numpy_util.check_arraylike("mish", x)
286-
x_arr = jnp.asarray(x)
278+
x_arr = numpy_util.ensure_arraylike("mish", x)
287279
return x_arr * jnp.tanh(softplus(x_arr))
288280

289281
@api.jit
@@ -304,8 +296,7 @@ def log_sigmoid(x: ArrayLike) -> Array:
304296
See also:
305297
:func:`sigmoid`
306298
"""
307-
numpy_util.check_arraylike("log_sigmoid", x)
308-
x_arr = jnp.asarray(x)
299+
x_arr = numpy_util.ensure_arraylike("log_sigmoid", x)
309300
return -softplus(-x_arr)
310301

311302
@api.jit
@@ -330,8 +321,7 @@ def elu(x: ArrayLike, alpha: ArrayLike = 1.0) -> Array:
330321
See also:
331322
:func:`selu`
332323
"""
333-
numpy_util.check_arraylike("elu", x)
334-
x_arr = jnp.asarray(x)
324+
x_arr = numpy_util.ensure_arraylike("elu", x)
335325
return jnp.where(x_arr > 0,
336326
x_arr,
337327
alpha * jnp.expm1(jnp.where(x_arr > 0, 0., x_arr)))
@@ -360,8 +350,7 @@ def leaky_relu(x: ArrayLike, negative_slope: ArrayLike = 1e-2) -> Array:
360350
See also:
361351
:func:`relu`
362352
"""
363-
numpy_util.check_arraylike("leaky_relu", x)
364-
x_arr = jnp.asarray(x)
353+
x_arr = numpy_util.ensure_arraylike("leaky_relu", x)
365354
return jnp.where(x_arr >= 0, x_arr, negative_slope * x_arr)
366355

367356
@api.jit
@@ -383,8 +372,7 @@ def hard_tanh(x: ArrayLike) -> Array:
383372
Returns:
384373
An array.
385374
"""
386-
numpy_util.check_arraylike("hard_tanh", x)
387-
x_arr = jnp.asarray(x)
375+
x_arr = numpy_util.ensure_arraylike("hard_tanh", x)
388376
return jnp.where(x_arr > 1, 1, jnp.where(x_arr < -1, -1, x_arr))
389377

390378
@api.jit
@@ -504,8 +492,7 @@ def glu(x: ArrayLike, axis: int = -1) -> Array:
504492
See also:
505493
:func:`sigmoid`
506494
"""
507-
numpy_util.check_arraylike("glu", x)
508-
x_arr = jnp.asarray(x)
495+
x_arr = numpy_util.ensure_arraylike("glu", x)
509496
size = x_arr.shape[axis]
510497
assert size % 2 == 0, "axis size must be divisible by 2"
511498
x1, x2 = jnp.split(x_arr, 2, axis)
@@ -575,8 +562,7 @@ def log_softmax(x: ArrayLike,
575562
See also:
576563
:func:`softmax`
577564
"""
578-
numpy_util.check_arraylike("log_softmax", x)
579-
x_arr = jnp.asarray(x)
565+
x_arr = numpy_util.ensure_arraylike("log_softmax", x)
580566
x_max = jnp.max(x_arr, axis, where=where, initial=-np.inf, keepdims=True)
581567
x_safe = x_arr if where is None else jnp.where(where, x_arr, -np.inf)
582568
shifted = x_safe - lax.stop_gradient(x_max)
@@ -849,8 +835,7 @@ def hard_silu(x: ArrayLike) -> Array:
849835
See also:
850836
:func:`hard_sigmoid`
851837
"""
852-
numpy_util.check_arraylike("hard_silu", x)
853-
x_arr = jnp.asarray(x)
838+
x_arr = numpy_util.ensure_arraylike("hard_silu", x)
854839
return x_arr * hard_sigmoid(x_arr)
855840

856841
hard_swish = hard_silu
@@ -1496,8 +1481,7 @@ def log1mexp(x: ArrayLike) -> Array:
14961481
.. [1] Martin Mächler. `Accurately Computing log(1 − exp(−|a|)) Assessed by the Rmpfr package.
14971482
<https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf>`_.
14981483
"""
1499-
numpy_util.check_arraylike("log1mexp", x)
1500-
x = jnp.asarray(x)
1484+
x = numpy_util.ensure_arraylike("log1mexp", x)
15011485
c = jnp.log(2.0)
15021486
return jnp.where(
15031487
x < c,

0 commit comments

Comments
 (0)