diff --git a/benchmark/bert/run_glue.py b/benchmark/bert/run_glue.py index 241d36b..4dd3aca 100644 --- a/benchmark/bert/run_glue.py +++ b/benchmark/bert/run_glue.py @@ -21,6 +21,7 @@ import numpy as np import paddle +import paddle.fluid as fluid from paddle.io import DataLoader from paddlenlp.datasets import SimpleDataset, GlueQNLI, GlueSST2 @@ -101,6 +102,21 @@ def parse_args(): default=1e-8, type=float, help="Epsilon for Adam optimizer.") + parser.add_argument( + "--momentum_rate", + default=0.9, + type=float, + help="The value of momentum_rate.") + parser.add_argument( + "--l2_decay", + default=1e-4, + type=float, + help="The l2_decay parameter.") + parser.add_argument( + "--multi_precision", + default=False, + type=bool, + help="Whether to enable multi-precision training with fp16.") parser.add_argument( "--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument( @@ -131,6 +147,14 @@ def parse_args(): help="Save checkpoint every X updates steps.") parser.add_argument( "--seed", type=int, default=42, help="Random seed for initialization") + parser.add_argument( + "--use_pure_fp16", type=bool, default=False, help="Whether to enable half precision training with pure fp16.") + parser.add_argument( + "--use_amp", type=bool, default=False, help="Whether to enable half precision training with AMP.") + parser.add_argument( + "--scale_loss", type=float, default=1.0, help="The value of scale_loss for fp16.") + parser.add_argument( + "--use_dynamic_loss_scaling", type=bool, default=True, help="Whether to use dynamic loss scaling.") args = parser.parse_args() return args @@ -145,6 +169,17 @@ def create_data_holder(): return [input_ids, segment_ids, label] +def get_gpu_num(): + visible_device = os.getenv('CUDA_VISIBLE_DEVICES') + if visible_device: + device_num = len(visible_device.split(',')) + else: + device_num = subprocess.check_output( + [str.encode('nvidia-smi'), str.encode('-L')]).decode('utf-8').count( + '\n') + return device_num + + def reset_program_state_dict(model, state_dict, pretrained_state_dict): reset_state_dict = {} scale = model.initializer_range if hasattr(model, "initializer_range")\ @@ -278,9 +313,12 @@ def do_train(args): train_dataset = SimpleDataset(train_dataset).apply(trans_func, lazy=True) + use_tensor_core = False + if args.use_amp or args.use_pure_fp16: + use_tensor_core = True batchify_fn = lambda samples, fn=Tuple( - Pad(axis=0, pad_val=tokenizer.pad_token_id), # input - Pad(axis=0, pad_val=tokenizer.pad_token_id), # segment + Pad(axis=0, pad_val=tokenizer.pad_token_id, use_tensor_core=use_tensor_core), # input + Pad(axis=0, pad_val=tokenizer.pad_token_id, use_tensor_core=use_tensor_core), # segment Stack(dtype="int64" if train_dataset.get_labels() else "float32") # label ): [data for i, data in enumerate(fn(samples))] @@ -321,7 +359,9 @@ def do_train(args): num_classes=len(train_dataset.get_labels())) loss_fct = paddle.nn.loss.CrossEntropyLoss( ) if train_dataset.get_labels() else paddle.nn.loss.MSELoss() - logits = model(input_ids, segment_ids) + logits = model(input_ids, segment_ids, use_pure_fp16=args.use_pure_fp16) + if args.use_pure_fp16: + logits = fluid.layers.cast(logits, "float32") loss = loss_fct(logits, labels) dev_program = main_program.clone(for_test=True) @@ -338,19 +378,24 @@ def do_train(args): 0.0, float(num_training_steps - current_step) / float( max(1, num_training_steps - num_warmup_steps)))) - optimizer = paddle.optimizer.AdamW( + rescale_grad = 1.0 / (args.batch_size / get_gpu_num()) + optimizer = fluid.contrib.optimizer.Momentum( learning_rate=lr_scheduler, - epsilon=args.adam_epsilon, - parameters=model.parameters(), - weight_decay=args.weight_decay, - apply_decay_param_fun=lambda x: x in [ - p.name for n, p in model.named_parameters() - if not any(nd in n for nd in ["bias", "norm"]) - ]) + momentum=args.momentum_rate, + regularization=fluid.regularizer.L2Decay(args.l2_decay), + multi_precision=args.multi_precision, + rescale_grad=rescale_grad) + if args.use_amp: + optimizer = paddle.fluid.contrib.mixed_precision.decorate( + optimizer, + init_loss_scaling=args.scale_loss, + use_dynamic_loss_scaling=args.use_dynamic_loss_scaling) optimizer.minimize(loss) # Create the metric pass for the validation with paddle.static.program_guard(dev_program, startup_program): + if args.use_amp: + logits = paddle.cast(logits, 'float32') metric = metric_class() correct = metric.compute(logits, labels) @@ -364,6 +409,17 @@ def do_train(args): pretrained_state_dict) paddle.static.set_program_state(main_program, reset_state_dict) + exec_strategy = fluid.ExecutionStrategy() + exec_strategy.num_threads = 1 + exec_strategy.num_iteration_per_drop_scope = 10000 + + build_strategy = fluid.BuildStrategy() + + main_program = fluid.CompiledProgram(main_program).with_data_parallel( + loss_name=loss.name, + exec_strategy=exec_strategy, + build_strategy=build_strategy) + global_step = 0 tic_train = time.time() for epoch in range(args.num_train_epochs): diff --git a/benchmark/bert/run_glue_amp.sh b/benchmark/bert/run_glue_amp.sh new file mode 100755 index 0000000..eb6a79a --- /dev/null +++ b/benchmark/bert/run_glue_amp.sh @@ -0,0 +1,17 @@ +export CUDA_VISIBLE_DEVICES=0 +export TASK_NAME=SST-2 + +python -u ./run_glue.py \ + --model_type bert \ + --model_name_or_path bert-base-uncased \ + --task_name $TASK_NAME \ + --max_seq_length 128 \ + --batch_size 64 \ + --learning_rate 2e-5 \ + --num_train_epochs 3 \ + --logging_steps 100 \ + --save_steps 500 \ + --output_dir ./tmp/$TASK_NAME/ \ + --use_amp true \ + --scale_loss 128.0 \ + --use_dynamic_loss_scaling true \ diff --git a/benchmark/bert/run_glue_fp32.sh b/benchmark/bert/run_glue_fp32.sh new file mode 100755 index 0000000..6cb5988 --- /dev/null +++ b/benchmark/bert/run_glue_fp32.sh @@ -0,0 +1,15 @@ +export CUDA_VISIBLE_DEVICES=0 +export TASK_NAME=SST-2 + +python -u ./run_glue.py \ + --model_type bert \ + --model_name_or_path bert-base-uncased \ + --task_name $TASK_NAME \ + --max_seq_length 128 \ + --batch_size 64 \ + --learning_rate 2e-5 \ + --num_train_epochs 3 \ + --logging_steps 100 \ + --save_steps 500 \ + --output_dir ./tmp/$TASK_NAME/ \ + diff --git a/benchmark/bert/run_glue_pure_fp16.sh b/benchmark/bert/run_glue_pure_fp16.sh new file mode 100755 index 0000000..b368e2c --- /dev/null +++ b/benchmark/bert/run_glue_pure_fp16.sh @@ -0,0 +1,16 @@ +export CUDA_VISIBLE_DEVICES=0 +export TASK_NAME=SST-2 + +python -u ./run_glue.py \ + --model_type bert \ + --model_name_or_path bert-base-uncased \ + --task_name $TASK_NAME \ + --max_seq_length 128 \ + --batch_size 64 \ + --learning_rate 2e-5 \ + --num_train_epochs 3 \ + --logging_steps 100 \ + --save_steps 500 \ + --output_dir ./tmp/$TASK_NAME/ \ + --use_pure_fp16 true \ + --multi_precision true diff --git a/paddlenlp/data/batchify.py b/paddlenlp/data/batchify.py index 32dd3dd..faf2c1e 100644 --- a/paddlenlp/data/batchify.py +++ b/paddlenlp/data/batchify.py @@ -92,11 +92,12 @@ class Pad(object): [8. 2. 0. 0.]] ''' """ - def __init__(self, pad_val=0, axis=0, ret_length=None, dtype=None): + def __init__(self, pad_val=0, axis=0, ret_length=None, dtype=None, use_tensor_core=False): self._pad_val = pad_val self._axis = axis self._ret_length = ret_length self._dtype = dtype + self._use_tensor_core = use_tensor_core def __call__(self, data): """ @@ -116,6 +117,8 @@ def __call__(self, data): arrs = [np.asarray(ele) for ele in data] original_length = [ele.shape[self._axis] for ele in arrs] max_size = max(original_length) + if self._use_tensor_core and max_size % 8 != 0: + max_size = (int(max_size / 8) + 1) * 8 ret_shape = list(arrs[0].shape) ret_shape[self._axis] = max_size ret_shape = (len(arrs), ) + tuple(ret_shape) diff --git a/paddlenlp/transformers/bert/modeling.py b/paddlenlp/transformers/bert/modeling.py index e7a7e6d..ea8fbd5 100644 --- a/paddlenlp/transformers/bert/modeling.py +++ b/paddlenlp/transformers/bert/modeling.py @@ -271,7 +271,8 @@ def forward(self, input_ids, token_type_ids=None, position_ids=None, - attention_mask=None): + attention_mask=None, + use_pure_fp16=False): if attention_mask is None: attention_mask = paddle.unsqueeze( (input_ids == self.pad_token_id @@ -281,6 +282,9 @@ def forward(self, input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids) + if use_pure_fp16: + attention_mask = paddle.cast(attention_mask, "float16") + embedding_output = paddle.cast(embedding_output, "float16") encoder_outputs = self.encoder(embedding_output, attention_mask) sequence_output = encoder_outputs pooled_output = self.pooler(sequence_output) @@ -333,12 +337,14 @@ def forward(self, input_ids, token_type_ids=None, position_ids=None, - attention_mask=None): + attention_mask=None, + use_pure_fp16=False): _, pooled_output = self.bert( input_ids, token_type_ids=token_type_ids, position_ids=position_ids, - attention_mask=attention_mask) + attention_mask=attention_mask, + use_pure_fp16=use_pure_fp16) pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output)