From 092492c107e71e0759b9add758431016c092b619 Mon Sep 17 00:00:00 2001 From: Hanzhi Zhou Date: Thu, 10 Jul 2025 20:46:26 -0700 Subject: [PATCH] Add fuji 150b test --- .../fuji-150B-v2-flash-fp8.txt | 313 ++++++++++++++++++ .../fuji-150B-v2-flash-fp8_init.txt | 10 + .../fuji-150B-v2-flash-fp8_regularizer.txt | 11 + .../fuji-150B-v2-flash.txt | 313 ++++++++++++++++++ .../fuji-150B-v2-flash_init.txt | 10 + .../fuji-150B-v2-flash_regularizer.txt | 11 + .../fuji-150B-v2-fp8.txt | 280 ++++++++++++++++ .../fuji-150B-v2-fp8_init.txt | 10 + .../fuji-150B-v2-fp8_regularizer.txt | 11 + .../fuji-150B-v2.txt | 280 ++++++++++++++++ .../fuji-150B-v2_init.txt | 10 + .../fuji-150B-v2_regularizer.txt | 11 + .../fuji-150B-v3-flash-fp8.txt | 313 ++++++++++++++++++ .../fuji-150B-v3-flash-fp8_init.txt | 10 + .../fuji-150B-v3-flash-fp8_regularizer.txt | 11 + .../fuji-150B-v3-flash.txt | 313 ++++++++++++++++++ .../fuji-150B-v3-flash_init.txt | 10 + .../fuji-150B-v3-flash_regularizer.txt | 11 + .../fuji-150B-v3-fp8.txt | 280 ++++++++++++++++ .../fuji-150B-v3-fp8_init.txt | 10 + .../fuji-150B-v3-fp8_regularizer.txt | 11 + .../fuji-150B-v3-tiktoken-flash-fp8.txt | 313 ++++++++++++++++++ .../fuji-150B-v3-tiktoken-flash-fp8_init.txt | 10 + ...150B-v3-tiktoken-flash-fp8_regularizer.txt | 11 + .../fuji-150B-v3-tiktoken-flash.txt | 313 ++++++++++++++++++ .../fuji-150B-v3-tiktoken-flash_init.txt | 10 + ...uji-150B-v3-tiktoken-flash_regularizer.txt | 11 + .../fuji-150B-v3-tiktoken-fp8.txt | 280 ++++++++++++++++ .../fuji-150B-v3-tiktoken-fp8_init.txt | 10 + .../fuji-150B-v3-tiktoken-fp8_regularizer.txt | 11 + .../fuji-150B-v3-tiktoken.txt | 280 ++++++++++++++++ .../fuji-150B-v3-tiktoken_init.txt | 10 + .../fuji-150B-v3-tiktoken_regularizer.txt | 11 + .../fuji-150B-v3.txt | 280 ++++++++++++++++ .../fuji-150B-v3_init.txt | 10 + .../fuji-150B-v3_regularizer.txt | 11 + axlearn/experiments/text/gpt/fuji.py | 53 ++- patches/shard_map.py.patch | 27 ++ 38 files changed, 3889 insertions(+), 1 deletion(-) create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-flash-fp8.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-flash-fp8_init.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-flash-fp8_regularizer.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-flash.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-flash_init.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-flash_regularizer.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-fp8.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-fp8_init.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-fp8_regularizer.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2_init.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2_regularizer.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-flash-fp8.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-flash-fp8_init.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-flash-fp8_regularizer.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-flash.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-flash_init.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-flash_regularizer.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-fp8.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-fp8_init.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-fp8_regularizer.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-flash-fp8.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-flash-fp8_init.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-flash-fp8_regularizer.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-flash.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-flash_init.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-flash_regularizer.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-fp8.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-fp8_init.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-fp8_regularizer.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken_init.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken_regularizer.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3_init.txt create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3_regularizer.txt create mode 100644 patches/shard_map.py.patch diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-flash-fp8.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-flash-fp8.txt new file mode 100644 index 000000000..30aaebc9c --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-flash-fp8.txt @@ -0,0 +1,313 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 50000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 524288 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 5000 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 524288 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 5000 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.input_dispatcher.global_logical_batch_size: 1024 +evalers['train'].input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 4096 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +evalers['train'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 524288 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 5000 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.input_dispatcher.global_logical_batch_size: 1024 +evalers['validation'].input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 4096 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +evalers['validation'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.input_dispatcher.global_logical_batch_size: 1024 +input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +input.input_partitioner.fn: 'axlearn.common.input_base.partition_by_path_rank' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][0]: 'data' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][1]: 'expert' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][2]: 'fsdp' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][0]: 'data' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][1]: 'expert' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][2]: 'fsdp' +input.input_partitioner.path_rank_to_partition[(None, 2)][1]: 'seq' +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 4096 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.00015 +learner.optimizer.args[1].update_schedule.alpha: 0.1 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 524288 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 2000 +learner.optimizer.args[1].weight_decay: 0.1 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 524288 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'tpu-v6e-256.*' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: 64 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.common.utils.policy' +mesh_rules[0][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' +mesh_rules[0][1].config_modifiers[2].tpu_block_size: 1024 +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_shape[0]: 1 +mesh_shape[1]: -1 +mesh_shape[2]: 1 +mesh_shape[3]: 64 +mesh_shape[4]: 1 +mesh_shape[5]: 4 +model.batch_axis_names: None +model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' +model.decoder.dim: 12288 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0]: None +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 1 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.lm_head.klass: 'axlearn.common.decoder.LmHead' +model.decoder.lm_head.param_partition_spec[0]: None +model.decoder.lm_head.param_partition_spec[1]: 'model' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 0 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' +model.decoder.transformer.layer.feed_forward.hidden_dim.round_up_to_multiples_of: 256 +model.decoder.transformer.layer.feed_forward.hidden_dim.scale: 3.5 +model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.feed_forward.linear1.bias: False +model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.bias: False +model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.feed_forward.structure: 'prenorm' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' +model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' +model.decoder.transformer.layer.self_attention.attention.causal: True +model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8 +model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 10000.0 +model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.flash_attention.layer.FlashAttention' +model.decoder.transformer.layer.self_attention.attention.kv_cache.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.kv_cache.klass: 'axlearn.common.attention.KVCache' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][2]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.self_attention.attention.num_heads: 96 +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][1]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][2]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.self_attention.attention.tpu_block_size: 512 +model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.self_attention.structure: 'prenorm' +model.decoder.transformer.num_layers: 80 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 32768 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.metrics.klass: 'axlearn.common.causal_lm.CompositeLossMetrics' +model.metrics.metrics['lm'].klass: 'axlearn.common.causal_lm.CrossEntropyLossMetrics' +model.metrics.metrics['aux'].klass: 'axlearn.common.causal_lm.AuxLossMetrics' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-flash-fp8_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-flash-fp8_init.txt new file mode 100644 index 000000000..d3f162fe3 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-flash-fp8_init.txt @@ -0,0 +1,10 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32768, 12288], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(12288, 112, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(12288, 96, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(12288, 43008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(12288, 43008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(43008, 12288), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/output_norm/scale: constant(1.0) +decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(32768, 12288), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-flash-fp8_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-flash-fp8_regularizer.txt new file mode 100644 index 000000000..65733fb7f --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-flash-fp8_regularizer.txt @@ -0,0 +1,11 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/lm_head/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-flash.txt new file mode 100644 index 000000000..30aaebc9c --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-flash.txt @@ -0,0 +1,313 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 50000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 524288 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 5000 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 524288 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 5000 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.input_dispatcher.global_logical_batch_size: 1024 +evalers['train'].input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 4096 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +evalers['train'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 524288 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 5000 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.input_dispatcher.global_logical_batch_size: 1024 +evalers['validation'].input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 4096 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +evalers['validation'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.input_dispatcher.global_logical_batch_size: 1024 +input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +input.input_partitioner.fn: 'axlearn.common.input_base.partition_by_path_rank' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][0]: 'data' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][1]: 'expert' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][2]: 'fsdp' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][0]: 'data' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][1]: 'expert' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][2]: 'fsdp' +input.input_partitioner.path_rank_to_partition[(None, 2)][1]: 'seq' +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 4096 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.00015 +learner.optimizer.args[1].update_schedule.alpha: 0.1 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 524288 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 2000 +learner.optimizer.args[1].weight_decay: 0.1 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 524288 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'tpu-v6e-256.*' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: 64 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.common.utils.policy' +mesh_rules[0][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' +mesh_rules[0][1].config_modifiers[2].tpu_block_size: 1024 +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_shape[0]: 1 +mesh_shape[1]: -1 +mesh_shape[2]: 1 +mesh_shape[3]: 64 +mesh_shape[4]: 1 +mesh_shape[5]: 4 +model.batch_axis_names: None +model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' +model.decoder.dim: 12288 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0]: None +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 1 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.lm_head.klass: 'axlearn.common.decoder.LmHead' +model.decoder.lm_head.param_partition_spec[0]: None +model.decoder.lm_head.param_partition_spec[1]: 'model' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 0 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' +model.decoder.transformer.layer.feed_forward.hidden_dim.round_up_to_multiples_of: 256 +model.decoder.transformer.layer.feed_forward.hidden_dim.scale: 3.5 +model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.feed_forward.linear1.bias: False +model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.bias: False +model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.feed_forward.structure: 'prenorm' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' +model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' +model.decoder.transformer.layer.self_attention.attention.causal: True +model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8 +model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 10000.0 +model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.flash_attention.layer.FlashAttention' +model.decoder.transformer.layer.self_attention.attention.kv_cache.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.kv_cache.klass: 'axlearn.common.attention.KVCache' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][2]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.self_attention.attention.num_heads: 96 +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][1]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][2]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.self_attention.attention.tpu_block_size: 512 +model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.self_attention.structure: 'prenorm' +model.decoder.transformer.num_layers: 80 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 32768 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.metrics.klass: 'axlearn.common.causal_lm.CompositeLossMetrics' +model.metrics.metrics['lm'].klass: 'axlearn.common.causal_lm.CrossEntropyLossMetrics' +model.metrics.metrics['aux'].klass: 'axlearn.common.causal_lm.AuxLossMetrics' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-flash_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-flash_init.txt new file mode 100644 index 000000000..d3f162fe3 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-flash_init.txt @@ -0,0 +1,10 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32768, 12288], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(12288, 112, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(12288, 96, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(12288, 43008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(12288, 43008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(43008, 12288), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/output_norm/scale: constant(1.0) +decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(32768, 12288), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-flash_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-flash_regularizer.txt new file mode 100644 index 000000000..65733fb7f --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-flash_regularizer.txt @@ -0,0 +1,11 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/lm_head/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-fp8.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-fp8.txt new file mode 100644 index 000000000..95e954a3b --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-fp8.txt @@ -0,0 +1,280 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 50000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 524288 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 5000 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 524288 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 5000 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.input_dispatcher.global_logical_batch_size: 1024 +evalers['train'].input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 4096 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +evalers['train'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 524288 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 5000 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.input_dispatcher.global_logical_batch_size: 1024 +evalers['validation'].input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 4096 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +evalers['validation'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.input_dispatcher.global_logical_batch_size: 1024 +input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +input.input_partitioner.fn: 'axlearn.common.input_base.partition_by_path_rank' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][0]: 'data' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][1]: 'expert' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][2]: 'fsdp' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][0]: 'data' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][1]: 'expert' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][2]: 'fsdp' +input.input_partitioner.path_rank_to_partition[(None, 2)][1]: 'seq' +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 4096 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.00015 +learner.optimizer.args[1].update_schedule.alpha: 0.1 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 524288 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 2000 +learner.optimizer.args[1].weight_decay: 0.1 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 524288 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'tpu-v6e-256.*' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: 64 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.common.utils.policy' +mesh_rules[0][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' +mesh_rules[0][1].config_modifiers[2].tpu_block_size: 1024 +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_shape[0]: 1 +mesh_shape[1]: -1 +mesh_shape[2]: 1 +mesh_shape[3]: 64 +mesh_shape[4]: 1 +mesh_shape[5]: 4 +model.batch_axis_names: None +model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' +model.decoder.dim: 12288 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0]: None +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 1 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.lm_head.klass: 'axlearn.common.decoder.LmHead' +model.decoder.lm_head.param_partition_spec[0]: None +model.decoder.lm_head.param_partition_spec[1]: 'model' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 0 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' +model.decoder.transformer.layer.feed_forward.hidden_dim.round_up_to_multiples_of: 256 +model.decoder.transformer.layer.feed_forward.hidden_dim.scale: 3.5 +model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.feed_forward.linear1.bias: False +model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.bias: False +model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.feed_forward.structure: 'prenorm' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' +model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' +model.decoder.transformer.layer.self_attention.attention.causal: True +model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8 +model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 10000.0 +model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.attention.GroupedQueryAttention' +model.decoder.transformer.layer.self_attention.attention.kv_cache.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.kv_cache.klass: 'axlearn.common.attention.KVCache' +model.decoder.transformer.layer.self_attention.attention.num_heads: 96 +model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.self_attention.structure: 'prenorm' +model.decoder.transformer.num_layers: 80 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 32768 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.metrics.klass: 'axlearn.common.causal_lm.CompositeLossMetrics' +model.metrics.metrics['lm'].klass: 'axlearn.common.causal_lm.CrossEntropyLossMetrics' +model.metrics.metrics['aux'].klass: 'axlearn.common.causal_lm.AuxLossMetrics' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-fp8_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-fp8_init.txt new file mode 100644 index 000000000..d3f162fe3 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-fp8_init.txt @@ -0,0 +1,10 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32768, 12288], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(12288, 112, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(12288, 96, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(12288, 43008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(12288, 43008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(43008, 12288), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/output_norm/scale: constant(1.0) +decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(32768, 12288), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-fp8_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-fp8_regularizer.txt new file mode 100644 index 000000000..65733fb7f --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-fp8_regularizer.txt @@ -0,0 +1,11 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/lm_head/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2.txt new file mode 100644 index 000000000..95e954a3b --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2.txt @@ -0,0 +1,280 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 50000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 524288 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 5000 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 524288 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 5000 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.input_dispatcher.global_logical_batch_size: 1024 +evalers['train'].input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 4096 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +evalers['train'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 524288 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 5000 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.input_dispatcher.global_logical_batch_size: 1024 +evalers['validation'].input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 4096 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +evalers['validation'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.input_dispatcher.global_logical_batch_size: 1024 +input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +input.input_partitioner.fn: 'axlearn.common.input_base.partition_by_path_rank' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][0]: 'data' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][1]: 'expert' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][2]: 'fsdp' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][0]: 'data' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][1]: 'expert' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][2]: 'fsdp' +input.input_partitioner.path_rank_to_partition[(None, 2)][1]: 'seq' +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 4096 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.00015 +learner.optimizer.args[1].update_schedule.alpha: 0.1 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 524288 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 2000 +learner.optimizer.args[1].weight_decay: 0.1 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 524288 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'tpu-v6e-256.*' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: 64 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.common.utils.policy' +mesh_rules[0][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' +mesh_rules[0][1].config_modifiers[2].tpu_block_size: 1024 +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_shape[0]: 1 +mesh_shape[1]: -1 +mesh_shape[2]: 1 +mesh_shape[3]: 64 +mesh_shape[4]: 1 +mesh_shape[5]: 4 +model.batch_axis_names: None +model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' +model.decoder.dim: 12288 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0]: None +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 1 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.lm_head.klass: 'axlearn.common.decoder.LmHead' +model.decoder.lm_head.param_partition_spec[0]: None +model.decoder.lm_head.param_partition_spec[1]: 'model' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 0 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' +model.decoder.transformer.layer.feed_forward.hidden_dim.round_up_to_multiples_of: 256 +model.decoder.transformer.layer.feed_forward.hidden_dim.scale: 3.5 +model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.feed_forward.linear1.bias: False +model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.bias: False +model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.feed_forward.structure: 'prenorm' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' +model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' +model.decoder.transformer.layer.self_attention.attention.causal: True +model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8 +model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 10000.0 +model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.attention.GroupedQueryAttention' +model.decoder.transformer.layer.self_attention.attention.kv_cache.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.kv_cache.klass: 'axlearn.common.attention.KVCache' +model.decoder.transformer.layer.self_attention.attention.num_heads: 96 +model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.self_attention.structure: 'prenorm' +model.decoder.transformer.num_layers: 80 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 32768 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.metrics.klass: 'axlearn.common.causal_lm.CompositeLossMetrics' +model.metrics.metrics['lm'].klass: 'axlearn.common.causal_lm.CrossEntropyLossMetrics' +model.metrics.metrics['aux'].klass: 'axlearn.common.causal_lm.AuxLossMetrics' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2_init.txt new file mode 100644 index 000000000..d3f162fe3 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2_init.txt @@ -0,0 +1,10 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32768, 12288], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(12288, 112, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(12288, 96, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(12288, 43008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(12288, 43008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(43008, 12288), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/output_norm/scale: constant(1.0) +decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(32768, 12288), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2_regularizer.txt new file mode 100644 index 000000000..65733fb7f --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2_regularizer.txt @@ -0,0 +1,11 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/lm_head/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-flash-fp8.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-flash-fp8.txt new file mode 100644 index 000000000..b148649f5 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-flash-fp8.txt @@ -0,0 +1,313 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 50000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 983040 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 5000 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 983040 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 5000 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.input_dispatcher.global_logical_batch_size: 2048 +evalers['train'].input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 8192 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +evalers['train'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_128k_c4.model' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 983040 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 5000 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.input_dispatcher.global_logical_batch_size: 2048 +evalers['validation'].input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 8192 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +evalers['validation'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_128k_c4.model' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.input_dispatcher.global_logical_batch_size: 2048 +input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +input.input_partitioner.fn: 'axlearn.common.input_base.partition_by_path_rank' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][0]: 'data' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][1]: 'expert' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][2]: 'fsdp' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][0]: 'data' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][1]: 'expert' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][2]: 'fsdp' +input.input_partitioner.path_rank_to_partition[(None, 2)][1]: 'seq' +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 8192 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +input.source.vocab_cfg.sentencepiece_model_name: 'bpe_128k_c4.model' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.00015 +learner.optimizer.args[1].update_schedule.alpha: 0.1 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 983040 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 2000 +learner.optimizer.args[1].weight_decay: 0.1 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 983040 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'tpu-v6e-256.*' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: 64 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.common.utils.policy' +mesh_rules[0][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' +mesh_rules[0][1].config_modifiers[2].tpu_block_size: 1024 +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_shape[0]: 1 +mesh_shape[1]: -1 +mesh_shape[2]: 1 +mesh_shape[3]: 64 +mesh_shape[4]: 1 +mesh_shape[5]: 4 +model.batch_axis_names: None +model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' +model.decoder.dim: 12288 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0]: None +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 1 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.lm_head.klass: 'axlearn.common.decoder.LmHead' +model.decoder.lm_head.param_partition_spec[0]: None +model.decoder.lm_head.param_partition_spec[1]: 'model' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 0 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' +model.decoder.transformer.layer.feed_forward.hidden_dim.round_up_to_multiples_of: 256 +model.decoder.transformer.layer.feed_forward.hidden_dim.scale: 3.5 +model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.feed_forward.linear1.bias: False +model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.bias: False +model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.feed_forward.structure: 'prenorm' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' +model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' +model.decoder.transformer.layer.self_attention.attention.causal: True +model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8 +model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 +model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.flash_attention.layer.FlashAttention' +model.decoder.transformer.layer.self_attention.attention.kv_cache.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.kv_cache.klass: 'axlearn.common.attention.KVCache' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][2]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.self_attention.attention.num_heads: 96 +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][1]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][2]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.self_attention.attention.tpu_block_size: 512 +model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.self_attention.structure: 'prenorm' +model.decoder.transformer.num_layers: 80 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 131072 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.metrics.klass: 'axlearn.common.causal_lm.CompositeLossMetrics' +model.metrics.metrics['lm'].klass: 'axlearn.common.causal_lm.CrossEntropyLossMetrics' +model.metrics.metrics['aux'].klass: 'axlearn.common.causal_lm.AuxLossMetrics' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-flash-fp8_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-flash-fp8_init.txt new file mode 100644 index 000000000..ce0614cbc --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-flash-fp8_init.txt @@ -0,0 +1,10 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 12288], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(12288, 112, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(12288, 96, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(12288, 43008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(12288, 43008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(43008, 12288), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/output_norm/scale: constant(1.0) +decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(131072, 12288), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-flash-fp8_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-flash-fp8_regularizer.txt new file mode 100644 index 000000000..65733fb7f --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-flash-fp8_regularizer.txt @@ -0,0 +1,11 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/lm_head/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-flash.txt new file mode 100644 index 000000000..b148649f5 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-flash.txt @@ -0,0 +1,313 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 50000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 983040 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 5000 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 983040 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 5000 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.input_dispatcher.global_logical_batch_size: 2048 +evalers['train'].input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 8192 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +evalers['train'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_128k_c4.model' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 983040 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 5000 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.input_dispatcher.global_logical_batch_size: 2048 +evalers['validation'].input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 8192 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +evalers['validation'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_128k_c4.model' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.input_dispatcher.global_logical_batch_size: 2048 +input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +input.input_partitioner.fn: 'axlearn.common.input_base.partition_by_path_rank' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][0]: 'data' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][1]: 'expert' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][2]: 'fsdp' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][0]: 'data' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][1]: 'expert' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][2]: 'fsdp' +input.input_partitioner.path_rank_to_partition[(None, 2)][1]: 'seq' +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 8192 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +input.source.vocab_cfg.sentencepiece_model_name: 'bpe_128k_c4.model' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.00015 +learner.optimizer.args[1].update_schedule.alpha: 0.1 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 983040 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 2000 +learner.optimizer.args[1].weight_decay: 0.1 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 983040 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'tpu-v6e-256.*' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: 64 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.common.utils.policy' +mesh_rules[0][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' +mesh_rules[0][1].config_modifiers[2].tpu_block_size: 1024 +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_shape[0]: 1 +mesh_shape[1]: -1 +mesh_shape[2]: 1 +mesh_shape[3]: 64 +mesh_shape[4]: 1 +mesh_shape[5]: 4 +model.batch_axis_names: None +model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' +model.decoder.dim: 12288 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0]: None +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 1 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.lm_head.klass: 'axlearn.common.decoder.LmHead' +model.decoder.lm_head.param_partition_spec[0]: None +model.decoder.lm_head.param_partition_spec[1]: 'model' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 0 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' +model.decoder.transformer.layer.feed_forward.hidden_dim.round_up_to_multiples_of: 256 +model.decoder.transformer.layer.feed_forward.hidden_dim.scale: 3.5 +model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.feed_forward.linear1.bias: False +model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.bias: False +model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.feed_forward.structure: 'prenorm' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' +model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' +model.decoder.transformer.layer.self_attention.attention.causal: True +model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8 +model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 +model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.flash_attention.layer.FlashAttention' +model.decoder.transformer.layer.self_attention.attention.kv_cache.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.kv_cache.klass: 'axlearn.common.attention.KVCache' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][2]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.self_attention.attention.num_heads: 96 +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][1]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][2]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.self_attention.attention.tpu_block_size: 512 +model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.self_attention.structure: 'prenorm' +model.decoder.transformer.num_layers: 80 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 131072 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.metrics.klass: 'axlearn.common.causal_lm.CompositeLossMetrics' +model.metrics.metrics['lm'].klass: 'axlearn.common.causal_lm.CrossEntropyLossMetrics' +model.metrics.metrics['aux'].klass: 'axlearn.common.causal_lm.AuxLossMetrics' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-flash_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-flash_init.txt new file mode 100644 index 000000000..ce0614cbc --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-flash_init.txt @@ -0,0 +1,10 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 12288], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(12288, 112, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(12288, 96, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(12288, 43008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(12288, 43008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(43008, 12288), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/output_norm/scale: constant(1.0) +decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(131072, 12288), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-flash_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-flash_regularizer.txt new file mode 100644 index 000000000..65733fb7f --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-flash_regularizer.txt @@ -0,0 +1,11 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/lm_head/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-fp8.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-fp8.txt new file mode 100644 index 000000000..6f4a63859 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-fp8.txt @@ -0,0 +1,280 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 50000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 983040 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 5000 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 983040 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 5000 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.input_dispatcher.global_logical_batch_size: 2048 +evalers['train'].input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 8192 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +evalers['train'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_128k_c4.model' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 983040 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 5000 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.input_dispatcher.global_logical_batch_size: 2048 +evalers['validation'].input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 8192 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +evalers['validation'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_128k_c4.model' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.input_dispatcher.global_logical_batch_size: 2048 +input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +input.input_partitioner.fn: 'axlearn.common.input_base.partition_by_path_rank' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][0]: 'data' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][1]: 'expert' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][2]: 'fsdp' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][0]: 'data' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][1]: 'expert' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][2]: 'fsdp' +input.input_partitioner.path_rank_to_partition[(None, 2)][1]: 'seq' +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 8192 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +input.source.vocab_cfg.sentencepiece_model_name: 'bpe_128k_c4.model' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.00015 +learner.optimizer.args[1].update_schedule.alpha: 0.1 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 983040 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 2000 +learner.optimizer.args[1].weight_decay: 0.1 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 983040 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'tpu-v6e-256.*' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: 64 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.common.utils.policy' +mesh_rules[0][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' +mesh_rules[0][1].config_modifiers[2].tpu_block_size: 1024 +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_shape[0]: 1 +mesh_shape[1]: -1 +mesh_shape[2]: 1 +mesh_shape[3]: 64 +mesh_shape[4]: 1 +mesh_shape[5]: 4 +model.batch_axis_names: None +model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' +model.decoder.dim: 12288 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0]: None +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 1 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.lm_head.klass: 'axlearn.common.decoder.LmHead' +model.decoder.lm_head.param_partition_spec[0]: None +model.decoder.lm_head.param_partition_spec[1]: 'model' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 0 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' +model.decoder.transformer.layer.feed_forward.hidden_dim.round_up_to_multiples_of: 256 +model.decoder.transformer.layer.feed_forward.hidden_dim.scale: 3.5 +model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.feed_forward.linear1.bias: False +model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.bias: False +model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.feed_forward.structure: 'prenorm' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' +model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' +model.decoder.transformer.layer.self_attention.attention.causal: True +model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8 +model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 +model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.attention.GroupedQueryAttention' +model.decoder.transformer.layer.self_attention.attention.kv_cache.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.kv_cache.klass: 'axlearn.common.attention.KVCache' +model.decoder.transformer.layer.self_attention.attention.num_heads: 96 +model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.self_attention.structure: 'prenorm' +model.decoder.transformer.num_layers: 80 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 131072 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.metrics.klass: 'axlearn.common.causal_lm.CompositeLossMetrics' +model.metrics.metrics['lm'].klass: 'axlearn.common.causal_lm.CrossEntropyLossMetrics' +model.metrics.metrics['aux'].klass: 'axlearn.common.causal_lm.AuxLossMetrics' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-fp8_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-fp8_init.txt new file mode 100644 index 000000000..ce0614cbc --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-fp8_init.txt @@ -0,0 +1,10 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 12288], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(12288, 112, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(12288, 96, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(12288, 43008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(12288, 43008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(43008, 12288), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/output_norm/scale: constant(1.0) +decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(131072, 12288), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-fp8_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-fp8_regularizer.txt new file mode 100644 index 000000000..65733fb7f --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-fp8_regularizer.txt @@ -0,0 +1,11 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/lm_head/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-flash-fp8.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-flash-fp8.txt new file mode 100644 index 000000000..1389620a4 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-flash-fp8.txt @@ -0,0 +1,313 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 50000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 983040 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 5000 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 983040 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 5000 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.input_dispatcher.global_logical_batch_size: 2048 +evalers['train'].input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 8192 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['train'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 983040 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 5000 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.input_dispatcher.global_logical_batch_size: 2048 +evalers['validation'].input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 8192 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['validation'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.input_dispatcher.global_logical_batch_size: 2048 +input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +input.input_partitioner.fn: 'axlearn.common.input_base.partition_by_path_rank' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][0]: 'data' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][1]: 'expert' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][2]: 'fsdp' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][0]: 'data' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][1]: 'expert' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][2]: 'fsdp' +input.input_partitioner.path_rank_to_partition[(None, 2)][1]: 'seq' +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 8192 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.00015 +learner.optimizer.args[1].update_schedule.alpha: 0.1 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 983040 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 2000 +learner.optimizer.args[1].weight_decay: 0.1 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 983040 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'tpu-v6e-256.*' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: 64 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.common.utils.policy' +mesh_rules[0][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' +mesh_rules[0][1].config_modifiers[2].tpu_block_size: 1024 +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_shape[0]: 1 +mesh_shape[1]: -1 +mesh_shape[2]: 1 +mesh_shape[3]: 64 +mesh_shape[4]: 1 +mesh_shape[5]: 4 +model.batch_axis_names: None +model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' +model.decoder.dim: 12288 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0]: None +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 128001 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.lm_head.klass: 'axlearn.common.decoder.LmHead' +model.decoder.lm_head.param_partition_spec[0]: None +model.decoder.lm_head.param_partition_spec[1]: 'model' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 128004 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' +model.decoder.transformer.layer.feed_forward.hidden_dim.round_up_to_multiples_of: 256 +model.decoder.transformer.layer.feed_forward.hidden_dim.scale: 3.5 +model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.feed_forward.linear1.bias: False +model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.bias: False +model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.feed_forward.structure: 'prenorm' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' +model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' +model.decoder.transformer.layer.self_attention.attention.causal: True +model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8 +model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 +model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.flash_attention.layer.FlashAttention' +model.decoder.transformer.layer.self_attention.attention.kv_cache.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.kv_cache.klass: 'axlearn.common.attention.KVCache' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][2]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.self_attention.attention.num_heads: 96 +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][1]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][2]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.self_attention.attention.tpu_block_size: 512 +model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.self_attention.structure: 'prenorm' +model.decoder.transformer.num_layers: 80 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 128256 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.metrics.klass: 'axlearn.common.causal_lm.CompositeLossMetrics' +model.metrics.metrics['lm'].klass: 'axlearn.common.causal_lm.CrossEntropyLossMetrics' +model.metrics.metrics['aux'].klass: 'axlearn.common.causal_lm.AuxLossMetrics' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-flash-fp8_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-flash-fp8_init.txt new file mode 100644 index 000000000..e4b66772f --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-flash-fp8_init.txt @@ -0,0 +1,10 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 12288], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(12288, 112, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(12288, 96, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(12288, 43008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(12288, 43008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(43008, 12288), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/output_norm/scale: constant(1.0) +decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(128256, 12288), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-flash-fp8_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-flash-fp8_regularizer.txt new file mode 100644 index 000000000..65733fb7f --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-flash-fp8_regularizer.txt @@ -0,0 +1,11 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/lm_head/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-flash.txt new file mode 100644 index 000000000..1389620a4 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-flash.txt @@ -0,0 +1,313 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 50000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 983040 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 5000 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 983040 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 5000 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.input_dispatcher.global_logical_batch_size: 2048 +evalers['train'].input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 8192 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['train'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 983040 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 5000 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.input_dispatcher.global_logical_batch_size: 2048 +evalers['validation'].input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 8192 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['validation'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.input_dispatcher.global_logical_batch_size: 2048 +input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +input.input_partitioner.fn: 'axlearn.common.input_base.partition_by_path_rank' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][0]: 'data' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][1]: 'expert' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][2]: 'fsdp' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][0]: 'data' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][1]: 'expert' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][2]: 'fsdp' +input.input_partitioner.path_rank_to_partition[(None, 2)][1]: 'seq' +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 8192 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.00015 +learner.optimizer.args[1].update_schedule.alpha: 0.1 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 983040 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 2000 +learner.optimizer.args[1].weight_decay: 0.1 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 983040 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'tpu-v6e-256.*' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: 64 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.common.utils.policy' +mesh_rules[0][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' +mesh_rules[0][1].config_modifiers[2].tpu_block_size: 1024 +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_shape[0]: 1 +mesh_shape[1]: -1 +mesh_shape[2]: 1 +mesh_shape[3]: 64 +mesh_shape[4]: 1 +mesh_shape[5]: 4 +model.batch_axis_names: None +model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' +model.decoder.dim: 12288 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0]: None +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 128001 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.lm_head.klass: 'axlearn.common.decoder.LmHead' +model.decoder.lm_head.param_partition_spec[0]: None +model.decoder.lm_head.param_partition_spec[1]: 'model' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 128004 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' +model.decoder.transformer.layer.feed_forward.hidden_dim.round_up_to_multiples_of: 256 +model.decoder.transformer.layer.feed_forward.hidden_dim.scale: 3.5 +model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.feed_forward.linear1.bias: False +model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.bias: False +model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.feed_forward.structure: 'prenorm' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' +model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' +model.decoder.transformer.layer.self_attention.attention.causal: True +model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8 +model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 +model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.flash_attention.layer.FlashAttention' +model.decoder.transformer.layer.self_attention.attention.kv_cache.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.kv_cache.klass: 'axlearn.common.attention.KVCache' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][2]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.self_attention.attention.num_heads: 96 +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][1]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][2]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.self_attention.attention.tpu_block_size: 512 +model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.self_attention.structure: 'prenorm' +model.decoder.transformer.num_layers: 80 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 128256 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.metrics.klass: 'axlearn.common.causal_lm.CompositeLossMetrics' +model.metrics.metrics['lm'].klass: 'axlearn.common.causal_lm.CrossEntropyLossMetrics' +model.metrics.metrics['aux'].klass: 'axlearn.common.causal_lm.AuxLossMetrics' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-flash_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-flash_init.txt new file mode 100644 index 000000000..e4b66772f --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-flash_init.txt @@ -0,0 +1,10 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 12288], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(12288, 112, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(12288, 96, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(12288, 43008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(12288, 43008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(43008, 12288), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/output_norm/scale: constant(1.0) +decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(128256, 12288), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-flash_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-flash_regularizer.txt new file mode 100644 index 000000000..65733fb7f --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-flash_regularizer.txt @@ -0,0 +1,11 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/lm_head/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-fp8.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-fp8.txt new file mode 100644 index 000000000..1d766fae4 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-fp8.txt @@ -0,0 +1,280 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 50000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 983040 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 5000 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 983040 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 5000 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.input_dispatcher.global_logical_batch_size: 2048 +evalers['train'].input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 8192 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['train'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 983040 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 5000 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.input_dispatcher.global_logical_batch_size: 2048 +evalers['validation'].input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 8192 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['validation'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.input_dispatcher.global_logical_batch_size: 2048 +input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +input.input_partitioner.fn: 'axlearn.common.input_base.partition_by_path_rank' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][0]: 'data' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][1]: 'expert' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][2]: 'fsdp' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][0]: 'data' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][1]: 'expert' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][2]: 'fsdp' +input.input_partitioner.path_rank_to_partition[(None, 2)][1]: 'seq' +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 8192 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.00015 +learner.optimizer.args[1].update_schedule.alpha: 0.1 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 983040 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 2000 +learner.optimizer.args[1].weight_decay: 0.1 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 983040 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'tpu-v6e-256.*' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: 64 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.common.utils.policy' +mesh_rules[0][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' +mesh_rules[0][1].config_modifiers[2].tpu_block_size: 1024 +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_shape[0]: 1 +mesh_shape[1]: -1 +mesh_shape[2]: 1 +mesh_shape[3]: 64 +mesh_shape[4]: 1 +mesh_shape[5]: 4 +model.batch_axis_names: None +model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' +model.decoder.dim: 12288 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0]: None +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 128001 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.lm_head.klass: 'axlearn.common.decoder.LmHead' +model.decoder.lm_head.param_partition_spec[0]: None +model.decoder.lm_head.param_partition_spec[1]: 'model' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 128004 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' +model.decoder.transformer.layer.feed_forward.hidden_dim.round_up_to_multiples_of: 256 +model.decoder.transformer.layer.feed_forward.hidden_dim.scale: 3.5 +model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.feed_forward.linear1.bias: False +model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.bias: False +model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.feed_forward.structure: 'prenorm' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' +model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' +model.decoder.transformer.layer.self_attention.attention.causal: True +model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8 +model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 +model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.attention.GroupedQueryAttention' +model.decoder.transformer.layer.self_attention.attention.kv_cache.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.kv_cache.klass: 'axlearn.common.attention.KVCache' +model.decoder.transformer.layer.self_attention.attention.num_heads: 96 +model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.self_attention.structure: 'prenorm' +model.decoder.transformer.num_layers: 80 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 128256 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.metrics.klass: 'axlearn.common.causal_lm.CompositeLossMetrics' +model.metrics.metrics['lm'].klass: 'axlearn.common.causal_lm.CrossEntropyLossMetrics' +model.metrics.metrics['aux'].klass: 'axlearn.common.causal_lm.AuxLossMetrics' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-fp8_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-fp8_init.txt new file mode 100644 index 000000000..e4b66772f --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-fp8_init.txt @@ -0,0 +1,10 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 12288], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(12288, 112, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(12288, 96, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(12288, 43008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(12288, 43008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(43008, 12288), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/output_norm/scale: constant(1.0) +decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(128256, 12288), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-fp8_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-fp8_regularizer.txt new file mode 100644 index 000000000..65733fb7f --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-fp8_regularizer.txt @@ -0,0 +1,11 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/lm_head/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken.txt new file mode 100644 index 000000000..1d766fae4 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken.txt @@ -0,0 +1,280 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 50000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 983040 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 5000 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 983040 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 5000 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.input_dispatcher.global_logical_batch_size: 2048 +evalers['train'].input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 8192 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['train'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 983040 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 5000 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.input_dispatcher.global_logical_batch_size: 2048 +evalers['validation'].input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 8192 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['validation'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.input_dispatcher.global_logical_batch_size: 2048 +input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +input.input_partitioner.fn: 'axlearn.common.input_base.partition_by_path_rank' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][0]: 'data' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][1]: 'expert' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][2]: 'fsdp' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][0]: 'data' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][1]: 'expert' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][2]: 'fsdp' +input.input_partitioner.path_rank_to_partition[(None, 2)][1]: 'seq' +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 8192 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.00015 +learner.optimizer.args[1].update_schedule.alpha: 0.1 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 983040 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 2000 +learner.optimizer.args[1].weight_decay: 0.1 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 983040 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'tpu-v6e-256.*' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: 64 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.common.utils.policy' +mesh_rules[0][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' +mesh_rules[0][1].config_modifiers[2].tpu_block_size: 1024 +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_shape[0]: 1 +mesh_shape[1]: -1 +mesh_shape[2]: 1 +mesh_shape[3]: 64 +mesh_shape[4]: 1 +mesh_shape[5]: 4 +model.batch_axis_names: None +model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' +model.decoder.dim: 12288 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0]: None +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 128001 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.lm_head.klass: 'axlearn.common.decoder.LmHead' +model.decoder.lm_head.param_partition_spec[0]: None +model.decoder.lm_head.param_partition_spec[1]: 'model' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 128004 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' +model.decoder.transformer.layer.feed_forward.hidden_dim.round_up_to_multiples_of: 256 +model.decoder.transformer.layer.feed_forward.hidden_dim.scale: 3.5 +model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.feed_forward.linear1.bias: False +model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.bias: False +model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.feed_forward.structure: 'prenorm' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' +model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' +model.decoder.transformer.layer.self_attention.attention.causal: True +model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8 +model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 +model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.attention.GroupedQueryAttention' +model.decoder.transformer.layer.self_attention.attention.kv_cache.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.kv_cache.klass: 'axlearn.common.attention.KVCache' +model.decoder.transformer.layer.self_attention.attention.num_heads: 96 +model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.self_attention.structure: 'prenorm' +model.decoder.transformer.num_layers: 80 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 128256 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.metrics.klass: 'axlearn.common.causal_lm.CompositeLossMetrics' +model.metrics.metrics['lm'].klass: 'axlearn.common.causal_lm.CrossEntropyLossMetrics' +model.metrics.metrics['aux'].klass: 'axlearn.common.causal_lm.AuxLossMetrics' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken_init.txt new file mode 100644 index 000000000..e4b66772f --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken_init.txt @@ -0,0 +1,10 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 12288], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(12288, 112, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(12288, 96, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(12288, 43008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(12288, 43008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(43008, 12288), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/output_norm/scale: constant(1.0) +decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(128256, 12288), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken_regularizer.txt new file mode 100644 index 000000000..65733fb7f --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken_regularizer.txt @@ -0,0 +1,11 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/lm_head/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3.txt new file mode 100644 index 000000000..6f4a63859 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3.txt @@ -0,0 +1,280 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 50000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 983040 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 5000 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 983040 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 5000 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.input_dispatcher.global_logical_batch_size: 2048 +evalers['train'].input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 8192 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +evalers['train'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_128k_c4.model' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 983040 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 5000 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.input_dispatcher.global_logical_batch_size: 2048 +evalers['validation'].input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 8192 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +evalers['validation'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_128k_c4.model' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.input_dispatcher.global_logical_batch_size: 2048 +input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +input.input_partitioner.fn: 'axlearn.common.input_base.partition_by_path_rank' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][0]: 'data' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][1]: 'expert' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][2]: 'fsdp' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][0]: 'data' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][1]: 'expert' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][2]: 'fsdp' +input.input_partitioner.path_rank_to_partition[(None, 2)][1]: 'seq' +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 8192 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +input.source.vocab_cfg.sentencepiece_model_name: 'bpe_128k_c4.model' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.00015 +learner.optimizer.args[1].update_schedule.alpha: 0.1 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 983040 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 2000 +learner.optimizer.args[1].weight_decay: 0.1 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 983040 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'tpu-v6e-256.*' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: 64 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.common.utils.policy' +mesh_rules[0][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' +mesh_rules[0][1].config_modifiers[2].tpu_block_size: 1024 +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_shape[0]: 1 +mesh_shape[1]: -1 +mesh_shape[2]: 1 +mesh_shape[3]: 64 +mesh_shape[4]: 1 +mesh_shape[5]: 4 +model.batch_axis_names: None +model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' +model.decoder.dim: 12288 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0]: None +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 1 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.lm_head.klass: 'axlearn.common.decoder.LmHead' +model.decoder.lm_head.param_partition_spec[0]: None +model.decoder.lm_head.param_partition_spec[1]: 'model' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 0 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' +model.decoder.transformer.layer.feed_forward.hidden_dim.round_up_to_multiples_of: 256 +model.decoder.transformer.layer.feed_forward.hidden_dim.scale: 3.5 +model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.feed_forward.linear1.bias: False +model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.bias: False +model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.feed_forward.structure: 'prenorm' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' +model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' +model.decoder.transformer.layer.self_attention.attention.causal: True +model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8 +model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 +model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.attention.GroupedQueryAttention' +model.decoder.transformer.layer.self_attention.attention.kv_cache.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.kv_cache.klass: 'axlearn.common.attention.KVCache' +model.decoder.transformer.layer.self_attention.attention.num_heads: 96 +model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.self_attention.structure: 'prenorm' +model.decoder.transformer.num_layers: 80 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 131072 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.metrics.klass: 'axlearn.common.causal_lm.CompositeLossMetrics' +model.metrics.metrics['lm'].klass: 'axlearn.common.causal_lm.CrossEntropyLossMetrics' +model.metrics.metrics['aux'].klass: 'axlearn.common.causal_lm.AuxLossMetrics' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3_init.txt new file mode 100644 index 000000000..ce0614cbc --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3_init.txt @@ -0,0 +1,10 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 12288], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(12288, 112, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(12288, 96, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(12288, 43008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(12288, 43008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(43008, 12288), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/output_norm/scale: constant(1.0) +decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(131072, 12288), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3_regularizer.txt new file mode 100644 index 000000000..65733fb7f --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3_regularizer.txt @@ -0,0 +1,11 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/lm_head/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index 57c910c70..f34e52269 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -67,7 +67,7 @@ from axlearn.experiments.text.gpt.common import scaled_hidden_dim from axlearn.experiments.trainer_config_utils import TrainerConfigFn, V6eFlashConfigModifier -MODEL_SIZES = ("test", "1B", "3B", "7B", "8B", "70B") +MODEL_SIZES = ("test", "1B", "3B", "7B", "8B", "70B", "150B") class Version(enum.Enum): @@ -113,6 +113,7 @@ class Version(enum.Enum): "test": 2 * (1024**4), # 2T tokens "7B": 2 * (1024**4), # 2T tokens "70B": 2 * (1024**4), # 2T tokens + "150B": 2 * (1024**4), # 2T tokens }, Version.V3: { "test": 15 * (1024**4), # 15T tokens @@ -120,6 +121,7 @@ class Version(enum.Enum): "3B": 15 * (1024**4), # 15T tokens "7B": 15 * (1024**4), # 15T tokens "70B": 15 * (1024**4), # 15T tokens + "150B": 15 * (1024**4), # 15T tokens }, Version.V3_TIKTOKEN: { "test": 15 * (1024**4), # 15T tokens @@ -127,6 +129,7 @@ class Version(enum.Enum): "3B": 15 * (1024**4), # 15T tokens "8B": 15 * (1024**4), # 15T tokens "70B": 15 * (1024**4), # 15T tokens + "150B": 15 * (1024**4), # 15T tokens }, } @@ -809,6 +812,54 @@ def get_trainer_kwargs( ), ), ) + elif model_size == "150B": + trainer_kwargs = dict( + model_kwargs=dict( + num_layers=80, + hidden_dim=128 * 96, + num_heads=96, + # No GQA support in V1 models, so num_kv_heads is the same as num_heads. + num_kv_heads=None if version == Version.V1 else 8, + ffn_dim=scaled_hidden_dim(scale=3.5, round_up_to_multiples_of=256), + rope_theta=rope_theta, + shared_lm_head=False, + flash_attention=flash_attention, + ), + learner_kwargs=dict(peak_lr=1.5e-4, weight_decay=0.1), + max_sequence_length=max_sequence_length, + train_batch_size=train_batch_size, + max_step=max_step, + mesh_shape=mesh_shape_from_axes(data=-1, fsdp=64, model=4), + mesh_rules=( + ( + # Target per-device token count = 4k. + # PDBS = 0.5 at 8k context. + # Each slice can train a batch size of 128. + "tpu-v6e-256.*", + ChainConfigModifier.default_config().set( + config_modifiers=[ + MeshShapeModifier.default_config().set( + mesh_shape=mesh_shape_from_axes(data=-1, fsdp=64, model=4) + ), + RematSpecModifier.default_config().set( + remat_policies={ + "model.decoder.transformer.layer": RematSpec( + prevent_cse=False, + policy=save_and_offload_only_these_names_regex( + names_which_can_be_offloaded=".*input", + names_which_can_be_saved=None, + offload_src="device", + offload_dst="pinned_host", + ), + ), + } + ), + V6eFlashConfigModifier.default_config(), + ], + ), + ), + ), + ) else: raise NotImplementedError(f"Unknown model size {model_size}.") model_kwargs = trainer_kwargs.pop("model_kwargs") diff --git a/patches/shard_map.py.patch b/patches/shard_map.py.patch new file mode 100644 index 000000000..e6e10104f --- /dev/null +++ b/patches/shard_map.py.patch @@ -0,0 +1,27 @@ +--- shard_map_orig.py 2025-06-18 01:27:00.782665547 +0000 ++++ shard_map.py 2025-06-18 01:26:06.798346281 +0000 +@@ -1793,10 +1793,10 @@ + ) -> tuple[core.JaxprEqn, core.JaxprEqn, Sequence[bool], Sequence[bool], + list[core.Var]]: + jaxpr, mesh = eqn.params['jaxpr'], eqn.params['mesh'] +- auto = eqn.params['auto'] +- with _extend_axis_env(mesh, auto): ++ manual_axes = frozenset() ++ with (_extend_axis_env(mesh, manual_axes), use_abstract_mesh(_as_manual_mesh(mesh, manual_axes))): + jaxpr_known, jaxpr_staged, unks_out, inst_out, num_res = \ +- pe.partial_eval_jaxpr_custom(jaxpr, unks_in, inst_in, False, False, saveable) ++ pe.partial_eval_jaxpr_custom(jaxpr, unks_in, inst_in, False, False, saveable) + num_out_primals = len(jaxpr_known.outvars) - num_res + in_fwd = pe._jaxpr_forwarding(jaxpr_known)[num_out_primals:] + out_vars, res_vars = split_list(jaxpr_known.outvars, [num_out_primals]) +@@ -1804,8 +1804,8 @@ + out_fwd = [idx_map.get(id(v)) for v in res_vars] + which = [f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)] + mesh = eqn.params['mesh'] +- with (_extend_axis_env(mesh, auto), +- use_abstract_mesh(_as_manual_mesh(mesh, auto))): ++ with (_extend_axis_env(mesh, manual_axes), ++ use_abstract_mesh(_as_manual_mesh(mesh, frozenset()))): + jaxpr_known = pe.prune_jaxpr_outputs(jaxpr_known, [True] * num_out_primals + which) + jaxpr_known, jaxpr_staged = _add_reshapes(which, jaxpr_known, jaxpr_staged) + jaxpr_known = core.remove_named_axis_effects(jaxpr_known, mesh.axis_names)