diff --git a/init2winit/model_lib/metrics.py b/init2winit/model_lib/metrics.py index 5f404ffe..ed525d0d 100644 --- a/init2winit/model_lib/metrics.py +++ b/init2winit/model_lib/metrics.py @@ -464,8 +464,8 @@ def average_ctc_loss(): @flax.struct.dataclass class _Metric(metrics.Metric): """Applies `fun` and computes the average.""" - total: np.float32 - weight: np.float32 + total: jnp.float32 + weight: jnp.float32 @classmethod def from_model_output(cls, normalized_loss, **_): @@ -492,8 +492,8 @@ def compute_wer(decoded, num_words = 0.0 if tokenizer_type == 'SPM': - decoded_lengths = np.sum(decoded_paddings == 0.0, axis=-1) - target_lengths = np.sum(target_paddings == 0.0, axis=-1) + decoded_lengths = jnp.sum(decoded_paddings == 0.0, axis=-1) + target_lengths = jnp.sum(target_paddings == 0.0, axis=-1) batch_size = targets.shape[0]