From 95883d69a877e78e5aefb5bc8d4decaf3de1da57 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 10 Dec 2025 11:55:56 -0800 Subject: [PATCH] Use jnp for metrics to make arrays addressable in multi-task setting. PiperOrigin-RevId: 842819214 --- init2winit/model_lib/metrics.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/init2winit/model_lib/metrics.py b/init2winit/model_lib/metrics.py index 5f404ffe..ed525d0d 100644 --- a/init2winit/model_lib/metrics.py +++ b/init2winit/model_lib/metrics.py @@ -464,8 +464,8 @@ def average_ctc_loss(): @flax.struct.dataclass class _Metric(metrics.Metric): """Applies `fun` and computes the average.""" - total: np.float32 - weight: np.float32 + total: jnp.float32 + weight: jnp.float32 @classmethod def from_model_output(cls, normalized_loss, **_): @@ -492,8 +492,8 @@ def compute_wer(decoded, num_words = 0.0 if tokenizer_type == 'SPM': - decoded_lengths = np.sum(decoded_paddings == 0.0, axis=-1) - target_lengths = np.sum(target_paddings == 0.0, axis=-1) + decoded_lengths = jnp.sum(decoded_paddings == 0.0, axis=-1) + target_lengths = jnp.sum(target_paddings == 0.0, axis=-1) batch_size = targets.shape[0]