From 858d66f4800d06a86fd1900bd4b807aadcdf5999 Mon Sep 17 00:00:00 2001 From: CLU Authors Date: Mon, 27 Mar 2023 11:13:08 -0700 Subject: [PATCH] Update clu.Metrics.compute_value() return type, to avoid signature-mismatch error. according to here: go/pax/metrics#compute-value-return-types PiperOrigin-RevId: 519776158 --- clu/metrics.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/clu/metrics.py b/clu/metrics.py index e560179..73ec5a5 100644 --- a/clu/metrics.py +++ b/clu/metrics.py @@ -56,7 +56,7 @@ def evaluate(model, p_variables, test_ds): return ms.unreplicate().compute() """ -from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple, Type +from collections.abc import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union from absl import logging @@ -141,7 +141,16 @@ def empty(cls) -> "Metric": """Returns an empty instance (i.e. `.merge(Metric.empty())` is a no-op).""" raise NotImplementedError("Must override empty()") - def compute_value(self) -> clu.values.Value: + def compute_value( + self, + ) -> Union[ + clu.values.Value, + List[clu.values.Value], + Tuple[clu.values.Value], + Dict[str, clu.values.Value], + Dict[str, List[clu.values.Value]], + Dict[str, Tuple[clu.values.Value]], + ]: """Wraps compute() and returns a values.Value.""" return clu.values.Scalar(self.compute())