File tree Expand file tree Collapse file tree 1 file changed +5
-2
lines changed Expand file tree Collapse file tree 1 file changed +5
-2
lines changed Original file line number Diff line number Diff line change @@ -57,6 +57,7 @@ def evaluate(variables_p, test_ds):
5757"""
5858from __future__ import annotations
5959from collections .abc import Mapping , Sequence
60+ import inspect
6061from typing import Any , TypeVar , Protocol
6162
6263from absl import logging
@@ -536,7 +537,8 @@ def empty(cls: type[C]) -> C:
536537 _reduction_counter = _ReductionCounter (jnp .array (1 , dtype = jnp .int32 )),
537538 ** {
538539 metric_name : metric .empty ()
539- for metric_name , metric in cls .__annotations__ .items ()
540+ for metric_name , metric
541+ in inspect .get_annotations (cls , eval_str = True ).items ()
540542 })
541543
542544 @classmethod
@@ -546,7 +548,8 @@ def _from_model_output(cls: type[C], **kwargs) -> C:
546548 _reduction_counter = _ReductionCounter (jnp .array (1 , dtype = jnp .int32 )),
547549 ** {
548550 metric_name : metric .from_model_output (** kwargs )
549- for metric_name , metric in cls .__annotations__ .items ()
551+ for metric_name , metric
552+ in inspect .get_annotations (cls , eval_str = True ).items ()
550553 })
551554
552555 @classmethod
You can’t perform that action at this time.
0 commit comments