From 85c24e418432a75d6e61806223a95219e536c66b Mon Sep 17 00:00:00 2001 From: pangyoki Date: Wed, 25 Nov 2020 08:53:30 +0000 Subject: [PATCH 1/3] add AMP for BERT finetune --- benchmark/bert/run_glue.py | 23 +++++++++++++++++++++++ benchmark/bert/run_glue_amp.sh | 17 +++++++++++++++++ paddlenlp/data/batchify.py | 11 +++++++++-- 3 files changed, 49 insertions(+), 2 deletions(-) create mode 100755 benchmark/bert/run_glue_amp.sh diff --git a/benchmark/bert/run_glue.py b/benchmark/bert/run_glue.py index 241d36b..4a49acb 100644 --- a/benchmark/bert/run_glue.py +++ b/benchmark/bert/run_glue.py @@ -27,6 +27,8 @@ from paddlenlp.data.batchify import Stack, Tuple, Pad from paddlenlp.transformers import BertForSequenceClassification, BertTokenizer +import paddle.fluid as fluid + FORMAT = '%(asctime)s-%(levelname)s: %(message)s' logging.basicConfig(level=logging.INFO, format=FORMAT) logger = logging.getLogger(__name__) @@ -131,6 +133,21 @@ def parse_args(): help="Save checkpoint every X updates steps.") parser.add_argument( "--seed", type=int, default=42, help="Random seed for initialization") + parser.add_argument( + "--use_fp16", + type=bool, + default=False, + help="Whether to enable half precision training with fp16.") + parser.add_argument( + "--scale_loss", + type=float, + default=1.0, + help="The value of scale_loss for fp16.") + parser.add_argument( + "--use_dynamic_loss_scaling", + type=bool, + default=True, + help="Whether to use dynamic loss scaling.") args = parser.parse_args() return args @@ -347,10 +364,16 @@ def do_train(args): p.name for n, p in model.named_parameters() if not any(nd in n for nd in ["bias", "norm"]) ]) + if args.use_fp16: + optimizer = paddle.fluid.contrib.mixed_precision.decorate( + optimizer, + init_loss_scaling=args.scale_loss, + use_dynamic_loss_scaling=args.use_dynamic_loss_scaling) optimizer.minimize(loss) # Create the metric pass for the validation with paddle.static.program_guard(dev_program, startup_program): + logits = paddle.fluid.layers.cast(logits, 'float32') metric = metric_class() correct = metric.compute(logits, labels) diff --git a/benchmark/bert/run_glue_amp.sh b/benchmark/bert/run_glue_amp.sh new file mode 100755 index 0000000..1c70334 --- /dev/null +++ b/benchmark/bert/run_glue_amp.sh @@ -0,0 +1,17 @@ +export CUDA_VISIBLE_DEVICES=0 +export TASK_NAME=SST-2 + +python -u ./run_glue.py \ + --model_type bert \ + --model_name_or_path bert-base-uncased \ + --task_name $TASK_NAME \ + --max_seq_length 128 \ + --batch_size 64 \ + --learning_rate 2e-5 \ + --num_train_epochs 3 \ + --logging_steps 20 \ + --save_steps 500 \ + --output_dir ./tmp/$TASK_NAME/ \ + --use_fp16=true \ + --scale_loss=128.0 \ + --use_dynamic_loss_scaling=true \ diff --git a/paddlenlp/data/batchify.py b/paddlenlp/data/batchify.py index 32dd3dd..b51ef48 100644 --- a/paddlenlp/data/batchify.py +++ b/paddlenlp/data/batchify.py @@ -44,6 +44,7 @@ class Stack(object): [8 9 1 2]] ''' """ + def __init__(self, axis=0, dtype=None): self._axis = axis self._dtype = dtype @@ -56,8 +57,10 @@ def __call__(self, data): Returns: numpy.ndarray: Stacked batch data. """ - data = np.stack(data, axis=self._axis).astype( - self._dtype) if self._dtype else np.stack(data, axis=self._axis) + data = np.stack( + data, + axis=self._axis).astype(self._dtype) if self._dtype else np.stack( + data, axis=self._axis) return data @@ -92,6 +95,7 @@ class Pad(object): [8. 2. 0. 0.]] ''' """ + def __init__(self, pad_val=0, axis=0, ret_length=None, dtype=None): self._pad_val = pad_val self._axis = axis @@ -116,6 +120,8 @@ def __call__(self, data): arrs = [np.asarray(ele) for ele in data] original_length = [ele.shape[self._axis] for ele in arrs] max_size = max(original_length) + if max_size % 8 != 0: + max_size = (int(max_size / 8) + 1) * 8 ret_shape = list(arrs[0].shape) ret_shape[self._axis] = max_size ret_shape = (len(arrs), ) + tuple(ret_shape) @@ -160,6 +166,7 @@ class Tuple(object): from paddle.incubate.hapi.text.data_utils import Tuple, Pad, Stack batchify_fn = Tuple(Pad(axis=0, pad_val=0), Stack()) """ + def __init__(self, fn, *args): if isinstance(fn, (list, tuple)): assert len(args) == 0, 'Input pattern not understood. The input of Tuple can be ' \ From a57286187a02e0132390ab4f4fda9fd3a6be3fc8 Mon Sep 17 00:00:00 2001 From: pangyoki Date: Wed, 25 Nov 2020 12:02:55 +0000 Subject: [PATCH 2/3] fix feed data --- benchmark/bert/run_glue.py | 42 +++++++++++++++++++++++++++++----- benchmark/bert/run_glue_amp.sh | 6 ++--- 2 files changed, 39 insertions(+), 9 deletions(-) diff --git a/benchmark/bert/run_glue.py b/benchmark/bert/run_glue.py index 4a49acb..483c832 100644 --- a/benchmark/bert/run_glue.py +++ b/benchmark/bert/run_glue.py @@ -28,6 +28,7 @@ from paddlenlp.transformers import BertForSequenceClassification, BertTokenizer import paddle.fluid as fluid +from paddle.fluid import profiler FORMAT = '%(asctime)s-%(levelname)s: %(message)s' logging.basicConfig(level=logging.INFO, format=FORMAT) @@ -364,11 +365,10 @@ def do_train(args): p.name for n, p in model.named_parameters() if not any(nd in n for nd in ["bias", "norm"]) ]) - if args.use_fp16: - optimizer = paddle.fluid.contrib.mixed_precision.decorate( - optimizer, - init_loss_scaling=args.scale_loss, - use_dynamic_loss_scaling=args.use_dynamic_loss_scaling) + optimizer = paddle.fluid.contrib.mixed_precision.decorate( + optimizer, + init_loss_scaling=args.scale_loss, + use_dynamic_loss_scaling=args.use_dynamic_loss_scaling) optimizer.minimize(loss) # Create the metric pass for the validation @@ -388,10 +388,36 @@ def do_train(args): paddle.static.set_program_state(main_program, reset_state_dict) global_step = 0 + + # construct fixed feeding data + seq_len = 48 + label = np.random.randint(0, 2, [64, 1]) # batch_size = 64 + input_ids = np.random.randint(0, 11000, [64, seq_len]) + segment_ids = np.zeros([64, seq_len]).astype('int64') + batch = [{ + 'label': label, + 'input_ids': input_ids, + 'segment_ids': segment_ids + }] + tic_train = time.time() for epoch in range(args.num_train_epochs): - for step, batch in enumerate(train_data_loader): + # remove data_loader + # for step, batch in enumerate(train_data_loader): + for step in range(1000): global_step += 1 + + # profiler + """ + if step == 200: + # profiler.start_profiler("All") + fluid.core.nvprof_start() + if step == 210: + fluid.core.nvprof_stop() + # profiler.stop_profiler("total", "./profile") + return + """ + loss_return = exe.run(main_program, feed=batch, fetch_list=[loss]) if global_step % args.logging_steps == 0: logger.info( @@ -400,6 +426,9 @@ def do_train(args): args.logging_steps / (time.time() - tic_train))) tic_train = time.time() lr_scheduler.step() + + # don't do evaluation and don't save params + """ if global_step % args.save_steps == 0: # Validation pass, record the loss and metric evaluate(exe, metric, loss, correct, dev_program, @@ -410,6 +439,7 @@ def do_train(args): os.makedirs(output_dir) paddle.fluid.io.save_params(exe, output_dir) tokenizer.save_pretrained(output_dir) + """ if __name__ == "__main__": diff --git a/benchmark/bert/run_glue_amp.sh b/benchmark/bert/run_glue_amp.sh index 1c70334..2a4fdd5 100755 --- a/benchmark/bert/run_glue_amp.sh +++ b/benchmark/bert/run_glue_amp.sh @@ -12,6 +12,6 @@ python -u ./run_glue.py \ --logging_steps 20 \ --save_steps 500 \ --output_dir ./tmp/$TASK_NAME/ \ - --use_fp16=true \ - --scale_loss=128.0 \ - --use_dynamic_loss_scaling=true \ + --use_fp16 true \ + --scale_loss 128.0 \ + --use_dynamic_loss_scaling true \ From 0ccf193973dd152b546843bac53afb3a7d9bfe2d Mon Sep 17 00:00:00 2001 From: pangyoki Date: Thu, 26 Nov 2020 08:38:37 +0000 Subject: [PATCH 3/3] use PE --- benchmark/bert/run_glue.py | 44 ++++++++++++++------------------------ 1 file changed, 16 insertions(+), 28 deletions(-) diff --git a/benchmark/bert/run_glue.py b/benchmark/bert/run_glue.py index 483c832..360bd3d 100644 --- a/benchmark/bert/run_glue.py +++ b/benchmark/bert/run_glue.py @@ -28,7 +28,6 @@ from paddlenlp.transformers import BertForSequenceClassification, BertTokenizer import paddle.fluid as fluid -from paddle.fluid import profiler FORMAT = '%(asctime)s-%(levelname)s: %(message)s' logging.basicConfig(level=logging.INFO, format=FORMAT) @@ -365,10 +364,11 @@ def do_train(args): p.name for n, p in model.named_parameters() if not any(nd in n for nd in ["bias", "norm"]) ]) - optimizer = paddle.fluid.contrib.mixed_precision.decorate( - optimizer, - init_loss_scaling=args.scale_loss, - use_dynamic_loss_scaling=args.use_dynamic_loss_scaling) + if args.use_fp16: + optimizer = paddle.fluid.contrib.mixed_precision.decorate( + optimizer, + init_loss_scaling=args.scale_loss, + use_dynamic_loss_scaling=args.use_dynamic_loss_scaling) optimizer.minimize(loss) # Create the metric pass for the validation @@ -387,28 +387,22 @@ def do_train(args): pretrained_state_dict) paddle.static.set_program_state(main_program, reset_state_dict) - global_step = 0 + exec_strategy = fluid.ExecutionStrategy() + exec_strategy.num_threads = 1 + exec_strategy.num_iteration_per_drop_scope = 10000 + + build_strategy = fluid.BuildStrategy() - # construct fixed feeding data - seq_len = 48 - label = np.random.randint(0, 2, [64, 1]) # batch_size = 64 - input_ids = np.random.randint(0, 11000, [64, seq_len]) - segment_ids = np.zeros([64, seq_len]).astype('int64') - batch = [{ - 'label': label, - 'input_ids': input_ids, - 'segment_ids': segment_ids - }] + main_program = fluid.CompiledProgram(main_program).with_data_parallel( + loss_name=loss.name, + exec_strategy=exec_strategy, + build_strategy=build_strategy) + global_step = 0 tic_train = time.time() for epoch in range(args.num_train_epochs): - # remove data_loader - # for step, batch in enumerate(train_data_loader): - for step in range(1000): + for step, batch in enumerate(train_data_loader): global_step += 1 - - # profiler - """ if step == 200: # profiler.start_profiler("All") fluid.core.nvprof_start() @@ -416,8 +410,6 @@ def do_train(args): fluid.core.nvprof_stop() # profiler.stop_profiler("total", "./profile") return - """ - loss_return = exe.run(main_program, feed=batch, fetch_list=[loss]) if global_step % args.logging_steps == 0: logger.info( @@ -426,9 +418,6 @@ def do_train(args): args.logging_steps / (time.time() - tic_train))) tic_train = time.time() lr_scheduler.step() - - # don't do evaluation and don't save params - """ if global_step % args.save_steps == 0: # Validation pass, record the loss and metric evaluate(exe, metric, loss, correct, dev_program, @@ -439,7 +428,6 @@ def do_train(args): os.makedirs(output_dir) paddle.fluid.io.save_params(exe, output_dir) tokenizer.save_pretrained(output_dir) - """ if __name__ == "__main__":