Skip to content

Commit a8152eb

Browse files
CLU Authorscopybara-github
authored andcommitted
Update clu to use inspect.get_annotations(cls) over cls.__annotations__
PiperOrigin-RevId: 723511644
1 parent 43acbbd commit a8152eb

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

clu/metrics.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def evaluate(variables_p, test_ds):
5757
"""
5858
from __future__ import annotations
5959
from collections.abc import Mapping, Sequence
60+
import inspect
6061
from typing import Any, TypeVar, Protocol
6162

6263
from absl import logging
@@ -536,7 +537,8 @@ def empty(cls: type[C]) -> C:
536537
_reduction_counter=_ReductionCounter(jnp.array(1, dtype=jnp.int32)),
537538
**{
538539
metric_name: metric.empty()
539-
for metric_name, metric in cls.__annotations__.items()
540+
for metric_name, metric
541+
in inspect.get_annotations(cls, eval_str=True).items()
540542
})
541543

542544
@classmethod
@@ -546,7 +548,8 @@ def _from_model_output(cls: type[C], **kwargs) -> C:
546548
_reduction_counter=_ReductionCounter(jnp.array(1, dtype=jnp.int32)),
547549
**{
548550
metric_name: metric.from_model_output(**kwargs)
549-
for metric_name, metric in cls.__annotations__.items()
551+
for metric_name, metric
552+
in inspect.get_annotations(cls, eval_str=True).items()
550553
})
551554

552555
@classmethod

0 commit comments

Comments
 (0)