From be0d3d70f3de62ed32c10ceb6673fef6158b105f Mon Sep 17 00:00:00 2001 From: Marvin Ritter Date: Fri, 21 Jul 2023 00:55:59 -0700 Subject: [PATCH] Internal change. PiperOrigin-RevId: 549865766 --- clu/metric_writers/summary_writer.py | 87 +-------------- clu/metric_writers/tf/summary_writer.py | 104 ++++++++++++++++++ .../{ => tf}/summary_writer_test.py | 2 +- 3 files changed, 108 insertions(+), 85 deletions(-) create mode 100644 clu/metric_writers/tf/summary_writer.py rename clu/metric_writers/{ => tf}/summary_writer_test.py (99%) diff --git a/clu/metric_writers/summary_writer.py b/clu/metric_writers/summary_writer.py index 72e5ea8..60d6861 100644 --- a/clu/metric_writers/summary_writer.py +++ b/clu/metric_writers/summary_writer.py @@ -12,88 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""MetricWriter for writing to TF summary files. +"""MetricWriter for writing to TF summary files.""" +# pylint: disable=unused-import -Only works in eager mode. Does not work for Pytorch code, please use -TorchTensorboardWriter instead. -""" - -from typing import Any, Mapping, Optional -from absl import logging - - -from clu.internal import utils -from clu.metric_writers import interface -import tensorflow as tf - -from tensorboard.plugins.hparams import api as hparams_api - - -Array = interface.Array -Scalar = interface.Scalar - - -class SummaryWriter(interface.MetricWriter): - """MetricWriter that writes TF summary files.""" - - def __init__(self, logdir: str): - super().__init__() - self._summary_writer = tf.summary.create_file_writer(logdir) - - - def write_summaries( - self, step: int, - values: Mapping[str, Array], - metadata: Optional[Mapping[str, Any]] = None): - with self._summary_writer.as_default(): - for key, value in values.items(): - md = metadata.get(key) if metadata is not None else None - tf.summary.write(key, value, step=step, metadata=md) - - def write_scalars(self, step: int, scalars: Mapping[str, Scalar]): - with self._summary_writer.as_default(): - for key, value in scalars.items(): - tf.summary.scalar(key, value, step=step) - - def write_images(self, step: int, images: Mapping[str, Array]): - with self._summary_writer.as_default(): - for key, value in images.items(): - if len(value.shape) == 3: - value = value[None] - tf.summary.image(key, value, step=step, max_outputs=value.shape[0]) - - def write_videos(self, step: int, videos: Mapping[str, Array]): - logging.log_first_n( - logging.WARNING, - "SummaryWriter does not support writing videos.", 1) - - def write_audios( - self, step: int, audios: Mapping[str, Array], *, sample_rate: int): - with self._summary_writer.as_default(): - for key, value in audios.items(): - tf.summary.audio(key, value, sample_rate=sample_rate, step=step, - max_outputs=value.shape[0]) - - def write_texts(self, step: int, texts: Mapping[str, str]): - with self._summary_writer.as_default(): - for key, value in texts.items(): - tf.summary.text(key, value, step=step) - - def write_histograms(self, - step: int, - arrays: Mapping[str, Array], - num_buckets: Optional[Mapping[str, int]] = None): - with self._summary_writer.as_default(): - for key, value in arrays.items(): - buckets = None if num_buckets is None else num_buckets.get(key) - tf.summary.histogram(key, value, step=step, buckets=buckets) - - def write_hparams(self, hparams: Mapping[str, Any]): - with self._summary_writer.as_default(): - hparams_api.hparams(dict(utils.flatten_dict(hparams))) - - def flush(self): - self._summary_writer.flush() - - def close(self): - self._summary_writer.close() +from clu.metric_writers.summary_writer import SummaryWriter diff --git a/clu/metric_writers/tf/summary_writer.py b/clu/metric_writers/tf/summary_writer.py new file mode 100644 index 0000000..7fc5d91 --- /dev/null +++ b/clu/metric_writers/tf/summary_writer.py @@ -0,0 +1,104 @@ +# Copyright 2023 The CLU Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MetricWriter for writing to TF summary files. + +Only works in eager mode. Does not work for Pytorch code, please use +TorchTensorboardWriter instead. +""" + +from collections.abc import Mapping +from typing import Any + +from absl import logging + +from clu.internal import utils +from clu.metric_writers import interface +import tensorflow as tf + +from tensorboard.plugins.hparams import api as hparams_api + + +Array = interface.Array +Scalar = interface.Scalar + + +class SummaryWriter(interface.MetricWriter): + """MetricWriter that writes TF summary files.""" + + def __init__(self, logdir: str): + super().__init__() + self._summary_writer = tf.summary.create_file_writer(logdir) + + + def write_summaries( + self, + step: int, + values: Mapping[str, Array], + metadata: Mapping[str, Any] | None = None, + ): + with self._summary_writer.as_default(): + for key, value in values.items(): + md = metadata.get(key) if metadata is not None else None + tf.summary.write(key, value, step=step, metadata=md) + + def write_scalars(self, step: int, scalars: Mapping[str, Scalar]): + with self._summary_writer.as_default(): + for key, value in scalars.items(): + tf.summary.scalar(key, value, step=step) + + def write_images(self, step: int, images: Mapping[str, Array]): + with self._summary_writer.as_default(): + for key, value in images.items(): + if len(value.shape) == 3: + value = value[None] + tf.summary.image(key, value, step=step, max_outputs=value.shape[0]) + + def write_videos(self, step: int, videos: Mapping[str, Array]): + logging.log_first_n( + logging.WARNING, + "SummaryWriter does not support writing videos.", 1) + + def write_audios( + self, step: int, audios: Mapping[str, Array], *, sample_rate: int): + with self._summary_writer.as_default(): + for key, value in audios.items(): + tf.summary.audio(key, value, sample_rate=sample_rate, step=step, + max_outputs=value.shape[0]) + + def write_texts(self, step: int, texts: Mapping[str, str]): + with self._summary_writer.as_default(): + for key, value in texts.items(): + tf.summary.text(key, value, step=step) + + def write_histograms( + self, + step: int, + arrays: Mapping[str, Array], + num_buckets: Mapping[str, int] | None = None, + ): + with self._summary_writer.as_default(): + for key, value in arrays.items(): + buckets = None if num_buckets is None else num_buckets.get(key) + tf.summary.histogram(key, value, step=step, buckets=buckets) + + def write_hparams(self, hparams: Mapping[str, Any]): + with self._summary_writer.as_default(): + hparams_api.hparams(dict(utils.flatten_dict(hparams))) + + def flush(self): + self._summary_writer.flush() + + def close(self): + self._summary_writer.close() diff --git a/clu/metric_writers/summary_writer_test.py b/clu/metric_writers/tf/summary_writer_test.py similarity index 99% rename from clu/metric_writers/summary_writer_test.py rename to clu/metric_writers/tf/summary_writer_test.py index 11b88b1..f99e541 100644 --- a/clu/metric_writers/summary_writer_test.py +++ b/clu/metric_writers/tf/summary_writer_test.py @@ -17,7 +17,7 @@ import collections import os -from clu.metric_writers import summary_writer +from clu.metric_writers.tf import summary_writer import numpy as np import tensorflow as tf