diff --git a/optimization.py b/optimization.py index 2bb80964e..72f486f6f 100644 --- a/optimization.py +++ b/optimization.py @@ -153,7 +153,7 @@ def __init__(self, """Constructs a AdamWeightDecayOptimizer.""" super(AdamWeightDecayOptimizer, self).__init__(False, name) - self.learning_rate = learning_rate + self.learning_rate = tf.identity(learning_rate, name='learning_rate') self.weight_decay_rate = weight_decay_rate self.beta_1 = beta_1 self.beta_2 = beta_2 @@ -253,7 +253,7 @@ def __init__(self, """Constructs a LAMBOptimizer.""" super(LAMBOptimizer, self).__init__(False, name) - self.learning_rate = learning_rate + self.learning_rate = tf.identity(learning_rate, name='learning_rate') self.weight_decay_rate = weight_decay_rate self.beta_1 = beta_1 self.beta_2 = beta_2 @@ -369,7 +369,7 @@ def __init__(self, """Constructs a NadamWeightDecayOptimizer.""" super(NadamWeightDecayOptimizer, self).__init__(False, name) - self.learning_rate = learning_rate + self.learning_rate = tf.identity(learning_rate, name='learning_rate') self.weight_decay_rate = weight_decay_rate self.beta_1 = beta_1 self.beta_2 = beta_2 @@ -481,7 +481,7 @@ def __init__(self, """Constructs a NlamOptimizer.""" super(NlambOptimizer, self).__init__(False, name) - self.learning_rate = learning_rate + self.learning_rate = tf.identity(learning_rate, name='learning_rate') self.weight_decay_rate = weight_decay_rate self.beta_1 = beta_1 self.beta_2 = beta_2 diff --git a/run_pretraining.py b/run_pretraining.py index 4e7ac43b1..46086c281 100644 --- a/run_pretraining.py +++ b/run_pretraining.py @@ -24,6 +24,10 @@ import tensorflow as tf tf.compat.v1.disable_resource_variables() +import time +from tensorflow.python.training.summary_io import SummaryWriterCache +from tensorflow.core.framework.summary_pb2 import Summary + # Add Horovod to run_pretraining try: import horovod.tensorflow as hvd @@ -116,6 +120,63 @@ flags.DEFINE_string("optimizer_type", "adam", "Optimizer used for training - adam (default), lamb, nadam and nlamb") +flags.DEFINE_integer( + "num_report_steps", 10, + "How frequently should summary information be reported and recorded.") + + +class LogSessionRunHook(tf.estimator.SessionRunHook): + + def __init__(self, + global_batch_size, + num_report_steps=10, + output_dir=None): + self.global_batch_size = global_batch_size + self.num_report_steps = num_report_steps + self.output_dir=output_dir + self.summary_writer=None + + def begin(self): + if self.summary_writer is None and self.output_dir: + self.summary_writer = SummaryWriterCache.get(self.output_dir) + + def after_create_session(self, session, coord): + self.elapsed_secs = 0. + self.count = 0 + + def before_run(self, run_context): + self.t0 = time.time() + global_step = tf.compat.v1.train.get_global_step() + fetches = [global_step, 'learning_rate:0', 'total_loss:0', 'mlm_loss:0', 'nsp_loss:0'] + return tf.estimator.SessionRunArgs(fetches=fetches) + + def _log_and_record(self, global_step, learning_rate, total_loss, mlm_loss, nsp_loss): + time_per_step = self.elapsed_secs / self.count + throughput = self.global_batch_size / time_per_step + log_string = ' ' + log_string += 'Step = %6i'%(global_step) + log_string += ', throughput = %6.1f'%(throughput) + log_string += ', total_loss = %6.3f'%(total_loss) + log_string += ', mlm_oss = %6.4e'%(mlm_loss) + log_string += ', nsp_loss = %6.4e'%(nsp_loss) + log_string += ', learning_rate = %6.4e'%(learning_rate) + tf.compat.v1.logging.info(log_string) + + if self.summary_writer is not None: + throughput_summary = Summary(value=[Summary.Value(tag='throughput', simple_value=throughput)]) + self.summary_writer.add_summary(throughput_summary, global_step) + total_loss_summary = Summary(value=[Summary.Value(tag='total_loss', simple_value=total_loss)]) + self.summary_writer.add_summary(total_loss_summary, global_step) + + def after_run(self, run_context, run_values): + self.elapsed_secs += time.time() - self.t0 + self.count += 1 + global_step, learning_rate, total_loss, mlm_loss, nsp_loss = run_values.results[0:5] + if (global_step % self.num_report_steps) == 0: + self._log_and_record(global_step, learning_rate, total_loss, mlm_loss, nsp_loss) + self.elapsed_secs = 0. + self.count = 0 + def model_fn_builder(bert_config, init_checkpoint, learning_rate, num_train_steps, num_warmup_steps, use_tpu, @@ -158,6 +219,10 @@ def model_fn(features, labels, mode, params): # pylint: disable=unused-argument total_loss = masked_lm_loss + next_sentence_loss + masked_lm_loss = tf.identity(masked_lm_loss, name='mlm_loss') + next_sentence_loss = tf.identity(next_sentence_loss, name='nsp_loss') + total_loss = tf.identity(total_loss, name='total_loss') + tvars = tf.compat.v1.trainable_variables() initialized_variable_names = {} @@ -416,7 +481,10 @@ def _decode_record(record, name_to_features): def main(_): tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) - + + # disable the log messages from being printed twice + tf.compat.v1.get_logger().propagate = False + use_hvd = False if FLAGS.use_horovod and hvd != None: use_hvd = True @@ -468,7 +536,7 @@ def main(_): iterations_per_loop=FLAGS.iterations_per_loop, num_shards=FLAGS.num_tpu_cores, per_host_input_for_training=is_per_host), - log_step_count_steps=25, + log_step_count_steps=FLAGS.num_report_steps * FLAGS.iterations_per_loop, session_config=config) model_fn = model_fn_builder( @@ -499,10 +567,13 @@ def main(_): max_predictions_per_seq=FLAGS.max_predictions_per_seq, is_training=True) - hooks = None + hooks = [] + if (not use_hvd) or (hvd.rank() == 0): + global_batch_size = FLAGS.train_batch_size if not use_hvd else FLAGS.train_batch_size * hvd.size() + hooks.append(LogSessionRunHook(global_batch_size, FLAGS.num_report_steps, FLAGS.output_dir)) if use_hvd: # [HVD] Ensure all GPU's start with the same weights. - hooks = [hvd.BroadcastGlobalVariablesHook(0)] + hooks.append(hvd.BroadcastGlobalVariablesHook(0)) estimator.train(input_fn=train_input_fn, max_steps=FLAGS.num_train_steps, hooks=hooks) if FLAGS.do_eval: