diff --git a/data_pipeline.py b/data_pipeline.py index 388f459..fd08c53 100644 --- a/data_pipeline.py +++ b/data_pipeline.py @@ -58,7 +58,8 @@ def input_fn(tf_records, dataset = tf.data.TFRecordDataset(tf_records, buffer_size=10000) dataset = dataset.shuffle(buffer_size=buffer_size) - dataset = dataset.map(parse_example,num_parallel_calls=tf.data.experimental.AUTOTUNE) + dataset = dataset.map(lambda + x:tf.py_function(func=parse_example,inp=[x],Tout=(tf.int32,tf.int32)),num_parallel_calls=tf.data.AUTOTUNE) dataset = dataset.padded_batch(batch_size, padded_shapes=padded_shapes) dataset = dataset.repeat(epoch) dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE) diff --git a/train_gpt2.py b/train_gpt2.py index 06443c9..b7c4d10 100644 --- a/train_gpt2.py +++ b/train_gpt2.py @@ -6,6 +6,9 @@ from data_pipeline import input_fn from gpt2_model import * +from scripts.utils import write_csv +import timeit + _ROOT = os.path.abspath(os.path.dirname(__file__)) LOG_DIR = _ROOT + "/log" MODEL_DIR = _ROOT + "/model" @@ -44,6 +47,9 @@ def train(num_layers, embedding_size, num_heads, dff, max_seq_len, vocab_size, train_tf_records = tf_records[:train_percent] test_tf_records = tf_records[train_percent:] + start_time = timeit.default_timer() + skipped_time = 0 + train_dataset = input_fn(train_tf_records, batch_size=batch_size) test_dataset = input_fn(test_tf_records, batch_size=batch_size) @@ -70,6 +76,10 @@ def train(num_layers, embedding_size, num_heads, dff, max_seq_len, vocab_size, model.create_summary_writer(LOG_DIR) model.fit([train_dataset, test_dataset], graph_mode) + + time = timeit.default_timer() - start_time - skipped_time + write_csv(__file__, time=time) + print("Training Done................")