Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions cond_xla.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import tensorflow as tf

use_xla = 0

def conditional_xla():
def decorator(func):
if use_xla==1:
return tf.function(experimental_compile=True,experimental_relax_shapes=True)(func)
else:
return func
return decorator

35 changes: 35 additions & 0 deletions fp16_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
import numpy as np


def float32_variable_storage_getter(getter, name, shape=None, dtype=None,
initializer=None, regularizer=None,
trainable=True,
*args, **kwargs):
"""Custom variable getter that forces trainable variables to be stored in
float32 precision and then casts them to the training precision.
"""
storage_dtype = tf.float32 if trainable else dtype
variable = getter(name, shape, dtype=storage_dtype,
initializer=initializer, regularizer=regularizer,
trainable=trainable,
*args, **kwargs)
if trainable and dtype != tf.float32:
variable = tf.cast(variable, dtype)
return variable

36 changes: 36 additions & 0 deletions gpu_environment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# coding=utf-8
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import tensorflow as tf
import numpy as np

def float32_variable_storage_getter(getter, name, shape=None, dtype=None,
initializer=None, regularizer=None,
trainable=True,
*args, **kwargs):
"""Custom variable getter that forces trainable variables to be stored in
float32 precision and then casts them to the training precision.
"""
storage_dtype = tf.float32 if trainable else dtype
variable = getter(name, shape, dtype=storage_dtype,
initializer=initializer, regularizer=regularizer,
trainable=trainable,
*args, **kwargs)
if trainable and dtype != tf.float32:
variable = tf.cast(variable, dtype)
return variable

def get_custom_getter(compute_type):
return float32_variable_storage_getter if compute_type == tf.float16 else None
46 changes: 31 additions & 15 deletions modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
import numpy as np
import six
import tensorflow as tf
tf.compat.v1.disable_resource_variables()
tf.compat.v1.disable_eager_execution()

from gpu_environment import get_custom_getter
from cond_xla import conditional_xla, use_xla

class BertConfig(object):
"""Configuration for `BertModel`."""
Expand Down Expand Up @@ -137,7 +137,8 @@ def __init__(self,
input_mask=None,
token_type_ids=None,
use_one_hot_embeddings=False,
scope=None):
scope=None,
compute_type=tf.float32):
"""Constructor for BertModel.

Args:
Expand Down Expand Up @@ -170,7 +171,7 @@ def __init__(self,
if token_type_ids is None:
token_type_ids = tf.zeros(shape=[batch_size, seq_length], dtype=tf.int32)

with tf.compat.v1.variable_scope(scope, default_name="bert"):
with tf.compat.v1.variable_scope(scope, default_name="bert", custom_getter=get_custom_getter(compute_type)):
with tf.compat.v1.variable_scope("embeddings"):
# Perform embedding lookup on the word ids.
(self.embedding_output, self.embedding_table) = embedding_lookup(
Expand Down Expand Up @@ -205,7 +206,8 @@ def __init__(self,
# Run the stacked transformer.
# `sequence_output` shape = [batch_size, seq_length, hidden_size].
self.all_encoder_layers = transformer_model(
input_tensor=self.embedding_output,
input_tensor=tf.saturate_cast(self.embedding_output, compute_type) \
if self.embedding_output.dtype!=compute_type else self.embedding_output,
attention_mask=attention_mask,
hidden_size=config.hidden_size,
num_hidden_layers=config.num_hidden_layers,
Expand Down Expand Up @@ -262,7 +264,7 @@ def get_embedding_output(self):
def get_embedding_table(self):
return self.embedding_table


@conditional_xla()
def gelu(x):
"""Gaussian Error Linear Unit.

Expand All @@ -274,12 +276,14 @@ def gelu(x):
Returns:
`x` with the GELU activation applied.
"""
try:
return tf.nn.gelu(x)
except:
cdf = 0.5 * (1.0 + tf.tanh(
if not use_xla:
try:
return tf.nn.gelu(x)
except:
pass
cdf = 0.5 * (1.0 + tf.tanh(
(np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
return x * cdf
return x * cdf


def get_activation(activation_string):
Expand Down Expand Up @@ -366,10 +370,22 @@ def dropout(input_tensor, dropout_prob):

def layer_norm(input_tensor, name=None):
"""Run layer normalization on the last dimension of the tensor."""
# return tf.contrib.layers.layer_norm(
# inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name)
return tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-12)(inputs=input_tensor)

param_shape = input_tensor.shape[-1:]
scale = tf.compat.v1.get_variable(name="scale", shape=param_shape,
initializer=tf.ones_initializer())
offset = tf.compat.v1.get_variable(name="offset", shape=param_shape,
initializer=tf.zeros_initializer())
@conditional_xla()
def batch_norm(t, s, o):
mean, variance = tf.nn.moments(t, axes=[-1], keepdims=True)
return tf.nn.batch_normalization(
t,
mean,
variance,
offset=o,
scale=s,
variance_epsilon=1e-12)
return batch_norm(input_tensor, scale, offset)

def layer_norm_and_dropout(input_tensor, dropout_prob, name=None):
"""Runs layer normalization followed by dropout."""
Expand Down
Loading