Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def __init__(self,
"""Constructs a AdamWeightDecayOptimizer."""
super(AdamWeightDecayOptimizer, self).__init__(False, name)

self.learning_rate = learning_rate
self.learning_rate = tf.identity(learning_rate, name='learning_rate')
self.weight_decay_rate = weight_decay_rate
self.beta_1 = beta_1
self.beta_2 = beta_2
Expand Down Expand Up @@ -253,7 +253,7 @@ def __init__(self,
"""Constructs a LAMBOptimizer."""
super(LAMBOptimizer, self).__init__(False, name)

self.learning_rate = learning_rate
self.learning_rate = tf.identity(learning_rate, name='learning_rate')
self.weight_decay_rate = weight_decay_rate
self.beta_1 = beta_1
self.beta_2 = beta_2
Expand Down Expand Up @@ -369,7 +369,7 @@ def __init__(self,
"""Constructs a NadamWeightDecayOptimizer."""
super(NadamWeightDecayOptimizer, self).__init__(False, name)

self.learning_rate = learning_rate
self.learning_rate = tf.identity(learning_rate, name='learning_rate')
self.weight_decay_rate = weight_decay_rate
self.beta_1 = beta_1
self.beta_2 = beta_2
Expand Down Expand Up @@ -481,7 +481,7 @@ def __init__(self,
"""Constructs a NlamOptimizer."""
super(NlambOptimizer, self).__init__(False, name)

self.learning_rate = learning_rate
self.learning_rate = tf.identity(learning_rate, name='learning_rate')
self.weight_decay_rate = weight_decay_rate
self.beta_1 = beta_1
self.beta_2 = beta_2
Expand Down
79 changes: 75 additions & 4 deletions run_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
import tensorflow as tf
tf.compat.v1.disable_resource_variables()

import time
from tensorflow.python.training.summary_io import SummaryWriterCache
from tensorflow.core.framework.summary_pb2 import Summary

# Add Horovod to run_pretraining
try:
import horovod.tensorflow as hvd
Expand Down Expand Up @@ -116,6 +120,63 @@

flags.DEFINE_string("optimizer_type", "adam", "Optimizer used for training - adam (default), lamb, nadam and nlamb")

flags.DEFINE_integer(
"num_report_steps", 10,
"How frequently should summary information be reported and recorded.")


class LogSessionRunHook(tf.estimator.SessionRunHook):

def __init__(self,
global_batch_size,
num_report_steps=10,
output_dir=None):
self.global_batch_size = global_batch_size
self.num_report_steps = num_report_steps
self.output_dir=output_dir
self.summary_writer=None

def begin(self):
if self.summary_writer is None and self.output_dir:
self.summary_writer = SummaryWriterCache.get(self.output_dir)

def after_create_session(self, session, coord):
self.elapsed_secs = 0.
self.count = 0

def before_run(self, run_context):
self.t0 = time.time()
global_step = tf.compat.v1.train.get_global_step()
fetches = [global_step, 'learning_rate:0', 'total_loss:0', 'mlm_loss:0', 'nsp_loss:0']
return tf.estimator.SessionRunArgs(fetches=fetches)

def _log_and_record(self, global_step, learning_rate, total_loss, mlm_loss, nsp_loss):
time_per_step = self.elapsed_secs / self.count
throughput = self.global_batch_size / time_per_step
log_string = ' '
log_string += 'Step = %6i'%(global_step)
log_string += ', throughput = %6.1f'%(throughput)
log_string += ', total_loss = %6.3f'%(total_loss)
log_string += ', mlm_oss = %6.4e'%(mlm_loss)
log_string += ', nsp_loss = %6.4e'%(nsp_loss)
log_string += ', learning_rate = %6.4e'%(learning_rate)
tf.compat.v1.logging.info(log_string)

if self.summary_writer is not None:
throughput_summary = Summary(value=[Summary.Value(tag='throughput', simple_value=throughput)])
self.summary_writer.add_summary(throughput_summary, global_step)
total_loss_summary = Summary(value=[Summary.Value(tag='total_loss', simple_value=total_loss)])
self.summary_writer.add_summary(total_loss_summary, global_step)

def after_run(self, run_context, run_values):
self.elapsed_secs += time.time() - self.t0
self.count += 1
global_step, learning_rate, total_loss, mlm_loss, nsp_loss = run_values.results[0:5]
if (global_step % self.num_report_steps) == 0:
self._log_and_record(global_step, learning_rate, total_loss, mlm_loss, nsp_loss)
self.elapsed_secs = 0.
self.count = 0


def model_fn_builder(bert_config, init_checkpoint, learning_rate,
num_train_steps, num_warmup_steps, use_tpu,
Expand Down Expand Up @@ -158,6 +219,10 @@ def model_fn(features, labels, mode, params): # pylint: disable=unused-argument

total_loss = masked_lm_loss + next_sentence_loss

masked_lm_loss = tf.identity(masked_lm_loss, name='mlm_loss')
next_sentence_loss = tf.identity(next_sentence_loss, name='nsp_loss')
total_loss = tf.identity(total_loss, name='total_loss')

tvars = tf.compat.v1.trainable_variables()

initialized_variable_names = {}
Expand Down Expand Up @@ -416,7 +481,10 @@ def _decode_record(record, name_to_features):

def main(_):
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)


# disable the log messages from being printed twice
tf.compat.v1.get_logger().propagate = False

use_hvd = False
if FLAGS.use_horovod and hvd != None:
use_hvd = True
Expand Down Expand Up @@ -468,7 +536,7 @@ def main(_):
iterations_per_loop=FLAGS.iterations_per_loop,
num_shards=FLAGS.num_tpu_cores,
per_host_input_for_training=is_per_host),
log_step_count_steps=25,
log_step_count_steps=FLAGS.num_report_steps * FLAGS.iterations_per_loop,
session_config=config)

model_fn = model_fn_builder(
Expand Down Expand Up @@ -499,10 +567,13 @@ def main(_):
max_predictions_per_seq=FLAGS.max_predictions_per_seq,
is_training=True)

hooks = None
hooks = []
if (not use_hvd) or (hvd.rank() == 0):
global_batch_size = FLAGS.train_batch_size if not use_hvd else FLAGS.train_batch_size * hvd.size()
hooks.append(LogSessionRunHook(global_batch_size, FLAGS.num_report_steps, FLAGS.output_dir))
if use_hvd:
# [HVD] Ensure all GPU's start with the same weights.
hooks = [hvd.BroadcastGlobalVariablesHook(0)]
hooks.append(hvd.BroadcastGlobalVariablesHook(0))
estimator.train(input_fn=train_input_fn, max_steps=FLAGS.num_train_steps, hooks=hooks)

if FLAGS.do_eval:
Expand Down