diff --git a/clu/metrics.py b/clu/metrics.py index e560179..73ec5a5 100644 --- a/clu/metrics.py +++ b/clu/metrics.py @@ -56,7 +56,7 @@ def evaluate(model, p_variables, test_ds): return ms.unreplicate().compute() """ -from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple, Type +from collections.abc import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union from absl import logging @@ -141,7 +141,16 @@ def empty(cls) -> "Metric": """Returns an empty instance (i.e. `.merge(Metric.empty())` is a no-op).""" raise NotImplementedError("Must override empty()") - def compute_value(self) -> clu.values.Value: + def compute_value( + self, + ) -> Union[ + clu.values.Value, + List[clu.values.Value], + Tuple[clu.values.Value], + Dict[str, clu.values.Value], + Dict[str, List[clu.values.Value]], + Dict[str, Tuple[clu.values.Value]], + ]: """Wraps compute() and returns a values.Value.""" return clu.values.Scalar(self.compute())