|
| 1 | +import tensorflow as tf |
| 2 | + |
| 3 | +from open_seq2seq.models import AWDLSTM |
| 4 | +from open_seq2seq.encoders import AWDLSTMEncoder |
| 5 | +# from open_seq2seq.encoders import BidirectionalRNNEncoderWithEmbedding |
| 6 | +from open_seq2seq.decoders import FakeDecoder |
| 7 | +from open_seq2seq.data import LMTextDataLayer, LMTextDataLayerGenerate |
| 8 | +from open_seq2seq.parts.rnns.weight_drop import WeightDropLayerNormBasicLSTMCell |
| 9 | +# from open_seq2seq.losses import CrossEntropyLoss |
| 10 | +from open_seq2seq.losses import BasicSequenceLoss |
| 11 | +from open_seq2seq.optimizers.lr_policies import fixed_lr |
| 12 | +# from open_seq2seq.data.text2text.text2text import SpecialTextTokens |
| 13 | +# from open_seq2seq.optimizers.lr_policies import exp_decay |
| 14 | + |
| 15 | +data_root = "/data/wikitext-2/" |
| 16 | + |
| 17 | +base_model = AWDLSTM |
| 18 | +bptt = 72 |
| 19 | +steps = 40 |
| 20 | + |
| 21 | +base_params = { |
| 22 | + # "seed": 1882, # conforming to AWD-LSTM paper |
| 23 | + "restore_best_checkpoint": True, |
| 24 | + "use_horovod": True, |
| 25 | + "num_gpus": 2, |
| 26 | + |
| 27 | + "batch_size_per_gpu": 160, # conforming to AWD-LSTM paper 80 |
| 28 | + "num_epochs": 750, # conforming to AWD-LSTM paper 750 |
| 29 | + "save_summaries_steps": steps, |
| 30 | + "print_loss_steps": steps, |
| 31 | + "print_samples_steps": steps, |
| 32 | + "save_checkpoint_steps": steps, |
| 33 | + "logdir": "AWDLSTM-EXP64", |
| 34 | + "eval_steps": steps * 2, |
| 35 | + |
| 36 | + "optimizer": "Adam", # need to change to NT-ASGD |
| 37 | + "optimizer_params": {}, |
| 38 | + # luong10 decay scheme |
| 39 | + |
| 40 | + "lr_policy": fixed_lr, |
| 41 | + "lr_policy_params": { |
| 42 | + "learning_rate": 9e-4 |
| 43 | + }, |
| 44 | + |
| 45 | + # "lr_policy": exp_decay, |
| 46 | + # "lr_policy_params": { |
| 47 | + # "learning_rate": 0.0008, |
| 48 | + # "begin_decay_at": 170000, |
| 49 | + # "decay_steps": 17000, |
| 50 | + # "decay_rate": 0.5, |
| 51 | + # "use_staircase_decay": True, |
| 52 | + # "min_lr": 0.0000005, |
| 53 | + # }, |
| 54 | + "summaries": ['learning_rate', 'variables', 'gradients', |
| 55 | + 'variable_norm', 'gradient_norm', 'global_gradient_norm'], |
| 56 | + # "grad_clip":0.25, # conforming to AWD-LSTM paper |
| 57 | + # "max_grad_norm": 0.25, # conform to paper 0.25 |
| 58 | + "dtype": tf.float32, |
| 59 | + #"dtype": "mixed", |
| 60 | + #"automatic_loss_scaling": "Backoff", |
| 61 | + "encoder": AWDLSTMEncoder, |
| 62 | + # "encoder": BidirectionalRNNEncoderWithEmbedding, |
| 63 | + "encoder_params": { # will need to update |
| 64 | + "initializer": tf.random_uniform_initializer, |
| 65 | + "initializer_params": { # need different initializers for embeddings and for weights |
| 66 | + "minval": -0.1, |
| 67 | + "maxval": 0.1, |
| 68 | + }, |
| 69 | + # "core_cell": tf.contrib.rnn.LayerNormBasicLSTMCell, |
| 70 | + "core_cell": WeightDropLayerNormBasicLSTMCell, |
| 71 | + "core_cell_params": { |
| 72 | + "num_units": 800, # paper 1150 |
| 73 | + "forget_bias": 1.0, |
| 74 | + }, |
| 75 | + "last_cell_params": { |
| 76 | + "num_units": 320, |
| 77 | + "forget_bias": 1.0, |
| 78 | + }, |
| 79 | + "encoder_layers": 3, |
| 80 | + "encoder_dp_input_keep_prob": 1.0, |
| 81 | + "encoder_dp_output_keep_prob": 0.6, # output dropout for middle layer 0.3 |
| 82 | + "encoder_last_input_keep_prob": 1.0, |
| 83 | + "encoder_last_output_keep_prob": 0.6, # output droput at last layer is 0.4 |
| 84 | + "recurrent_keep_prob": 0.5, |
| 85 | + 'encoder_emb_keep_prob': 0.5, |
| 86 | + "encoder_use_skip_connections": False, |
| 87 | + "emb_size": 320, |
| 88 | + "vocab_size": 33278, |
| 89 | + "num_tokens_gen": 10, |
| 90 | + "sampling_prob": 0.0, # 0 is always use the ground truth |
| 91 | + "fc_use_bias": True, |
| 92 | + "weight_tied": True, |
| 93 | + "awd_initializer": False, |
| 94 | + }, |
| 95 | + |
| 96 | + "decoder": FakeDecoder, # need a new decoder with AR and TAR |
| 97 | + |
| 98 | + "regularizer": tf.contrib.layers.l2_regularizer, |
| 99 | + "regularizer_params": { |
| 100 | + 'scale': 2e-6, # alpha |
| 101 | + }, |
| 102 | + |
| 103 | + # "loss": CrossEntropyLoss, # will need to write new loss + regularizer |
| 104 | + "loss": BasicSequenceLoss, |
| 105 | + "loss_params": { |
| 106 | + "offset_target_by_one": False, |
| 107 | + "average_across_timestep": True, |
| 108 | + "do_mask": False, |
| 109 | + } |
| 110 | +} |
| 111 | + |
| 112 | +train_params = { |
| 113 | + "data_layer": LMTextDataLayer, |
| 114 | + "data_layer_params": { |
| 115 | + "pad_vocab_to_eight": False, |
| 116 | + "vocab_file": data_root+"vocab.txt", |
| 117 | + "content_file": data_root+"train.ids", |
| 118 | + "rand_start": True, |
| 119 | + "shuffle": True, |
| 120 | + "shuffle_buffer_size": 25000, |
| 121 | + "repeat": True, |
| 122 | + "map_parallel_calls": 16, |
| 123 | + "prefetch_buffer_size": 8, |
| 124 | + "bptt": bptt, |
| 125 | + }, |
| 126 | +} |
| 127 | +eval_params = { |
| 128 | + # "batch_size_per_gpu": 320, |
| 129 | + "data_layer": LMTextDataLayer, |
| 130 | + "data_layer_params": { |
| 131 | + "pad_vocab_to_eight": False, |
| 132 | + "vocab_file": data_root+"vocab.txt", |
| 133 | + "content_file": data_root+"valid.ids", |
| 134 | + "shuffle": False, |
| 135 | + "repeat": False, |
| 136 | + "map_parallel_calls": 16, |
| 137 | + "prefetch_buffer_size": 1, |
| 138 | + "bptt": bptt, |
| 139 | + }, |
| 140 | +} |
| 141 | + |
| 142 | +infer_params = { |
| 143 | + "data_layer": LMTextDataLayer, |
| 144 | + "data_layer_params": { |
| 145 | + "pad_vocab_to_eight": False, |
| 146 | + "vocab_file": data_root+"vocab.txt", |
| 147 | + "content_file": data_root+"test.ids", |
| 148 | + "shuffle": False, |
| 149 | + "repeat": False, |
| 150 | + "rand_start": False, |
| 151 | + "map_parallel_calls": 16, |
| 152 | + "prefetch_buffer_size": 8, |
| 153 | + "bptt": bptt, |
| 154 | + "seed_tokens": "something The only game", |
| 155 | + }, |
| 156 | +} |
0 commit comments