diff --git a/cond_xla.py b/cond_xla.py new file mode 100644 index 000000000..9921d5f31 --- /dev/null +++ b/cond_xla.py @@ -0,0 +1,12 @@ +import tensorflow as tf + +use_xla = 0 + +def conditional_xla(): + def decorator(func): + if use_xla==1: + return tf.function(experimental_compile=True,experimental_relax_shapes=True)(func) + else: + return func + return decorator + diff --git a/fp16_utils.py b/fp16_utils.py new file mode 100644 index 000000000..6b8bda985 --- /dev/null +++ b/fp16_utils.py @@ -0,0 +1,35 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import tensorflow as tf +import numpy as np + + +def float32_variable_storage_getter(getter, name, shape=None, dtype=None, + initializer=None, regularizer=None, + trainable=True, + *args, **kwargs): + """Custom variable getter that forces trainable variables to be stored in + float32 precision and then casts them to the training precision. + """ + storage_dtype = tf.float32 if trainable else dtype + variable = getter(name, shape, dtype=storage_dtype, + initializer=initializer, regularizer=regularizer, + trainable=trainable, + *args, **kwargs) + if trainable and dtype != tf.float32: + variable = tf.cast(variable, dtype) + return variable + diff --git a/gpu_environment.py b/gpu_environment.py new file mode 100644 index 000000000..948c3fa44 --- /dev/null +++ b/gpu_environment.py @@ -0,0 +1,36 @@ +# coding=utf-8 +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf +import numpy as np + +def float32_variable_storage_getter(getter, name, shape=None, dtype=None, + initializer=None, regularizer=None, + trainable=True, + *args, **kwargs): + """Custom variable getter that forces trainable variables to be stored in + float32 precision and then casts them to the training precision. + """ + storage_dtype = tf.float32 if trainable else dtype + variable = getter(name, shape, dtype=storage_dtype, + initializer=initializer, regularizer=regularizer, + trainable=trainable, + *args, **kwargs) + if trainable and dtype != tf.float32: + variable = tf.cast(variable, dtype) + return variable + +def get_custom_getter(compute_type): + return float32_variable_storage_getter if compute_type == tf.float16 else None diff --git a/modeling.py b/modeling.py index c523e2701..cffa748aa 100644 --- a/modeling.py +++ b/modeling.py @@ -26,9 +26,9 @@ import numpy as np import six import tensorflow as tf -tf.compat.v1.disable_resource_variables() -tf.compat.v1.disable_eager_execution() +from gpu_environment import get_custom_getter +from cond_xla import conditional_xla, use_xla class BertConfig(object): """Configuration for `BertModel`.""" @@ -137,7 +137,8 @@ def __init__(self, input_mask=None, token_type_ids=None, use_one_hot_embeddings=False, - scope=None): + scope=None, + compute_type=tf.float32): """Constructor for BertModel. Args: @@ -170,7 +171,7 @@ def __init__(self, if token_type_ids is None: token_type_ids = tf.zeros(shape=[batch_size, seq_length], dtype=tf.int32) - with tf.compat.v1.variable_scope(scope, default_name="bert"): + with tf.compat.v1.variable_scope(scope, default_name="bert", custom_getter=get_custom_getter(compute_type)): with tf.compat.v1.variable_scope("embeddings"): # Perform embedding lookup on the word ids. (self.embedding_output, self.embedding_table) = embedding_lookup( @@ -205,7 +206,8 @@ def __init__(self, # Run the stacked transformer. # `sequence_output` shape = [batch_size, seq_length, hidden_size]. self.all_encoder_layers = transformer_model( - input_tensor=self.embedding_output, + input_tensor=tf.saturate_cast(self.embedding_output, compute_type) \ + if self.embedding_output.dtype!=compute_type else self.embedding_output, attention_mask=attention_mask, hidden_size=config.hidden_size, num_hidden_layers=config.num_hidden_layers, @@ -262,7 +264,7 @@ def get_embedding_output(self): def get_embedding_table(self): return self.embedding_table - +@conditional_xla() def gelu(x): """Gaussian Error Linear Unit. @@ -274,12 +276,14 @@ def gelu(x): Returns: `x` with the GELU activation applied. """ - try: - return tf.nn.gelu(x) - except: - cdf = 0.5 * (1.0 + tf.tanh( + if not use_xla: + try: + return tf.nn.gelu(x) + except: + pass + cdf = 0.5 * (1.0 + tf.tanh( (np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3))))) - return x * cdf + return x * cdf def get_activation(activation_string): @@ -366,10 +370,22 @@ def dropout(input_tensor, dropout_prob): def layer_norm(input_tensor, name=None): """Run layer normalization on the last dimension of the tensor.""" - # return tf.contrib.layers.layer_norm( - # inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name) - return tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-12)(inputs=input_tensor) - + param_shape = input_tensor.shape[-1:] + scale = tf.compat.v1.get_variable(name="scale", shape=param_shape, + initializer=tf.ones_initializer()) + offset = tf.compat.v1.get_variable(name="offset", shape=param_shape, + initializer=tf.zeros_initializer()) + @conditional_xla() + def batch_norm(t, s, o): + mean, variance = tf.nn.moments(t, axes=[-1], keepdims=True) + return tf.nn.batch_normalization( + t, + mean, + variance, + offset=o, + scale=s, + variance_epsilon=1e-12) + return batch_norm(input_tensor, scale, offset) def layer_norm_and_dropout(input_tensor, dropout_prob, name=None): """Runs layer normalization followed by dropout.""" diff --git a/optimization.py b/optimization.py index 0fccccdca..0e9db1450 100644 --- a/optimization.py +++ b/optimization.py @@ -18,17 +18,19 @@ from __future__ import division from __future__ import print_function +import os import re import tensorflow as tf -tf.compat.v1.disable_resource_variables() -tf.compat.v1.disable_eager_execution() try: import horovod.tensorflow as hvd + from horovod.tensorflow.compression import Compression except: hvd = None -def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu, use_hvd=False, optimizer_type="adam"): +from cond_xla import conditional_xla + +def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu, use_hvd=False, optimizer_type="adam", use_fp16=False, manual_fp16=False): """Creates an optimizer training op.""" global_step = tf.compat.v1.train.get_or_create_global_step() @@ -43,7 +45,7 @@ def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu, power = 1.0 else: power = 0.5 - + learning_rate = tf.compat.v1.train.polynomial_decay( learning_rate, global_step, @@ -51,9 +53,6 @@ def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu, end_learning_rate=0.0, power=power, cycle=False) - # if use_hvd: - # # May want to scale learning rate by number of GPUs - # learning_rate *= hvd.size() # Implements linear warmup. I.e., if global_step < num_warmup_steps, the # learning rate will be `global_step/num_warmup_steps * init_lr`. @@ -82,7 +81,9 @@ def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu, beta_1=0.9, beta_2=0.999, epsilon=1e-6, - exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) + exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"], + use_fp16=use_fp16, + manual_fp16=manual_fp16) elif optimizer_type == "lamb": print("Initializing LAMB Optimizer") optimizer = LAMBOptimizer( @@ -110,15 +111,24 @@ def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu, beta_2=0.999, epsilon=1e-6, exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) - + +# if manual_fp16 or use_fp16: +# os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_GRAPH_REWRITE"] = "1" + +# if use_fp16: +# optimizer=tf.train.experimental.enable_mixed_precision_graph_rewrite(optimizer) + if use_hvd: # [HVD] Wrap the original optimizer by Horovod's distributed optimizer, which handles all the under the hood allreduce calls. # Notice Horovod only does synchronized parameter update. - optimizer = hvd.DistributedOptimizer(optimizer) + optimizer = hvd.DistributedOptimizer(optimizer, sparse_as_dense=True, compression=Compression.fp16 if (use_fp16 or manual_fp16) else Compression.none) if use_tpu: optimizer = tf.compat.v1.tpu.CrossShardOptimizer(optimizer) + if manual_fp16 or use_fp16: + optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(optimizer) + tvars = tf.compat.v1.trainable_variables() if use_hvd: # [HVD] Use distributed optimizer to compute gradients @@ -137,9 +147,27 @@ def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu, # However, `AdamWeightDecayOptimizer` doesn't do this. But if you use # a different optimizer, you should probably take this line out. new_global_step = global_step + 1 + new_global_step = tf.identity(new_global_step, name='step_update') train_op = tf.group(train_op, [global_step.assign(new_global_step)]) return train_op +@conditional_xla() +def apply_adam_with_decay(beta_1, beta_2, learning_rate, epsilon, do_decay, weight_decay_rate, grad, m, v, param): + # Standard Adam update. + next_m = ( + tf.multiply(beta_1, m) + tf.multiply(1.0 - beta_1, grad)) + next_v = ( + tf.multiply(beta_2, v) + tf.multiply(1.0 - beta_2, + tf.square(grad))) + update = next_m / (tf.sqrt(next_v) + epsilon) + if do_decay: + update += weight_decay_rate * param + + update_with_lr = learning_rate * update + + next_param = param - update_with_lr + return next_param, next_m, next_v + class AdamWeightDecayOptimizer(tf.compat.v1.train.Optimizer): """A basic Adam optimizer that includes "correct" L2 weight decay.""" @@ -150,17 +178,29 @@ def __init__(self, beta_2=0.999, epsilon=1e-6, exclude_from_weight_decay=None, - name="AdamWeightDecayOptimizer"): + name="AdamWeightDecayOptimizer", + use_fp16=False, + manual_fp16=False): """Constructs a AdamWeightDecayOptimizer.""" super(AdamWeightDecayOptimizer, self).__init__(False, name) +<<<<<<< HEAD self.learning_rate = tf.identity(learning_rate, name='learning_rate') self.weight_decay_rate = weight_decay_rate self.beta_1 = beta_1 self.beta_2 = beta_2 self.epsilon = epsilon +======= + self.learning_rate = tf.identity(learning_rate, name='learning_rate') + self.weight_decay_on = (weight_decay_rate!=0.0) + self.weight_decay_rate = tf.constant(weight_decay_rate) + self.beta_1 = tf.constant(beta_1) + self.beta_2 = tf.constant(beta_2) + self.epsilon = tf.constant(epsilon) +>>>>>>> Adding support of XLA self.exclude_from_weight_decay = exclude_from_weight_decay - + self.use_fp16 = use_fp16 + self.manual_fp16 = manual_fp16 def apply_gradients(self, grads_and_vars, global_step=None, name=None): """See base class.""" assignments = [] @@ -170,6 +210,17 @@ def apply_gradients(self, grads_and_vars, global_step=None, name=None): param_name = self._get_variable_name(param.name) + has_shadow = self.manual_fp16 and param.dtype.base_dtype != tf.float32 + if has_shadow: + # create shadow fp32 weights for fp16 variable + param_fp32 = tf.compat.v1.get_variable( + name=param_name + "/shadow", + dtype=tf.float32, + trainable=False, + initializer=tf.cast(param.initialized_value(),tf.float32)) + else: + param_fp32 = param + m = tf.compat.v1.get_variable( name=param_name + "/adam_m", shape=param.shape.as_list(), @@ -183,38 +234,24 @@ def apply_gradients(self, grads_and_vars, global_step=None, name=None): trainable=False, initializer=tf.compat.v1.zeros_initializer()) - # Standard Adam update. - next_m = ( - tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) - next_v = ( - tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, - tf.square(grad))) + next_param, next_m, next_v = apply_adam_with_decay(self.beta_1, + self.beta_2, self.learning_rate, self.epsilon, + self._do_use_weight_decay(param_name), + self.weight_decay_rate, grad, m, v, param) - update = next_m / (tf.sqrt(next_v) + self.epsilon) - - # Just adding the square of the weights to the loss function is *not* - # the correct way of using L2 regularization/weight decay with Adam, - # since that will interact with the m and v parameters in strange ways. - # - # Instead we want ot decay the weights in a manner that doesn't interact - # with the m/v parameters. This is equivalent to adding the square - # of the weights to the loss with plain (non-momentum) SGD. - if self._do_use_weight_decay(param_name): - update += self.weight_decay_rate * param - - update_with_lr = self.learning_rate * update - - next_param = param - update_with_lr + if has_shadow: + # cast shadow fp32 weights to fp16 and assign to trainable variable + param.assign(tf.cast(next_param, param.dtype.base_dtype)) assignments.extend( - [param.assign(next_param), + [param_fp32.assign(next_param), m.assign(next_m), v.assign(next_v)]) return tf.group(*assignments, name=name) def _do_use_weight_decay(self, param_name): """Whether to use L2 weight decay for `param_name`.""" - if not self.weight_decay_rate: + if not self.weight_decay_on: return False if self.exclude_from_weight_decay: for r in self.exclude_from_weight_decay: @@ -229,6 +266,49 @@ def _get_variable_name(self, param_name): param_name = m.group(1) return param_name +@conditional_xla() +def apply_lamb(beta_1, beta_2, learning_rate, epsilon, do_decay, weight_decay_rate, grad, m, v, param): + # Standard Adam update. + next_m = ( + tf.multiply(beta_1, m) + tf.multiply(1.0 - beta_1, grad)) + next_v = ( + tf.multiply(beta_2, v) + tf.multiply(1.0 - beta_2, + tf.square(grad))) + + update = next_m / (tf.sqrt(next_v) + epsilon) + + # Just adding the square of the weights to the loss function is *not* + # the correct way of using L2 regularization/weight decay with Adam, + # since that will interact with the m and v parameters in strange ways. + # + # Instead we want ot decay the weights in a manner that doesn't interact + # with the m/v parameters. This is equivalent to adding the square + # of the weights to the loss with plain (non-momentum) SGD. + if do_decay: + update += weight_decay_rate * param + + ############## BELOW ARE THE SPECIFIC PARTS FOR LAMB ############## + + # Note: Here are two choices for scaling function \phi(z) + # minmax: \phi(z) = min(max(z, \gamma_l), \gamma_u) + # identity: \phi(z) = z + # The authors does not mention what is \gamma_l and \gamma_u + # UPDATE: after asking authors, they provide me the code below. + # ratio = array_ops.where(math_ops.greater(w_norm, 0), array_ops.where( + # math_ops.greater(g_norm, 0), (w_norm / g_norm), 1.0), 1.0) + + r1 = tf.sqrt(tf.reduce_sum(input_tensor=tf.square(param))) + r2 = tf.sqrt(tf.reduce_sum(input_tensor=tf.square(update))) + + r = tf.compat.v1.where(tf.greater(r1, 0.0), tf.compat.v1.where( + tf.greater(r2, 0.0), r1/r2, 1.0), 1.0) + + eta = learning_rate * r + + update_with_lr = eta * update + + next_param = param - update_with_lr + return next_param, next_m, next_v class LAMBOptimizer(tf.compat.v1.train.Optimizer): """ @@ -282,47 +362,10 @@ def apply_gradients(self, grads_and_vars, global_step=None, name=None): dtype=tf.float32, trainable=False, initializer=tf.compat.v1.zeros_initializer()) - - # Standard Adam update. - next_m = ( - tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) - next_v = ( - tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, - tf.square(grad))) - - update = next_m / (tf.sqrt(next_v) + self.epsilon) - - # Just adding the square of the weights to the loss function is *not* - # the correct way of using L2 regularization/weight decay with Adam, - # since that will interact with the m and v parameters in strange ways. - # - # Instead we want ot decay the weights in a manner that doesn't interact - # with the m/v parameters. This is equivalent to adding the square - # of the weights to the loss with plain (non-momentum) SGD. - if self._do_use_weight_decay(param_name): - update += self.weight_decay_rate * param - - ############## BELOW ARE THE SPECIFIC PARTS FOR LAMB ############## - - # Note: Here are two choices for scaling function \phi(z) - # minmax: \phi(z) = min(max(z, \gamma_l), \gamma_u) - # identity: \phi(z) = z - # The authors does not mention what is \gamma_l and \gamma_u - # UPDATE: after asking authors, they provide me the code below. - # ratio = array_ops.where(math_ops.greater(w_norm, 0), array_ops.where( - # math_ops.greater(g_norm, 0), (w_norm / g_norm), 1.0), 1.0) - - r1 = tf.sqrt(tf.reduce_sum(input_tensor=tf.square(param))) - r2 = tf.sqrt(tf.reduce_sum(input_tensor=tf.square(update))) - - r = tf.compat.v1.where(tf.greater(r1, 0.0), tf.compat.v1.where( - tf.greater(r2, 0.0), r1/r2, 1.0), 1.0) - - eta = self.learning_rate * r - - update_with_lr = eta * update - - next_param = param - update_with_lr + next_param, next_m, next_v = apply_lamb(self.beta_1, self.beta_2, + self.learning_rate, self.epsilon, + self._do_use_weight_decay(param_name), + self.weight_decay_rate, grad, m, v, param) assignments.extend( [param.assign(next_param), diff --git a/run_pretraining.py b/run_pretraining.py index d071d30f4..a57ee5ee7 100644 --- a/run_pretraining.py +++ b/run_pretraining.py @@ -19,15 +19,13 @@ from __future__ import print_function import os -import modeling -import optimization +import sys +import time import tensorflow as tf -tf.compat.v1.disable_resource_variables() -tf.compat.v1.disable_eager_execution() - import time from tensorflow.python.training.summary_io import SummaryWriterCache from tensorflow.core.framework.summary_pb2 import Summary +from tensorflow.core.protobuf import rewriter_config_pb2 # Add Horovod to run_pretraining try: @@ -37,7 +35,19 @@ flags = tf.compat.v1.flags -FLAGS = flags.FLAGS + +import multiprocessing.spawn +_old_preparation_data = multiprocessing.spawn.get_preparation_data + +def _patched_preparation_data(name): + try: + return _old_preparation_data(name) + except AttributeError: + main_module = sys.modules['__main__'] + # Any string for __spec__ does the job + main_module.__spec__ = '' + return _old_preparation_data(name) +multiprocessing.spawn.get_preparation_data = _patched_preparation_data ## Required parameters flags.DEFINE_string( @@ -58,6 +68,9 @@ "init_checkpoint", None, "Initial checkpoint (usually from a pre-trained BERT model).") +flags.DEFINE_string( + "eval_filename", None, "Eval output filename") + flags.DEFINE_integer( "max_seq_length", 128, "The maximum total input sequence length after WordPiece tokenization. " @@ -83,7 +96,7 @@ flags.DEFINE_integer("num_warmup_steps", 10000, "Number of warmup steps.") -flags.DEFINE_integer("save_checkpoints_steps", 1000, +flags.DEFINE_integer("save_checkpoints_steps", 25000, "How often to save the model checkpoint.") flags.DEFINE_integer("iterations_per_loop", 1000, @@ -93,6 +106,8 @@ flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") +flags.DEFINE_integer("use_xla", 0, "XLA optimizations: 0 - off, 1 - restricted, 2 - full") + flags.DEFINE_string( "tpu_name", None, "The Cloud TPU to use for training. This should be either the name " @@ -116,7 +131,11 @@ flags.DEFINE_integer( "num_tpu_cores", 8, "Only used if `use_tpu` is True. Total number of TPU cores to use.") - + +flags.DEFINE_bool("use_fp16", False, "Whether to use FP16 via AMP.") + +flags.DEFINE_bool("manual_fp16", False, "Whether to use NVIDIA manual FP16.") + flags.DEFINE_bool("use_horovod", False, "Whether to use Horovod.") flags.DEFINE_string("optimizer_type", "adam", "Optimizer used for training - adam (default), lamb, nadam and nlamb") @@ -125,6 +144,15 @@ "num_report_steps", 10, "How frequently should summary information be reported and recorded.") +FLAGS = flags.FLAGS + +import cond_xla +# must be done _before_ we import modeling or optimization +# (otherwise use_xla won't stick) +cond_xla.use_xla = FLAGS.use_xla + +import modeling +import optimization class LogSessionRunHook(tf.estimator.SessionRunHook): @@ -186,7 +214,6 @@ def after_run(self, run_context, run_values): "Only used if `enable_timeline` is True. " "Generate timeline for every Nth step.") - def model_fn_builder(bert_config, init_checkpoint, learning_rate, num_train_steps, num_warmup_steps, use_tpu, use_one_hot_embeddings, use_hvd): @@ -215,7 +242,8 @@ def model_fn(features, labels, mode, params): # pylint: disable=unused-argument input_ids=input_ids, input_mask=input_mask, token_type_ids=segment_ids, - use_one_hot_embeddings=use_one_hot_embeddings) + use_one_hot_embeddings=use_one_hot_embeddings, + compute_type=tf.float16 if FLAGS.manual_fp16 else tf.float32) (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs) = get_masked_lm_output( @@ -226,12 +254,10 @@ def model_fn(features, labels, mode, params): # pylint: disable=unused-argument next_sentence_log_probs) = get_next_sentence_output( bert_config, model.get_pooled_output(), next_sentence_labels) + masked_lm_loss = tf.identity(masked_lm_loss, name="mlm_loss") + next_sentence_loss = tf.identity(next_sentence_loss, name="nsp_loss") total_loss = masked_lm_loss + next_sentence_loss - - masked_lm_loss = tf.identity(masked_lm_loss, name='mlm_loss') - next_sentence_loss = tf.identity(next_sentence_loss, name='nsp_loss') total_loss = tf.identity(total_loss, name='total_loss') - tvars = tf.compat.v1.trainable_variables() initialized_variable_names = {} @@ -250,17 +276,22 @@ def tpu_scaffold(): tf.compat.v1.train.init_from_checkpoint(init_checkpoint, assignment_map) tf.compat.v1.logging.info("**** Trainable Variables ****") + row = 0 for var in tvars: init_string = "" if var.name in initialized_variable_names: init_string = ", *INIT_FROM_CKPT*" - tf.compat.v1.logging.info(" name = %s, shape = %s%s", var.name, var.shape, + if row<10 or row>len(tvars)-10: + tf.compat.v1.logging.info(" name = %s, shape = %s%s", var.name, var.shape, init_string) + if row==10: + tf.compat.v1.logging.info("...") + row+=1 output_spec = None if mode == tf.estimator.ModeKeys.TRAIN: train_op = optimization.create_optimizer( - total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu, use_hvd, FLAGS.optimizer_type) + total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu, use_hvd, FLAGS.optimizer_type, use_fp16=FLAGS.use_fp16, manual_fp16=FLAGS.manual_fp16) output_spec = tf.compat.v1.estimator.tpu.TPUEstimatorSpec( mode=mode, @@ -454,9 +485,6 @@ def input_fn(params): d = d.shuffle(buffer_size=100) else: d = tf.data.TFRecordDataset(input_files) - # Since we evaluate for a fixed number of steps we don't want to encounter - # out-of-range exceptions. - d = d.repeat() # We must `drop_remainder` on training because the TPU requires fixed # size dimensions. For eval, we assume we are evaluating on the CPU or GPU @@ -487,13 +515,11 @@ def _decode_record(record, name_to_features): return example - def main(_): tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) # disable the log messages from being printed twice tf.compat.v1.get_logger().propagate = False - use_hvd = False if FLAGS.use_horovod and hvd != None: use_hvd = True @@ -512,7 +538,6 @@ def main(_): raise ValueError("At least one of `do_train` or `do_eval` must be True.") bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) - tf.io.gfile.makedirs(FLAGS.output_dir) input_files = [] @@ -520,8 +545,11 @@ def main(_): input_files.extend(tf.io.gfile.glob(input_pattern)) tf.compat.v1.logging.info("*** Input Files ***") + row=0 for input_file in input_files: - tf.compat.v1.logging.info(" %s" % input_file) + if row<10 or row>len(input_files)-10: + tf.compat.v1.logging.info(" %s" % input_file) + row+=1 tpu_cluster_resolver = None if FLAGS.use_tpu and FLAGS.tpu_name: @@ -530,10 +558,12 @@ def main(_): is_per_host = tf.compat.v1.estimator.tpu.InputPipelineConfig.PER_HOST_V2 - config = None + config = tf.compat.v1.ConfigProto() + if FLAGS.use_xla==2: + config.graph_options.optimizer_options.global_jit_level = tf.compat.v1.OptimizerOptions.ON_1 + config.graph_options.rewrite_options.memory_optimization = rewriter_config_pb2.RewriterConfig.NO_MEM_OPT if use_hvd: # [HVD] Pin each worker to a GPU (make sure one worker uses only one GPU). - config = tf.compat.v1.ConfigProto() config.gpu_options.visible_device_list = str(hvd.local_rank()) run_config = tf.compat.v1.estimator.tpu.RunConfig( @@ -605,9 +635,9 @@ def main(_): is_training=False) result = estimator.evaluate( - input_fn=eval_input_fn, steps=FLAGS.max_eval_steps) + input_fn=eval_input_fn, steps=FLAGS.max_eval_steps if FLAGS.max_eval_steps>0 else None) - output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt") + output_eval_file = os.path.join(FLAGS.output_dir, FLAGS.eval_filename or "eval_results.txt") with tf.io.gfile.GFile(output_eval_file, "w") as writer: tf.compat.v1.logging.info("***** Eval results *****") for key in sorted(result.keys()):