From ad63a5f88e61e4a782c0cba2f9422eff4738cc38 Mon Sep 17 00:00:00 2001 From: Sergii Volodko Date: Sun, 28 Jun 2020 21:22:21 +0100 Subject: [PATCH 01/11] Add recurrent_internals.py --- sonnet/src/BUILD | 7 +++++++ sonnet/src/recurrent_internals.py | 0 2 files changed, 7 insertions(+) create mode 100644 sonnet/src/recurrent_internals.py diff --git a/sonnet/src/BUILD b/sonnet/src/BUILD index 9e793e92..1fb6df4d 100644 --- a/sonnet/src/BUILD +++ b/sonnet/src/BUILD @@ -366,6 +366,12 @@ snt_py_test( ], ) +snt_py_library( + name = "recurrent_internals", + srcs = ["recurrent_internals.py"], + deps = [], +) + snt_py_library( name = "recurrent", srcs = ["recurrent.py"], @@ -377,6 +383,7 @@ snt_py_library( ":once", ":types", ":utils", + ":recurrent_internals" # pip: six # pip: tensorflow # pip: tree diff --git a/sonnet/src/recurrent_internals.py b/sonnet/src/recurrent_internals.py new file mode 100644 index 00000000..e69de29b From ce662c376f2bca8d5864bbd6ed0c2304934ee0e1 Mon Sep 17 00:00:00 2001 From: Sergii Volodko Date: Sun, 28 Jun 2020 21:23:41 +0100 Subject: [PATCH 02/11] Move _check_inputs_dtype from recurrent to recurrent_internals --- sonnet/src/recurrent.py | 13 +++++++------ sonnet/src/recurrent_internals.py | 22 ++++++++++++++++++++++ 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/sonnet/src/recurrent.py b/sonnet/src/recurrent.py index d3067150..b02fe546 100644 --- a/sonnet/src/recurrent.py +++ b/sonnet/src/recurrent.py @@ -32,6 +32,7 @@ from sonnet.src import once from sonnet.src import types from sonnet.src import utils +from sonnet.src.recurrent_internals import _check_inputs_dtype import tensorflow.compat.v1 as tf1 import tensorflow as tf @@ -1719,9 +1720,9 @@ def _initialize(self, inputs): name="w_h") self.b = tf.Variable(self._b_init([3 * self._hidden_size], dtype), name="b") - -def _check_inputs_dtype(inputs, expected_dtype): - if inputs.dtype is not expected_dtype: - raise TypeError("inputs must have dtype {!r}, got {!r}".format( - expected_dtype, inputs.dtype)) - return expected_dtype +# +# def _check_inputs_dtype(inputs, expected_dtype): +# if inputs.dtype is not expected_dtype: +# raise TypeError("inputs must have dtype {!r}, got {!r}".format( +# expected_dtype, inputs.dtype)) +# return expected_dtype diff --git a/sonnet/src/recurrent_internals.py b/sonnet/src/recurrent_internals.py index e69de29b..80917c89 100644 --- a/sonnet/src/recurrent_internals.py +++ b/sonnet/src/recurrent_internals.py @@ -0,0 +1,22 @@ +# Copyright 2019 The Sonnet Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Utils for Recurrent Neural Network cores.""" + + +def _check_inputs_dtype(inputs, expected_dtype): + if inputs.dtype is not expected_dtype: + raise TypeError("inputs must have dtype {!r}, got {!r}".format( + expected_dtype, inputs.dtype)) + return expected_dtype \ No newline at end of file From 01fffc92b9aae954cba7045a17b61159f60faaa5 Mon Sep 17 00:00:00 2001 From: Sergii Volodko Date: Sun, 28 Jun 2020 21:44:57 +0100 Subject: [PATCH 03/11] Move _safe_where from recurrent to recurrent_internals --- sonnet/src/BUILD | 4 +++- sonnet/src/recurrent.py | 21 +-------------------- sonnet/src/recurrent_internals.py | 14 +++++++++++++- 3 files changed, 17 insertions(+), 22 deletions(-) diff --git a/sonnet/src/BUILD b/sonnet/src/BUILD index 1fb6df4d..933f99c6 100644 --- a/sonnet/src/BUILD +++ b/sonnet/src/BUILD @@ -369,7 +369,9 @@ snt_py_test( snt_py_library( name = "recurrent_internals", srcs = ["recurrent_internals.py"], - deps = [], + deps = [ + # pip: tensorflow + ], ) snt_py_library( diff --git a/sonnet/src/recurrent.py b/sonnet/src/recurrent.py index b02fe546..f29c9865 100644 --- a/sonnet/src/recurrent.py +++ b/sonnet/src/recurrent.py @@ -32,9 +32,8 @@ from sonnet.src import once from sonnet.src import types from sonnet.src import utils -from sonnet.src.recurrent_internals import _check_inputs_dtype +from sonnet.src.recurrent_internals import _check_inputs_dtype, _safe_where -import tensorflow.compat.v1 as tf1 import tensorflow as tf import tree @@ -435,17 +434,6 @@ def _unstack_input_sequence(input_sequence): lambda i: tf.TensorArray(i.dtype, num_steps).unstack(i), input_sequence) return num_steps, input_tas - -def _safe_where(condition, x, y): # pylint: disable=g-doc-args - """`tf.where` which allows scalar inputs.""" - if x.shape.rank == 0: - # This is to match the `tf.nn.*_rnn` behavior. In general, we might - # want to branch on `tf.reduce_all(condition)`. - return y - # TODO(tomhennigan) Broadcasting with SelectV2 is currently broken. - return tf1.where(condition, x, y) - - def _rnn_step(core, input_tas, sequence_length, t, prev_outputs, prev_state): """Performs a single RNN step optionally accounting for variable length.""" outputs, state = core( @@ -1719,10 +1707,3 @@ def _initialize(self, inputs): self._w_h_init([self._hidden_size, 3 * self._hidden_size], dtype), name="w_h") self.b = tf.Variable(self._b_init([3 * self._hidden_size], dtype), name="b") - -# -# def _check_inputs_dtype(inputs, expected_dtype): -# if inputs.dtype is not expected_dtype: -# raise TypeError("inputs must have dtype {!r}, got {!r}".format( -# expected_dtype, inputs.dtype)) -# return expected_dtype diff --git a/sonnet/src/recurrent_internals.py b/sonnet/src/recurrent_internals.py index 80917c89..17bcb56c 100644 --- a/sonnet/src/recurrent_internals.py +++ b/sonnet/src/recurrent_internals.py @@ -14,9 +14,21 @@ # ============================================================================ """Utils for Recurrent Neural Network cores.""" +import tensorflow.compat.v1 as tf1 + def _check_inputs_dtype(inputs, expected_dtype): if inputs.dtype is not expected_dtype: raise TypeError("inputs must have dtype {!r}, got {!r}".format( expected_dtype, inputs.dtype)) - return expected_dtype \ No newline at end of file + return expected_dtype + + +def _safe_where(condition, x, y): # pylint: disable=g-doc-args + """`tf.where` which allows scalar inputs.""" + if x.shape.rank == 0: + # This is to match the `tf.nn.*_rnn` behavior. In general, we might + # want to branch on `tf.reduce_all(condition)`. + return y + # TODO(tomhennigan) Broadcasting with SelectV2 is currently broken. + return tf1.where(condition, x, y) \ No newline at end of file From dd8413fd6e0cf4fa1c460b09321dc4f8b84916d1 Mon Sep 17 00:00:00 2001 From: Sergii Volodko Date: Sun, 28 Jun 2020 22:02:16 +0100 Subject: [PATCH 04/11] Move _unstack_input_sequence from recurrent to recurrent_internals --- sonnet/src/BUILD | 1 + sonnet/src/recurrent.py | 42 +--------------------------- sonnet/src/recurrent_internals.py | 46 ++++++++++++++++++++++++++++++- 3 files changed, 47 insertions(+), 42 deletions(-) diff --git a/sonnet/src/BUILD b/sonnet/src/BUILD index 933f99c6..789823ec 100644 --- a/sonnet/src/BUILD +++ b/sonnet/src/BUILD @@ -371,6 +371,7 @@ snt_py_library( srcs = ["recurrent_internals.py"], deps = [ # pip: tensorflow + # pip: tree ], ) diff --git a/sonnet/src/recurrent.py b/sonnet/src/recurrent.py index f29c9865..a4d72a43 100644 --- a/sonnet/src/recurrent.py +++ b/sonnet/src/recurrent.py @@ -32,7 +32,7 @@ from sonnet.src import once from sonnet.src import types from sonnet.src import utils -from sonnet.src.recurrent_internals import _check_inputs_dtype, _safe_where +from sonnet.src.recurrent_internals import _check_inputs_dtype, _safe_where, _unstack_input_sequence import tensorflow as tf import tree @@ -394,46 +394,6 @@ def dynamic_unroll( return output_sequence, state -def _unstack_input_sequence(input_sequence): - r"""Unstacks the input sequence into a nest of :tf:`TensorArray`\ s. - - This allows to traverse the input sequence using :tf:`TensorArray.read` - instead of a slice, avoiding O(sliced tensor) slice gradient - computation during the backwards pass. - - Args: - input_sequence: See :func:`dynamic_unroll` or :func:`static_unroll`. - - Returns: - num_steps: Number of steps in the input sequence. - input_tas: An arbitrarily nested structure of :tf:`TensorArray`\ s of - size ``num_steps``. - - Raises: - ValueError: If tensors in ``input_sequence`` have inconsistent number - of steps or the number of steps is 0. - """ - flat_input_sequence = tree.flatten(input_sequence) - all_num_steps = {i.shape[0] for i in flat_input_sequence} - if len(all_num_steps) > 1: - raise ValueError( - "input_sequence tensors must have consistent number of time steps") - [num_steps] = all_num_steps - if num_steps == 0: - raise ValueError("input_sequence must have at least a single time step") - elif num_steps is None: - # Number of steps is not known statically, fall back to dynamic shape. - num_steps = tf.shape(flat_input_sequence[0])[0] - # TODO(b/141910613): uncomment when the bug is fixed. - # for i in flat_input_sequence[1:]: - # tf.debugging.assert_equal( - # tf.shape(i)[0], num_steps, - # "input_sequence tensors must have consistent number of time steps") - - input_tas = tree.map_structure( - lambda i: tf.TensorArray(i.dtype, num_steps).unstack(i), input_sequence) - return num_steps, input_tas - def _rnn_step(core, input_tas, sequence_length, t, prev_outputs, prev_state): """Performs a single RNN step optionally accounting for variable length.""" outputs, state = core( diff --git a/sonnet/src/recurrent_internals.py b/sonnet/src/recurrent_internals.py index 17bcb56c..616e2955 100644 --- a/sonnet/src/recurrent_internals.py +++ b/sonnet/src/recurrent_internals.py @@ -15,6 +15,8 @@ """Utils for Recurrent Neural Network cores.""" import tensorflow.compat.v1 as tf1 +import tensorflow as tf +import tree def _check_inputs_dtype(inputs, expected_dtype): @@ -31,4 +33,46 @@ def _safe_where(condition, x, y): # pylint: disable=g-doc-args # want to branch on `tf.reduce_all(condition)`. return y # TODO(tomhennigan) Broadcasting with SelectV2 is currently broken. - return tf1.where(condition, x, y) \ No newline at end of file + return tf1.where(condition, x, y) + + + +def _unstack_input_sequence(input_sequence): + r"""Unstacks the input sequence into a nest of :tf:`TensorArray`\ s. + + This allows to traverse the input sequence using :tf:`TensorArray.read` + instead of a slice, avoiding O(sliced tensor) slice gradient + computation during the backwards pass. + + Args: + input_sequence: See :func:`dynamic_unroll` or :func:`static_unroll`. + + Returns: + num_steps: Number of steps in the input sequence. + input_tas: An arbitrarily nested structure of :tf:`TensorArray`\ s of + size ``num_steps``. + + Raises: + ValueError: If tensors in ``input_sequence`` have inconsistent number + of steps or the number of steps is 0. + """ + flat_input_sequence = tree.flatten(input_sequence) + all_num_steps = {i.shape[0] for i in flat_input_sequence} + if len(all_num_steps) > 1: + raise ValueError( + "input_sequence tensors must have consistent number of time steps") + [num_steps] = all_num_steps + if num_steps == 0: + raise ValueError("input_sequence must have at least a single time step") + elif num_steps is None: + # Number of steps is not known statically, fall back to dynamic shape. + num_steps = tf.shape(flat_input_sequence[0])[0] + # TODO(b/141910613): uncomment when the bug is fixed. + # for i in flat_input_sequence[1:]: + # tf.debugging.assert_equal( + # tf.shape(i)[0], num_steps, + # "input_sequence tensors must have consistent number of time steps") + + input_tas = tree.map_structure( + lambda i: tf.TensorArray(i.dtype, num_steps).unstack(i), input_sequence) + return num_steps, input_tas \ No newline at end of file From 08e19983612c0dacc220fdd5c09540a218c0c070 Mon Sep 17 00:00:00 2001 From: Sergii Volodko Date: Sun, 28 Jun 2020 22:18:27 +0100 Subject: [PATCH 05/11] Move _specialize_per_device from recurrent to recurrent_internals --- sonnet/src/recurrent.py | 68 +------------------------------ sonnet/src/recurrent_internals.py | 68 ++++++++++++++++++++++++++++++- 2 files changed, 68 insertions(+), 68 deletions(-) diff --git a/sonnet/src/recurrent.py b/sonnet/src/recurrent.py index a4d72a43..b4aebeb5 100644 --- a/sonnet/src/recurrent.py +++ b/sonnet/src/recurrent.py @@ -22,7 +22,6 @@ import abc import collections import functools -import uuid import six from sonnet.src import base @@ -32,19 +31,13 @@ from sonnet.src import once from sonnet.src import types from sonnet.src import utils -from sonnet.src.recurrent_internals import _check_inputs_dtype, _safe_where, _unstack_input_sequence +from sonnet.src.recurrent_internals import _check_inputs_dtype, _safe_where, _unstack_input_sequence, _specialize_per_device import tensorflow as tf import tree from typing import Optional, Sequence, Text, Tuple, Union -# pylint: disable=g-direct-tensorflow-import -# Required for specializing `UnrolledLSTM` per device. -from tensorflow.python import context as context_lib -from tensorflow.python.eager import function as function_lib -# pylint: enable=g-direct-tensorflow-import - @six.add_metaclass(abc.ABCMeta) class RNNCore(base.Module): @@ -941,65 +934,6 @@ def _initialize(self, input_sequence): self.b = tf.Variable(tf.concat([b_i, b_f, b_g, b_o], axis=0), name="b") -# TODO(b/133740216): consider upstreaming into TensorFlow. -def _specialize_per_device(api_name, specializations, default): - """Create a :tf:`function` specialized per-device. - - Args: - api_name: Name of the function, e.g. ``"lstm"``. - specializations: A mapping from device type (e.g. ``"CPU"`` or ``"TPU``) to - a Python function with a specialized implementation for that device. - default: Default device type to use (typically, ``"CPU"``). - - Returns: - A :tf:`function` which when called dispatches to the specialization - for the current device. - """ - # Cached to avoid redundant ``ModuleWrapper.__getattribute__`` calls. - list_logical_devices = tf.config.experimental.list_logical_devices - - def wrapper(*args, **kwargs): - """Specialized {}. - - In eager mode the specialization is chosen based on the current - device context or, if no device context is active, on availability - of a GPU. - - In graph mode (inside tf.function) the choice is delegated to the - implementation selector pass in Grappler. - - Args: - *args: Positional arguments to pass to the chosen specialization. - **kwargs: Keyword arguments to pass to the chosen specialization. - """.format(api_name) - ctx = context_lib.context() - if ctx.executing_eagerly(): - device = ctx.device_spec.device_type - if device is None: - # Soft-placement will never implicitly place an op an a TPU, so - # we only need to consider CPU/GPU. - device = "GPU" if list_logical_devices("GPU") else "CPU" - - specialization = specializations.get(device) or specializations[default] - return specialization(*args, **kwargs) - - # Implementation selector requires a globally unique name for each - # .register() call. - unique_api_name = "{}_{}".format(api_name, uuid.uuid4()) - functions = {} - for device, specialization in specializations.items(): - functions[device] = function_lib.defun_with_attributes( - specialization, - attributes={ - "api_implements": unique_api_name, - "api_preferred_device": device - }) - function_lib.register(functions[device], *args, **kwargs) - return functions[default](*args, **kwargs) - - return wrapper - - def _fallback_unrolled_lstm(input_sequence, initial_state, w_i, w_h, b): """Fallback version of :class:`UnrolledLSTM` which works on any device.""" return dynamic_unroll( diff --git a/sonnet/src/recurrent_internals.py b/sonnet/src/recurrent_internals.py index 616e2955..bd64524f 100644 --- a/sonnet/src/recurrent_internals.py +++ b/sonnet/src/recurrent_internals.py @@ -14,10 +14,17 @@ # ============================================================================ """Utils for Recurrent Neural Network cores.""" +import uuid + import tensorflow.compat.v1 as tf1 import tensorflow as tf import tree +# pylint: disable=g-direct-tensorflow-import +# Required for specializing `UnrolledLSTM` per device. +from tensorflow.python import context as context_lib +from tensorflow.python.eager import function as function_lib +# pylint: enable=g-direct-tensorflow-import def _check_inputs_dtype(inputs, expected_dtype): if inputs.dtype is not expected_dtype: @@ -75,4 +82,63 @@ def _unstack_input_sequence(input_sequence): input_tas = tree.map_structure( lambda i: tf.TensorArray(i.dtype, num_steps).unstack(i), input_sequence) - return num_steps, input_tas \ No newline at end of file + return num_steps, input_tas + + +# TODO(b/133740216): consider upstreaming into TensorFlow. +def _specialize_per_device(api_name, specializations, default): + """Create a :tf:`function` specialized per-device. + + Args: + api_name: Name of the function, e.g. ``"lstm"``. + specializations: A mapping from device type (e.g. ``"CPU"`` or ``"TPU``) to + a Python function with a specialized implementation for that device. + default: Default device type to use (typically, ``"CPU"``). + + Returns: + A :tf:`function` which when called dispatches to the specialization + for the current device. + """ + # Cached to avoid redundant ``ModuleWrapper.__getattribute__`` calls. + list_logical_devices = tf.config.experimental.list_logical_devices + + def wrapper(*args, **kwargs): + """Specialized {}. + + In eager mode the specialization is chosen based on the current + device context or, if no device context is active, on availability + of a GPU. + + In graph mode (inside tf.function) the choice is delegated to the + implementation selector pass in Grappler. + + Args: + *args: Positional arguments to pass to the chosen specialization. + **kwargs: Keyword arguments to pass to the chosen specialization. + """.format(api_name) + ctx = context_lib.context() + if ctx.executing_eagerly(): + device = ctx.device_spec.device_type + if device is None: + # Soft-placement will never implicitly place an op an a TPU, so + # we only need to consider CPU/GPU. + device = "GPU" if list_logical_devices("GPU") else "CPU" + + specialization = specializations.get(device) or specializations[default] + return specialization(*args, **kwargs) + + # Implementation selector requires a globally unique name for each + # .register() call. + unique_api_name = "{}_{}".format(api_name, uuid.uuid4()) + functions = {} + for device, specialization in specializations.items(): + functions[device] = function_lib.defun_with_attributes( + specialization, + attributes={ + "api_implements": unique_api_name, + "api_preferred_device": device + }) + function_lib.register(functions[device], *args, **kwargs) + return functions[default](*args, **kwargs) + + return wrapper \ No newline at end of file From 48b88e09b38cd9b2d4a6fdfcbfb1b9c94680bc7a Mon Sep 17 00:00:00 2001 From: Sergii Volodko Date: Sun, 28 Jun 2020 22:49:35 +0100 Subject: [PATCH 06/11] Move _rnn_step from recurrent to recurrent_internals --- sonnet/src/recurrent.py | 22 +-------------------- sonnet/src/recurrent_internals.py | 33 +++++++++++++++++++++++++------ 2 files changed, 28 insertions(+), 27 deletions(-) diff --git a/sonnet/src/recurrent.py b/sonnet/src/recurrent.py index b4aebeb5..84dd07e4 100644 --- a/sonnet/src/recurrent.py +++ b/sonnet/src/recurrent.py @@ -31,7 +31,7 @@ from sonnet.src import once from sonnet.src import types from sonnet.src import utils -from sonnet.src.recurrent_internals import _check_inputs_dtype, _safe_where, _unstack_input_sequence, _specialize_per_device +from sonnet.src.recurrent_internals import _check_inputs_dtype, _unstack_input_sequence, _specialize_per_device, _rnn_step import tensorflow as tf import tree @@ -387,26 +387,6 @@ def dynamic_unroll( return output_sequence, state -def _rnn_step(core, input_tas, sequence_length, t, prev_outputs, prev_state): - """Performs a single RNN step optionally accounting for variable length.""" - outputs, state = core( - tree.map_structure(lambda i: i.read(t), input_tas), prev_state) - - if prev_outputs is None: - assert t == 0 - prev_outputs = tree.map_structure(tf.zeros_like, outputs) - - # TODO(slebedev): do not go into this block if t < min_len. - if sequence_length is not None: - # Selectively propagate outputs/state to the not-yet-finished - # sequences. - maybe_propagate = functools.partial(_safe_where, t >= sequence_length) - outputs = tree.map_structure(maybe_propagate, prev_outputs, outputs) - state = tree.map_structure(maybe_propagate, prev_state, state) - - return outputs, state - - class VanillaRNN(RNNCore): """Basic fully-connected RNN core. diff --git a/sonnet/src/recurrent_internals.py b/sonnet/src/recurrent_internals.py index bd64524f..2ad76c58 100644 --- a/sonnet/src/recurrent_internals.py +++ b/sonnet/src/recurrent_internals.py @@ -14,6 +14,7 @@ # ============================================================================ """Utils for Recurrent Neural Network cores.""" +import functools import uuid import tensorflow.compat.v1 as tf1 @@ -26,11 +27,25 @@ from tensorflow.python.eager import function as function_lib # pylint: enable=g-direct-tensorflow-import -def _check_inputs_dtype(inputs, expected_dtype): - if inputs.dtype is not expected_dtype: - raise TypeError("inputs must have dtype {!r}, got {!r}".format( - expected_dtype, inputs.dtype)) - return expected_dtype + +def _rnn_step(core, input_tas, sequence_length, t, prev_outputs, prev_state): + """Performs a single RNN step optionally accounting for variable length.""" + outputs, state = core( + tree.map_structure(lambda i: i.read(t), input_tas), prev_state) + + if prev_outputs is None: + assert t == 0 + prev_outputs = tree.map_structure(tf.zeros_like, outputs) + + # TODO(slebedev): do not go into this block if t < min_len. + if sequence_length is not None: + # Selectively propagate outputs/state to the not-yet-finished + # sequences. + maybe_propagate = functools.partial(_safe_where, t >= sequence_length) + outputs = tree.map_structure(maybe_propagate, prev_outputs, outputs) + state = tree.map_structure(maybe_propagate, prev_state, state) + + return outputs, state def _safe_where(condition, x, y): # pylint: disable=g-doc-args @@ -43,6 +58,12 @@ def _safe_where(condition, x, y): # pylint: disable=g-doc-args return tf1.where(condition, x, y) +def _check_inputs_dtype(inputs, expected_dtype): + if inputs.dtype is not expected_dtype: + raise TypeError("inputs must have dtype {!r}, got {!r}".format( + expected_dtype, inputs.dtype)) + return expected_dtype + def _unstack_input_sequence(input_sequence): r"""Unstacks the input sequence into a nest of :tf:`TensorArray`\ s. @@ -141,4 +162,4 @@ def wrapper(*args, **kwargs): function_lib.register(functions[device], *args, **kwargs) return functions[default](*args, **kwargs) - return wrapper \ No newline at end of file + return wrapper From f1555dc555755f942e5e3050726b444955af0971 Mon Sep 17 00:00:00 2001 From: Sergii Volodko Date: Sun, 28 Jun 2020 23:01:46 +0100 Subject: [PATCH 07/11] Move LSTMState definition from recurrent to recurrent_internals --- sonnet/src/recurrent.py | 6 +----- sonnet/src/recurrent_internals.py | 4 ++++ 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/sonnet/src/recurrent.py b/sonnet/src/recurrent.py index 84dd07e4..251ef3c5 100644 --- a/sonnet/src/recurrent.py +++ b/sonnet/src/recurrent.py @@ -20,7 +20,6 @@ from __future__ import print_function import abc -import collections import functools import six @@ -31,7 +30,7 @@ from sonnet.src import once from sonnet.src import types from sonnet.src import utils -from sonnet.src.recurrent_internals import _check_inputs_dtype, _unstack_input_sequence, _specialize_per_device, _rnn_step +from sonnet.src.recurrent_internals import _check_inputs_dtype, _unstack_input_sequence, _specialize_per_device, _rnn_step, LSTMState import tensorflow as tf import tree @@ -665,9 +664,6 @@ def deep_rnn_with_residual_connections( name=name) -LSTMState = collections.namedtuple("LSTMState", ["hidden", "cell"]) - - class LSTM(RNNCore): r"""Long short-term memory (LSTM) RNN core. diff --git a/sonnet/src/recurrent_internals.py b/sonnet/src/recurrent_internals.py index 2ad76c58..920dccdf 100644 --- a/sonnet/src/recurrent_internals.py +++ b/sonnet/src/recurrent_internals.py @@ -14,6 +14,7 @@ # ============================================================================ """Utils for Recurrent Neural Network cores.""" +import collections import functools import uuid @@ -28,6 +29,9 @@ # pylint: enable=g-direct-tensorflow-import +LSTMState = collections.namedtuple("LSTMState", ["hidden", "cell"]) + + def _rnn_step(core, input_tas, sequence_length, t, prev_outputs, prev_state): """Performs a single RNN step optionally accounting for variable length.""" outputs, state = core( From 5855ea8594dc7e0a15eef14f71e4446be0b81c00 Mon Sep 17 00:00:00 2001 From: Sergii Volodko Date: Sun, 28 Jun 2020 23:15:15 +0100 Subject: [PATCH 08/11] Move _lstm_fn from recurrent to recurrent_internals --- sonnet/src/recurrent.py | 25 +++++-------------------- sonnet/src/recurrent_internals.py | 19 +++++++++++++++++++ 2 files changed, 24 insertions(+), 20 deletions(-) diff --git a/sonnet/src/recurrent.py b/sonnet/src/recurrent.py index 251ef3c5..85323cef 100644 --- a/sonnet/src/recurrent.py +++ b/sonnet/src/recurrent.py @@ -30,7 +30,8 @@ from sonnet.src import once from sonnet.src import types from sonnet.src import utils -from sonnet.src.recurrent_internals import _check_inputs_dtype, _unstack_input_sequence, _specialize_per_device, _rnn_step, LSTMState +from sonnet.src.recurrent_internals import _check_inputs_dtype, _unstack_input_sequence, _specialize_per_device, _rnn_step, _lstm_fn +from sonnet.src.recurrent_internals import LSTMState as LSTMState_definition import tensorflow as tf import tree @@ -664,6 +665,9 @@ def deep_rnn_with_residual_connections( name=name) +LSTMState = LSTMState_definition + + class LSTM(RNNCore): r"""Long short-term memory (LSTM) RNN core. @@ -808,25 +812,6 @@ def _initialize(self, inputs): name="projection") -def _lstm_fn(inputs, prev_state, w_i, w_h, b, projection=None): - """Compute one step of an LSTM.""" - gates_x = tf.matmul(inputs, w_i) - gates_h = tf.matmul(prev_state.hidden, w_h) - gates = gates_x + gates_h + b - - # i = input, f = forget, g = cell updates, o = output. - i, f, g, o = tf.split(gates, num_or_size_splits=4, axis=1) - - next_cell = tf.sigmoid(f) * prev_state.cell - next_cell += tf.sigmoid(i) * tf.tanh(g) - next_hidden = tf.sigmoid(o) * tf.tanh(next_cell) - - if projection is not None: - next_hidden = tf.matmul(next_hidden, projection) - - return next_hidden, LSTMState(hidden=next_hidden, cell=next_cell) - - class UnrolledLSTM(UnrolledRNN): """Unrolled long short-term memory (LSTM). diff --git a/sonnet/src/recurrent_internals.py b/sonnet/src/recurrent_internals.py index 920dccdf..d6ab18a3 100644 --- a/sonnet/src/recurrent_internals.py +++ b/sonnet/src/recurrent_internals.py @@ -32,6 +32,25 @@ LSTMState = collections.namedtuple("LSTMState", ["hidden", "cell"]) +def _lstm_fn(inputs, prev_state, w_i, w_h, b, projection=None): + """Compute one step of an LSTM.""" + gates_x = tf.matmul(inputs, w_i) + gates_h = tf.matmul(prev_state.hidden, w_h) + gates = gates_x + gates_h + b + + # i = input, f = forget, g = cell updates, o = output. + i, f, g, o = tf.split(gates, num_or_size_splits=4, axis=1) + + next_cell = tf.sigmoid(f) * prev_state.cell + next_cell += tf.sigmoid(i) * tf.tanh(g) + next_hidden = tf.sigmoid(o) * tf.tanh(next_cell) + + if projection is not None: + next_hidden = tf.matmul(next_hidden, projection) + + return next_hidden, LSTMState(hidden=next_hidden, cell=next_cell) + + def _rnn_step(core, input_tas, sequence_length, t, prev_outputs, prev_state): """Performs a single RNN step optionally accounting for variable length.""" outputs, state = core( From 6e3f454fa09a5510934caf7fd153de9cb2016f59 Mon Sep 17 00:00:00 2001 From: Sergii Volodko Date: Sun, 28 Jun 2020 23:16:09 +0100 Subject: [PATCH 09/11] Add to recurrent_internals __future__ imports --- sonnet/src/recurrent_internals.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sonnet/src/recurrent_internals.py b/sonnet/src/recurrent_internals.py index d6ab18a3..e20c8e5e 100644 --- a/sonnet/src/recurrent_internals.py +++ b/sonnet/src/recurrent_internals.py @@ -14,6 +14,10 @@ # ============================================================================ """Utils for Recurrent Neural Network cores.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + import collections import functools import uuid From e543f55f63c8bbf9cfec2ee86444f878c740a43e Mon Sep 17 00:00:00 2001 From: Sergii Volodko Date: Sun, 28 Jun 2020 23:36:28 +0100 Subject: [PATCH 10/11] Align recurrent_internals usage with utils one --- sonnet/src/recurrent.py | 33 +++++++++++++++---------------- sonnet/src/recurrent_internals.py | 10 +++++----- 2 files changed, 21 insertions(+), 22 deletions(-) diff --git a/sonnet/src/recurrent.py b/sonnet/src/recurrent.py index 85323cef..1829c877 100644 --- a/sonnet/src/recurrent.py +++ b/sonnet/src/recurrent.py @@ -30,8 +30,7 @@ from sonnet.src import once from sonnet.src import types from sonnet.src import utils -from sonnet.src.recurrent_internals import _check_inputs_dtype, _unstack_input_sequence, _specialize_per_device, _rnn_step, _lstm_fn -from sonnet.src.recurrent_internals import LSTMState as LSTMState_definition +from sonnet.src import recurrent_internals import tensorflow as tf import tree @@ -253,7 +252,7 @@ def static_unroll( ValueError: If ``input_sequence`` is empty or its leading dimension is not known statically. """ - num_steps, input_tas = _unstack_input_sequence(input_sequence) + num_steps, input_tas = recurrent_internals.unstack_input_sequence(input_sequence) if not isinstance(num_steps, six.integer_types): raise ValueError( "input_sequence must have a statically known number of time steps") @@ -262,7 +261,7 @@ def static_unroll( state = initial_state output_accs = None for t in six.moves.range(num_steps): - outputs, state = _rnn_step( + outputs, state = recurrent_internals.rnn_step( core, input_tas, sequence_length, @@ -353,10 +352,10 @@ def dynamic_unroll( Raises: ValueError: If ``input_sequence`` is empty. """ - num_steps, input_tas = _unstack_input_sequence(input_sequence) + num_steps, input_tas = recurrent_internals.unstack_input_sequence(input_sequence) # Unroll the first time step separately to infer outputs structure. - outputs, state = _rnn_step( + outputs, state = recurrent_internals.rnn_step( core, input_tas, sequence_length, @@ -373,7 +372,7 @@ def dynamic_unroll( parallel_iterations=parallel_iterations, swap_memory=swap_memory, maximum_iterations=num_steps - 1) - outputs, state = _rnn_step( + outputs, state = recurrent_internals.rnn_step( core, input_tas, sequence_length, @@ -467,7 +466,7 @@ def initial_state(self, batch_size: int) -> tf.Tensor: @once.once def _initialize(self, inputs: tf.Tensor): - dtype = _check_inputs_dtype(inputs, self._dtype) + dtype = recurrent_internals.check_inputs_dtype(inputs, self._dtype) self._b = tf.Variable(self._b_init([self._hidden_size], dtype), name="b") @@ -665,7 +664,7 @@ def deep_rnn_with_residual_connections( name=name) -LSTMState = LSTMState_definition +LSTMState = recurrent_internals.LSTMState class LSTM(RNNCore): @@ -762,7 +761,7 @@ def __init__(self, def __call__(self, inputs, prev_state): """See base class.""" self._initialize(inputs) - return _lstm_fn(inputs, prev_state, self._w_i, self._w_h, self.b, + return recurrent_internals.lstm_fn(inputs, prev_state, self._w_i, self._w_h, self.b, self.projection) def initial_state(self, batch_size: int) -> LSTMState: @@ -783,7 +782,7 @@ def hidden_to_hidden(self): def _initialize(self, inputs): utils.assert_rank(inputs, 2) input_size = inputs.shape[1] - dtype = _check_inputs_dtype(inputs, self._dtype) + dtype = recurrent_internals.check_inputs_dtype(inputs, self._dtype) w_i_init = self._w_i_init or initializers.TruncatedNormal( stddev=1.0 / tf.sqrt(tf.cast(input_size, dtype))) @@ -878,7 +877,7 @@ def hidden_to_hidden(self): def _initialize(self, input_sequence): utils.assert_rank(input_sequence, 3) # [num_steps, batch_size, input_size]. input_size = input_sequence.shape[2] - dtype = _check_inputs_dtype(input_sequence, self._dtype) + dtype = recurrent_internals.check_inputs_dtype(input_sequence, self._dtype) w_i_init = self._w_i_init or initializers.TruncatedNormal( stddev=1.0 / tf.sqrt(tf.cast(input_size, dtype))) @@ -898,7 +897,7 @@ def _initialize(self, input_sequence): def _fallback_unrolled_lstm(input_sequence, initial_state, w_i, w_h, b): """Fallback version of :class:`UnrolledLSTM` which works on any device.""" return dynamic_unroll( - functools.partial(_lstm_fn, w_i=w_i, w_h=w_h, b=b), input_sequence, + functools.partial(recurrent_internals.lstm_fn, w_i=w_i, w_h=w_h, b=b), input_sequence, initial_state) @@ -949,7 +948,7 @@ def _cudnn_unrolled_lstm(input_sequence, initial_state, w_i, w_h, b): if hasattr(tf.raw_ops, "BlockLSTMV2"): _unrolled_lstm_impls["CPU"] = _block_unrolled_lstm -_specialized_unrolled_lstm = _specialize_per_device( +_specialized_unrolled_lstm = recurrent_internals.specialize_per_device( "snt_unrolled_lstm", specializations=_unrolled_lstm_impls, default="TPU") @@ -1193,7 +1192,7 @@ def initial_state(self, batch_size): @once.once def _initialize(self, inputs): - dtype = _check_inputs_dtype(inputs, self._dtype) + dtype = recurrent_internals.check_inputs_dtype(inputs, self._dtype) b_i, b_f, b_g, b_o = tf.split( self._b_init([4 * self._output_channels], dtype), num_or_size_splits=4) b_f += self._forget_bias @@ -1446,7 +1445,7 @@ def hidden_to_hidden(self): def _initialize(self, inputs): utils.assert_rank(inputs, 2) input_size = inputs.shape[1] - dtype = _check_inputs_dtype(inputs, self._dtype) + dtype = recurrent_internals.check_inputs_dtype(inputs, self._dtype) self._w_i = tf.Variable( self._w_i_init([input_size, 3 * self._hidden_size], dtype), name="w_i") self._w_h = tf.Variable( @@ -1555,7 +1554,7 @@ def initial_state(self, batch_size): def _initialize(self, inputs): utils.assert_rank(inputs, 3) # [num_steps, batch_size, input_size]. input_size = inputs.shape[2] - dtype = _check_inputs_dtype(inputs, self._dtype) + dtype = recurrent_internals.check_inputs_dtype(inputs, self._dtype) self._w_i = tf.Variable( self._w_i_init([input_size, 3 * self._hidden_size], dtype), name="w_i") self._w_h = tf.Variable( diff --git a/sonnet/src/recurrent_internals.py b/sonnet/src/recurrent_internals.py index e20c8e5e..58db164c 100644 --- a/sonnet/src/recurrent_internals.py +++ b/sonnet/src/recurrent_internals.py @@ -36,7 +36,7 @@ LSTMState = collections.namedtuple("LSTMState", ["hidden", "cell"]) -def _lstm_fn(inputs, prev_state, w_i, w_h, b, projection=None): +def lstm_fn(inputs, prev_state, w_i, w_h, b, projection=None): """Compute one step of an LSTM.""" gates_x = tf.matmul(inputs, w_i) gates_h = tf.matmul(prev_state.hidden, w_h) @@ -55,7 +55,7 @@ def _lstm_fn(inputs, prev_state, w_i, w_h, b, projection=None): return next_hidden, LSTMState(hidden=next_hidden, cell=next_cell) -def _rnn_step(core, input_tas, sequence_length, t, prev_outputs, prev_state): +def rnn_step(core, input_tas, sequence_length, t, prev_outputs, prev_state): """Performs a single RNN step optionally accounting for variable length.""" outputs, state = core( tree.map_structure(lambda i: i.read(t), input_tas), prev_state) @@ -85,14 +85,14 @@ def _safe_where(condition, x, y): # pylint: disable=g-doc-args return tf1.where(condition, x, y) -def _check_inputs_dtype(inputs, expected_dtype): +def check_inputs_dtype(inputs, expected_dtype): if inputs.dtype is not expected_dtype: raise TypeError("inputs must have dtype {!r}, got {!r}".format( expected_dtype, inputs.dtype)) return expected_dtype -def _unstack_input_sequence(input_sequence): +def unstack_input_sequence(input_sequence): r"""Unstacks the input sequence into a nest of :tf:`TensorArray`\ s. This allows to traverse the input sequence using :tf:`TensorArray.read` @@ -134,7 +134,7 @@ def _unstack_input_sequence(input_sequence): # TODO(b/133740216): consider upstreaming into TensorFlow. -def _specialize_per_device(api_name, specializations, default): +def specialize_per_device(api_name, specializations, default): """Create a :tf:`function` specialized per-device. Args: From 40487af2758382a31c91e6057f73bd692023edbb Mon Sep 17 00:00:00 2001 From: Sergii Volodko Date: Sun, 28 Jun 2020 23:36:28 +0100 Subject: [PATCH 11/11] Align recurrent_internals usage with utils one --- sonnet/src/recurrent.py | 33 +++++++++++++++---------------- sonnet/src/recurrent_internals.py | 10 +++++----- 2 files changed, 21 insertions(+), 22 deletions(-) diff --git a/sonnet/src/recurrent.py b/sonnet/src/recurrent.py index 85323cef..1829c877 100644 --- a/sonnet/src/recurrent.py +++ b/sonnet/src/recurrent.py @@ -30,8 +30,7 @@ from sonnet.src import once from sonnet.src import types from sonnet.src import utils -from sonnet.src.recurrent_internals import _check_inputs_dtype, _unstack_input_sequence, _specialize_per_device, _rnn_step, _lstm_fn -from sonnet.src.recurrent_internals import LSTMState as LSTMState_definition +from sonnet.src import recurrent_internals import tensorflow as tf import tree @@ -253,7 +252,7 @@ def static_unroll( ValueError: If ``input_sequence`` is empty or its leading dimension is not known statically. """ - num_steps, input_tas = _unstack_input_sequence(input_sequence) + num_steps, input_tas = recurrent_internals.unstack_input_sequence(input_sequence) if not isinstance(num_steps, six.integer_types): raise ValueError( "input_sequence must have a statically known number of time steps") @@ -262,7 +261,7 @@ def static_unroll( state = initial_state output_accs = None for t in six.moves.range(num_steps): - outputs, state = _rnn_step( + outputs, state = recurrent_internals.rnn_step( core, input_tas, sequence_length, @@ -353,10 +352,10 @@ def dynamic_unroll( Raises: ValueError: If ``input_sequence`` is empty. """ - num_steps, input_tas = _unstack_input_sequence(input_sequence) + num_steps, input_tas = recurrent_internals.unstack_input_sequence(input_sequence) # Unroll the first time step separately to infer outputs structure. - outputs, state = _rnn_step( + outputs, state = recurrent_internals.rnn_step( core, input_tas, sequence_length, @@ -373,7 +372,7 @@ def dynamic_unroll( parallel_iterations=parallel_iterations, swap_memory=swap_memory, maximum_iterations=num_steps - 1) - outputs, state = _rnn_step( + outputs, state = recurrent_internals.rnn_step( core, input_tas, sequence_length, @@ -467,7 +466,7 @@ def initial_state(self, batch_size: int) -> tf.Tensor: @once.once def _initialize(self, inputs: tf.Tensor): - dtype = _check_inputs_dtype(inputs, self._dtype) + dtype = recurrent_internals.check_inputs_dtype(inputs, self._dtype) self._b = tf.Variable(self._b_init([self._hidden_size], dtype), name="b") @@ -665,7 +664,7 @@ def deep_rnn_with_residual_connections( name=name) -LSTMState = LSTMState_definition +LSTMState = recurrent_internals.LSTMState class LSTM(RNNCore): @@ -762,7 +761,7 @@ def __init__(self, def __call__(self, inputs, prev_state): """See base class.""" self._initialize(inputs) - return _lstm_fn(inputs, prev_state, self._w_i, self._w_h, self.b, + return recurrent_internals.lstm_fn(inputs, prev_state, self._w_i, self._w_h, self.b, self.projection) def initial_state(self, batch_size: int) -> LSTMState: @@ -783,7 +782,7 @@ def hidden_to_hidden(self): def _initialize(self, inputs): utils.assert_rank(inputs, 2) input_size = inputs.shape[1] - dtype = _check_inputs_dtype(inputs, self._dtype) + dtype = recurrent_internals.check_inputs_dtype(inputs, self._dtype) w_i_init = self._w_i_init or initializers.TruncatedNormal( stddev=1.0 / tf.sqrt(tf.cast(input_size, dtype))) @@ -878,7 +877,7 @@ def hidden_to_hidden(self): def _initialize(self, input_sequence): utils.assert_rank(input_sequence, 3) # [num_steps, batch_size, input_size]. input_size = input_sequence.shape[2] - dtype = _check_inputs_dtype(input_sequence, self._dtype) + dtype = recurrent_internals.check_inputs_dtype(input_sequence, self._dtype) w_i_init = self._w_i_init or initializers.TruncatedNormal( stddev=1.0 / tf.sqrt(tf.cast(input_size, dtype))) @@ -898,7 +897,7 @@ def _initialize(self, input_sequence): def _fallback_unrolled_lstm(input_sequence, initial_state, w_i, w_h, b): """Fallback version of :class:`UnrolledLSTM` which works on any device.""" return dynamic_unroll( - functools.partial(_lstm_fn, w_i=w_i, w_h=w_h, b=b), input_sequence, + functools.partial(recurrent_internals.lstm_fn, w_i=w_i, w_h=w_h, b=b), input_sequence, initial_state) @@ -949,7 +948,7 @@ def _cudnn_unrolled_lstm(input_sequence, initial_state, w_i, w_h, b): if hasattr(tf.raw_ops, "BlockLSTMV2"): _unrolled_lstm_impls["CPU"] = _block_unrolled_lstm -_specialized_unrolled_lstm = _specialize_per_device( +_specialized_unrolled_lstm = recurrent_internals.specialize_per_device( "snt_unrolled_lstm", specializations=_unrolled_lstm_impls, default="TPU") @@ -1193,7 +1192,7 @@ def initial_state(self, batch_size): @once.once def _initialize(self, inputs): - dtype = _check_inputs_dtype(inputs, self._dtype) + dtype = recurrent_internals.check_inputs_dtype(inputs, self._dtype) b_i, b_f, b_g, b_o = tf.split( self._b_init([4 * self._output_channels], dtype), num_or_size_splits=4) b_f += self._forget_bias @@ -1446,7 +1445,7 @@ def hidden_to_hidden(self): def _initialize(self, inputs): utils.assert_rank(inputs, 2) input_size = inputs.shape[1] - dtype = _check_inputs_dtype(inputs, self._dtype) + dtype = recurrent_internals.check_inputs_dtype(inputs, self._dtype) self._w_i = tf.Variable( self._w_i_init([input_size, 3 * self._hidden_size], dtype), name="w_i") self._w_h = tf.Variable( @@ -1555,7 +1554,7 @@ def initial_state(self, batch_size): def _initialize(self, inputs): utils.assert_rank(inputs, 3) # [num_steps, batch_size, input_size]. input_size = inputs.shape[2] - dtype = _check_inputs_dtype(inputs, self._dtype) + dtype = recurrent_internals.check_inputs_dtype(inputs, self._dtype) self._w_i = tf.Variable( self._w_i_init([input_size, 3 * self._hidden_size], dtype), name="w_i") self._w_h = tf.Variable( diff --git a/sonnet/src/recurrent_internals.py b/sonnet/src/recurrent_internals.py index e20c8e5e..58db164c 100644 --- a/sonnet/src/recurrent_internals.py +++ b/sonnet/src/recurrent_internals.py @@ -36,7 +36,7 @@ LSTMState = collections.namedtuple("LSTMState", ["hidden", "cell"]) -def _lstm_fn(inputs, prev_state, w_i, w_h, b, projection=None): +def lstm_fn(inputs, prev_state, w_i, w_h, b, projection=None): """Compute one step of an LSTM.""" gates_x = tf.matmul(inputs, w_i) gates_h = tf.matmul(prev_state.hidden, w_h) @@ -55,7 +55,7 @@ def _lstm_fn(inputs, prev_state, w_i, w_h, b, projection=None): return next_hidden, LSTMState(hidden=next_hidden, cell=next_cell) -def _rnn_step(core, input_tas, sequence_length, t, prev_outputs, prev_state): +def rnn_step(core, input_tas, sequence_length, t, prev_outputs, prev_state): """Performs a single RNN step optionally accounting for variable length.""" outputs, state = core( tree.map_structure(lambda i: i.read(t), input_tas), prev_state) @@ -85,14 +85,14 @@ def _safe_where(condition, x, y): # pylint: disable=g-doc-args return tf1.where(condition, x, y) -def _check_inputs_dtype(inputs, expected_dtype): +def check_inputs_dtype(inputs, expected_dtype): if inputs.dtype is not expected_dtype: raise TypeError("inputs must have dtype {!r}, got {!r}".format( expected_dtype, inputs.dtype)) return expected_dtype -def _unstack_input_sequence(input_sequence): +def unstack_input_sequence(input_sequence): r"""Unstacks the input sequence into a nest of :tf:`TensorArray`\ s. This allows to traverse the input sequence using :tf:`TensorArray.read` @@ -134,7 +134,7 @@ def _unstack_input_sequence(input_sequence): # TODO(b/133740216): consider upstreaming into TensorFlow. -def _specialize_per_device(api_name, specializations, default): +def specialize_per_device(api_name, specializations, default): """Create a :tf:`function` specialized per-device. Args: