diff --git a/Dockerfile b/Dockerfile index 29db664d3..35bb6c2b2 100644 --- a/Dockerfile +++ b/Dockerfile @@ -88,7 +88,15 @@ ARG EXTRAS= ENV UV_FIND_LINKS=https://storage.googleapis.com/jax-releases/libtpu_releases.html # Ensure we install the TPU version, even if building locally. # Jax will fallback to CPU when run on a machine without TPU. -RUN uv pip install --prerelease=allow .[core,tpu] && uv cache clean +COPY libtpu.so /root/libtpu.so +RUN uv pip install --prerelease=allow .[core,gcp,tpu] && uv cache clean +RUN uv pip install libtpu==0.0.14 + +# Add this line to print the installed version of libtpu. +RUN pip show libtpu | grep Version +RUN pip show jax | grep Version +RUN pip show jaxlib | grep Version + RUN if [ -n "$EXTRAS" ]; then uv pip install .[$EXTRAS] && uv cache clean; fi COPY . . diff --git a/axlearn/cloud/gcp/jobs/launch.py b/axlearn/cloud/gcp/jobs/launch.py index 24b74cb68..6f968dcc7 100644 --- a/axlearn/cloud/gcp/jobs/launch.py +++ b/axlearn/cloud/gcp/jobs/launch.py @@ -780,6 +780,6 @@ def _wrapped_usage( if __name__ == "__main__": - configure_logging(logging.INFO) + configure_logging(logging.DEBUG) _private_flags() app.run(main) diff --git a/axlearn/cloud/gcp/jobset_utils.py b/axlearn/cloud/gcp/jobset_utils.py index 710f96701..03a4dd0c0 100644 --- a/axlearn/cloud/gcp/jobset_utils.py +++ b/axlearn/cloud/gcp/jobset_utils.py @@ -4,7 +4,6 @@ import io import logging -import math import os from dataclasses import dataclass from typing import Any, Optional, Sequence @@ -27,10 +26,7 @@ ) from axlearn.cloud.gcp.config import gcp_settings from axlearn.cloud.gcp.node_pool import PRE_PROVISIONER_LABEL -from axlearn.cloud.gcp.system_characteristics import ( - GCE_MACHINE_TYPE_TO_MEMORY_CHARACTERISTICS, - USER_FACING_NAME_TO_SYSTEM_CHARACTERISTICS, -) +from axlearn.cloud.gcp.system_characteristics import USER_FACING_NAME_TO_SYSTEM_CHARACTERISTICS from axlearn.cloud.gcp.tpu import get_default_env, infer_tpu_workers from axlearn.cloud.gcp.utils import validate_jobset_name from axlearn.common.compiler_options import infer_tpu_type @@ -451,15 +447,17 @@ def _build_container(self) -> Nested[Any]: if cfg.enable_tpu_ici_resiliency is not None: env_vars["ENABLE_ICI_RESILIENCY"] = str(cfg.enable_tpu_ici_resiliency).lower() + env_vars["TPU_LIBRARY_PATH"] = "/root/libtpu.so" + resources = {"limits": {"google.com/tpu": system.chips_per_vm}} - # Set request memory by host machine type. - machine_memory_gi = GCE_MACHINE_TYPE_TO_MEMORY_CHARACTERISTICS.get( - system.gce_machine_type, None - ) - if machine_memory_gi is not None: - request_memory_gi = machine_memory_gi * _MEMORY_REQUEST_PERCENTAGE - resources["limits"]["memory"] = f"{machine_memory_gi}Gi" - resources["requests"] = {"memory": f"{math.floor(request_memory_gi)}Gi"} + # # Set request memory by host machine type. + # machine_memory_gi = GCE_MACHINE_TYPE_TO_MEMORY_CHARACTERISTICS.get( + # system.gce_machine_type, None + # ) + # if machine_memory_gi is not None: + # request_memory_gi = machine_memory_gi * _MEMORY_REQUEST_PERCENTAGE + # resources["limits"]["memory"] = f"{machine_memory_gi}Gi" + # resources["requests"] = {"memory": f"{math.floor(request_memory_gi)}Gi"} k8s_env_vars = [dict(name=k, value=str(v)) for k, v in env_vars.items()] k8s_env_vars.append( @@ -509,8 +507,8 @@ def _build_uploader_container( interval_s = 60 sync_command = f"while true; do gsutil -m rsync -r {src} {dst}; sleep {interval_s}; done" resources = { - "requests": {"cpu": "100m", "memory": "128Mi"}, - "limits": {"cpu": "500m", "memory": "256Mi"}, + # "requests": {"cpu": "100m", "memory": "128Mi"}, + # "limits": {"cpu": "500m", "memory": "256Mi"}, } return dict( name="output-uploader", diff --git a/axlearn/cloud/gcp/measurement.py b/axlearn/cloud/gcp/measurement.py index 0d4ce0069..0eb226e6f 100644 --- a/axlearn/cloud/gcp/measurement.py +++ b/axlearn/cloud/gcp/measurement.py @@ -2,6 +2,9 @@ """Measurement utils for GCP. + For detailed documentation and advanced usage, please refer to: + axlearn/docs/05-Goodput-Monitoring.md + Example: # Enable Goodput when launching an AXLearn training job @@ -13,10 +16,14 @@ --recorder_spec=name=my-run-with-goodput \ --recorder_spec=upload_dir=my-output-directory/summaries \ --recorder_spec=upload_interval=30 \ - --recorder_spec=step_deviation_interval_seconds=30 + --recorder_spec=rolling_window_size=86400,604800 """ +import contextlib +import os +from typing import Optional, Sequence + import jax from absl import flags, logging from ml_goodput_measurement import goodput @@ -38,13 +45,19 @@ class Config(measurement.Recorder.Config): Attributes: upload_dir: Directory to store metrics for the monitor. upload_interval: Time interval (seconds) for monitoring uploads. - step_deviation_interval_seconds: Time interval (seconds) for step deviation metrics - uploads. -1 to disable step deviation uploads. + See "How to Monitor Cumulative Goodput Metrics" in + docs/05-Goodput-Monitoring.md for more details. + rolling_window_size: A sequence of integers defining the rolling window sizes in + seconds. + See "How to Monitor Rolling Window Goodput Metrics" in + docs/05-Goodput-Monitoring.md for more details. + jax_backend: Jax backend type to infer Pathways environment. """ upload_dir: Required[str] = REQUIRED upload_interval: Required[int] = REQUIRED - step_deviation_interval_seconds: int = 30 # Default to 30 seconds + rolling_window_size: Sequence[int] = [] + jax_backend: Optional[str] = None @classmethod def from_flags(cls, fv: flags.FlagValues) -> "GoodputRecorder": @@ -53,68 +66,78 @@ def from_flags(cls, fv: flags.FlagValues) -> "GoodputRecorder": `fv.recorder_spec` will be interpreted as a list of `key=value` pairs; config names corresponding to keys will be set to the corresponding values. A GoodputRecorder can additionally take in following Tensorboard configs in the recorder_spec: - - upload_dir: The directory to write Tensorboard data to. - - upload_interval: The time interval in seconds at which to query and upload data - to Tensorboard. - - step_deviation_interval_seconds: Time interval (seconds) for step deviation metrics - uploads. Set to less than or equal to 0 to disable step deviation uploads. + - upload_dir: The directory to write Tensorboard data to. + - upload_interval: The time interval in seconds at which to query and upload data + to Tensorboard. + - rolling_window_size: Comma-separated list of integers representing rolling window + sizes in seconds. + - jax_backend: The type of jax backend. """ cfg: measurement.Recorder.Config = cls.default_config() - cfg = maybe_set_config(cfg, **parse_kv_flags(fv.recorder_spec, delimiter="=")) - return cfg.instantiate() + parsed_flags = parse_kv_flags(fv.recorder_spec, delimiter="=") + if "upload_interval" in parsed_flags: + parsed_flags["upload_interval"] = int(parsed_flags["upload_interval"]) + if "rolling_window_size" in parsed_flags and isinstance( + parsed_flags["rolling_window_size"], str + ): + parsed_flags["rolling_window_size"] = [ + int(x) for x in parsed_flags["rolling_window_size"].split(",") + ] + return maybe_set_config(cfg, **parsed_flags).instantiate() def __init__(self, cfg): super().__init__(cfg) - cfg: GoodputRecorder.Config = self.config - self._recorder = None - self._monitor = None - - def record(self, event: measurement.Event, *args, **kwargs): - # Lazily instantiate the recorder. This avoids invoking jax before setup is complete. + self._recorder: Optional[goodput.GoodputRecorder] = None + self._monitor: Optional[goodput_monitoring.GoodputMonitor] = None + self._rolling_window_monitor: Optional[goodput_monitoring.GoodputMonitor] = None + self._job_name = cfg.name + self._logger_name = f"goodput_logger_{cfg.name}" + + @contextlib.contextmanager + def record_event(self, event: measurement.Event, *args, **kwargs): + """Records a goodput event using a context manager.""" + # Lazily instantiate the recorder if it hasn't been already. if self._recorder is None: - cfg: GoodputRecorder.Config = self.config + if jax.process_index() == 0: + logging.info("Lazily instantiating goodput recorder.") self._recorder = goodput.GoodputRecorder( - job_name=cfg.name, - logger_name=f"goodput_logger_{cfg.name}", + job_name=self._job_name, + logger_name=self._logger_name, logging_enabled=(jax.process_index() == 0), ) - if event == measurement.Event.START_JOB: - self._recorder.record_job_start_time(*args, **kwargs) - elif event == measurement.Event.END_JOB: - self._recorder.record_job_end_time(*args, **kwargs) - elif event == measurement.Event.START_STEP: - self._recorder.record_step_start_time(*args, **kwargs) - elif event == measurement.Event.START_ACCELERATOR_INIT: - self._recorder.record_tpu_init_start_time(*args, **kwargs) - elif event == measurement.Event.END_ACCELERATOR_INIT: - self._recorder.record_tpu_init_end_time(*args, **kwargs) - elif event == measurement.Event.START_TRAINING_PREPARATION: - self._recorder.record_training_preparation_start_time(*args, **kwargs) - elif event == measurement.Event.END_TRAINING_PREPARATION: - self._recorder.record_training_preparation_end_time(*args, **kwargs) - elif event == measurement.Event.START_DATA_LOADING: - self._recorder.record_data_loading_start_time(*args, **kwargs) - elif event == measurement.Event.END_DATA_LOADING: - self._recorder.record_data_loading_end_time(*args, **kwargs) - elif event == measurement.Event.START_CUSTOM_BADPUT_EVENT: - self._recorder.record_custom_badput_event_start_time(*args, **kwargs) - elif event == measurement.Event.END_CUSTOM_BADPUT_EVENT: - self._recorder.record_custom_badput_event_end_time(*args, **kwargs) - else: - logging.log_first_n( - logging.WARNING, - "Ignoring unknown event %s", - 1, - event, + start_method_name = f"record_{event.value}_start_time" + end_method_name = f"record_{event.value}_end_time" + + record_event_start = getattr(self._recorder, start_method_name, None) + record_event_end = getattr(self._recorder, end_method_name, None) + + try: + if record_event_start: + record_event_start(*args, **kwargs) + except RuntimeError as e: + logging.warning( + "Failed to record start of event %s. Error: %s", event.value, e, exc_info=True ) - def start_monitoring(self, *args, **kwargs): - """Starts Monitoring of Goodput. + try: + yield + finally: + try: + if record_event_end: + record_event_end(*args, **kwargs) + except RuntimeError as e: + logging.warning( + "Failed to record end of event %s. Error: %s", event.value, e, exc_info=True + ) + + @contextlib.contextmanager + def _maybe_monitor_goodput(self, *args, **kwargs): + """Monitor cumulative goodput if enabled. Instantiate ml-goodput-measurement's GoodputMonitor to asynchronously calculate - Goodput and Badput at the upload_interval and upload to the specified TensorBoard - directory. + Goodput, Badput, Step & Disruption Information at the upload_interval to the + specified TensorBoard directory and Google Cloud Monitoring. Note: This function requires initialization of distributed JAX before it is called. If there are internal GCP errors from querying and uploading data, these will be logged without affecting the workload. GoodputMonitor logs will provide further @@ -123,33 +146,68 @@ def start_monitoring(self, *args, **kwargs): Default behavior is to push metrics to Google Cloud Monitoring. This behavior can be overridden by configuring `goodput_monitoring.GCPOptions` """ - cfg: GoodputRecorder.Config = self.config - include_step_deviation = True - if jax.process_index() == 0: + if jax.process_index() != 0: + yield + return + try: if self._monitor is None: - if int(cfg.step_deviation_interval_seconds) <= 0: - include_step_deviation = False - - gcp_options = goodput_monitoring.GCPOptions( - enable_gcp_goodput_metrics=True, - enable_gcp_step_deviation_metrics=include_step_deviation, - ) self._monitor = goodput_monitoring.GoodputMonitor( - job_name=cfg.name, - logger_name=f"goodput_logger_{cfg.name}", - tensorboard_dir=cfg.upload_dir, - upload_interval=int(cfg.upload_interval), + job_name=self._job_name, + logger_name=self._logger_name, + tensorboard_dir=self.config.upload_dir, + upload_interval=self.config.upload_interval, monitoring_enabled=True, + pathway_enabled=self.config.jax_backend == "proxy", include_badput_breakdown=True, - include_step_deviation=include_step_deviation, - step_deviation_interval_seconds=int(cfg.step_deviation_interval_seconds), - gcp_options=gcp_options, ) self._monitor.start_goodput_uploader(*args, **kwargs) logging.info("Started Goodput upload to Tensorboard & GCM in the background!") - if include_step_deviation: - self._monitor.start_step_deviation_uploader(*args, **kwargs) + yield + finally: + if self._monitor: + self._monitor.stop_goodput_uploader() + logging.info("Flushed final metrics and safe exited from Goodput monitoring.") + + @contextlib.contextmanager + def _maybe_monitor_rolling_window_goodput(self): + """Monitor rolling window goodput if enabled.""" + if not self.config.rolling_window_size or jax.process_index() != 0: + yield + return + try: + if self._rolling_window_monitor is None: + rolling_window_tensorboard_dir = os.path.join( + self.config.upload_dir, f"rolling_window_{self.config.name}" + ) + self._rolling_window_monitor = goodput_monitoring.GoodputMonitor( + job_name=self._job_name, + logger_name=self._logger_name, + tensorboard_dir=rolling_window_tensorboard_dir, + upload_interval=self.config.upload_interval, + monitoring_enabled=True, + pathway_enabled=self.config.jax_backend == "proxy", + include_badput_breakdown=True, + ) + self._rolling_window_monitor.start_rolling_window_goodput_uploader( + self.config.rolling_window_size + ) + logging.info("Started Rolling Window Goodput monitoring in the background!") + yield + finally: + if self._rolling_window_monitor: + self._rolling_window_monitor.stop_rolling_window_goodput_uploader() logging.info( - "Started Step Deviation upload to Tensorboard & GCM in the background!" + "Flushed final metrics and safe exited from Rolling Window Goodput monitoring." ) + + def maybe_monitor_all_goodput(self): + goodput_monitor_manager = self._maybe_monitor_goodput() + rolling_goodput_monitor_manager = self._maybe_monitor_rolling_window_goodput() + + @contextlib.contextmanager + def monitor_goodput(): + with goodput_monitor_manager, rolling_goodput_monitor_manager: + yield + + return monitor_goodput() diff --git a/axlearn/cloud/gcp/measurement_test.py b/axlearn/cloud/gcp/measurement_test.py index e14fc16c4..e944a262c 100644 --- a/axlearn/cloud/gcp/measurement_test.py +++ b/axlearn/cloud/gcp/measurement_test.py @@ -3,191 +3,373 @@ """Tests measurement utils for GCP.""" # pylint: disable=protected-access -import contextlib from unittest import mock -from absl import flags +from absl import flags, logging from absl.testing import parameterized from axlearn.cloud.gcp.measurement import GoodputRecorder from axlearn.common import measurement +from axlearn.common.config import RequiredFieldMissingError class GoodputRecorderTest(parameterized.TestCase): """Tests GoodputRecorder.""" @parameterized.parameters( - (None,), (["name=test-name", "upload_dir=/test/path/to/upload", "upload_interval=15"],) - ) - def test_from_flags(self, spec): - fv = flags.FlagValues() - measurement.define_flags(flag_values=fv) - if spec is not None: - fv.set_default("recorder_spec", spec) - fv.mark_as_parsed() - - if spec is None: - ctx = self.assertRaisesRegex(ValueError, "name") - else: - ctx = contextlib.nullcontext() - - with ctx: - recorder = GoodputRecorder.from_flags(fv) - # Recorder is not instantiated until first event. - self.assertIsNone(recorder._recorder) - - def test_record_and_monitor(self): - fv = flags.FlagValues() - measurement.define_flags(flag_values=fv) - fv.set_default( - "recorder_spec", - ["name=test-name", "upload_dir=/test/path/to/upload", "upload_interval=15"], - ) - fv.mark_as_parsed() - - recorder = GoodputRecorder.from_flags(fv) - recorder._recorder = mock.MagicMock() - recorder.record(measurement.Event.START_JOB) - self.assertTrue(recorder._recorder.record_job_start_time.called) - - def test_start_goodput_monitoring(self): - fv = flags.FlagValues() - measurement.define_flags(flag_values=fv) - fv.set_default( - "recorder_spec", - [ + dict( + recorder_spec=[ "name=test-name", - "upload_dir=/test/path/to/upload", + "upload_dir=/test/path", "upload_interval=15", - "step_deviation_interval_seconds=-1", ], - ) - fv.mark_as_parsed() - - recorder = GoodputRecorder.from_flags(fv) - self.assertIsNone(recorder._monitor) # Ensure _monitor is initially None - - with mock.patch("ml_goodput_measurement.monitoring.GoodputMonitor") as mock_goodput_monitor: - with mock.patch("ml_goodput_measurement.monitoring.GCPOptions") as mock_gcp_options: - mock_monitor_instance = mock_goodput_monitor.return_value - recorder.start_monitoring() - mock_gcp_options.assert_called_once_with( - enable_gcp_goodput_metrics=True, - enable_gcp_step_deviation_metrics=False, - ) - mock_gcp_options_instance = mock_gcp_options.return_value - - # Check that GoodputMonitor was instantiated - mock_goodput_monitor.assert_called_once_with( - job_name="test-name", - logger_name="goodput_logger_test-name", - tensorboard_dir="/test/path/to/upload", - upload_interval=15, - monitoring_enabled=True, - include_badput_breakdown=True, - include_step_deviation=False, - step_deviation_interval_seconds=-1, - gcp_options=mock_gcp_options_instance, - ) - - # Ensure that start_goodput_uploader is called on the monitor instance - mock_monitor_instance.start_goodput_uploader.assert_called_once() - self.assertIsNotNone(recorder._monitor) - - def test_start_goodput_and_step_deviation_monitoring(self): - fv = flags.FlagValues() - measurement.define_flags(flag_values=fv) - fv.set_default( - "recorder_spec", - [ + expected_rolling_window_size=[], + expected_jax_backend=None, + ), + dict( + recorder_spec=[ "name=test-name", - "upload_dir=/test/path/to/upload", + "upload_dir=/test/path", "upload_interval=15", - "step_deviation_interval_seconds=30", + "rolling_window_size=1,2,3", + "jax_backend=proxy", ], + expected_rolling_window_size=[1, 2, 3], + expected_jax_backend="proxy", + ), + ) + def test_from_flags( + self, + recorder_spec, + expected_rolling_window_size, + expected_jax_backend, + ): + """Tests that flags are correctly parsed into the config.""" + mock_fv = mock.MagicMock(spec=flags.FlagValues) + mock_fv.recorder_spec = recorder_spec + mock_fv.jax_backend = "tpu" + + recorder = GoodputRecorder.from_flags(mock_fv) + + self.assertEqual("test-name", recorder.config.name) + self.assertEqual("/test/path", recorder.config.upload_dir) + self.assertEqual(15, recorder.config.upload_interval) + self.assertEqual(expected_rolling_window_size, recorder.config.rolling_window_size) + self.assertEqual(expected_jax_backend, recorder.config.jax_backend) + + def test_from_flags_missing_required(self): + """Tests that missing required flags raise an error.""" + mock_fv = mock.MagicMock(spec=flags.FlagValues) + mock_fv.recorder_spec = ["name=test-name"] # Missing upload_dir/interval + mock_fv.jax_backend = "tpu" + with self.assertRaisesRegex(RequiredFieldMissingError, "upload_dir"): + GoodputRecorder.from_flags(mock_fv) + + @parameterized.parameters( + dict( + event=measurement.Event.JOB, + expected_start="record_job_start_time", + expected_end="record_job_end_time", + args=(), + kwargs={}, + expect_end_call=True, + ), + dict( + event=measurement.Event.STEP, + expected_start="record_step_start_time", + expected_end=None, + args=(123,), + kwargs={}, + expect_end_call=False, + ), + dict( + event=measurement.Event.ACCELERATOR_INIT, + expected_start="record_tpu_init_start_time", + expected_end="record_tpu_init_end_time", + args=(), + kwargs={}, + expect_end_call=True, + ), + dict( + event=measurement.Event.TRAINING_PREPARATION, + expected_start="record_training_preparation_start_time", + expected_end="record_training_preparation_end_time", + args=(), + kwargs={}, + expect_end_call=True, + ), + dict( + event=measurement.Event.DATA_LOADING, + expected_start="record_data_loading_start_time", + expected_end="record_data_loading_end_time", + args=(), + kwargs={}, + expect_end_call=True, + ), + dict( + event=measurement.Event.CUSTOM_BADPUT_EVENT, + expected_start="record_custom_badput_event_start_time", + expected_end="record_custom_badput_event_end_time", + args=(), + kwargs={"custom_badput_event_type": "TEST_TYPE"}, + expect_end_call=True, + ), + ) + @mock.patch("jax.process_index", return_value=0) + def test_record_event_context_manager_success( + self, _, event, expected_start, expected_end, args, kwargs, expect_end_call + ): + """Tests that record_event calls correct start and end methods with args and kwargs.""" + cfg = GoodputRecorder.default_config().set( + name="test", + upload_dir="/tmp/test", + upload_interval=1, ) - fv.mark_as_parsed() - - recorder = GoodputRecorder.from_flags(fv) - self.assertIsNone(recorder._monitor) # Ensure _monitor is initially None - - with mock.patch("ml_goodput_measurement.monitoring.GoodputMonitor") as mock_goodput_monitor: - with mock.patch("ml_goodput_measurement.monitoring.GCPOptions") as mock_gcp_options: - mock_monitor_instance = mock_goodput_monitor.return_value - recorder.start_monitoring() - mock_gcp_options.assert_called_once_with( - enable_gcp_goodput_metrics=True, - enable_gcp_step_deviation_metrics=True, - ) - mock_gcp_options_instance = mock_gcp_options.return_value - - # Check that GoodputMonitor was instantiated - mock_goodput_monitor.assert_called_once_with( - job_name="test-name", - logger_name="goodput_logger_test-name", - tensorboard_dir="/test/path/to/upload", - upload_interval=15, - monitoring_enabled=True, - include_badput_breakdown=True, - include_step_deviation=True, - step_deviation_interval_seconds=30, - gcp_options=mock_gcp_options_instance, - ) + recorder = GoodputRecorder(cfg) - # Ensure that start_goodput_uploader and start_step_deviation_uploader is called on - # the monitor instance - mock_monitor_instance.start_goodput_uploader.assert_called_once() - mock_monitor_instance.start_step_deviation_uploader.assert_called_once() - self.assertIsNotNone(recorder._monitor) - - def test_missing_required_flags(self): - fv = flags.FlagValues() - measurement.define_flags(flag_values=fv) - # Missing 'upload_dir' and 'upload_interval' from recorder_spec - fv.set_default("recorder_spec", ["name=test-name"]) # Incomplete config - fv.mark_as_parsed() - - # Expecting ValueError since 'upload_dir' and 'upload_interval' are required - with self.assertRaises(ValueError): - GoodputRecorder.from_flags(fv) - - def test_monitoring_initialization_failure(self): - fv = flags.FlagValues() - measurement.define_flags(flag_values=fv) - fv.set_default( - "recorder_spec", - ["name=test-name", "upload_dir=/test/path/to/upload", "upload_interval=15"], + with mock.patch("ml_goodput_measurement.goodput.GoodputRecorder") as mock_recorder_cls: + mock_instance = mock_recorder_cls.return_value + + start_mock = mock.MagicMock() + setattr(mock_instance, expected_start, start_mock) + if expect_end_call and expected_end: + end_mock = mock.MagicMock() + setattr(mock_instance, expected_end, end_mock) + + with recorder.record_event(event, *args, **kwargs): + pass + + mock_recorder_cls.assert_called_once() + start_mock.assert_called_once_with(*args, **kwargs) + if expect_end_call and expected_end: + end_mock.assert_called_once_with(*args, **kwargs) + + def test_record_event_context_manager_handles_runtime_error(self): + cfg = GoodputRecorder.default_config().set( + name="test", + upload_dir="/tmp/test", + upload_interval=1, + ) + recorder = GoodputRecorder(cfg) + + with mock.patch("jax.process_index", return_value=0): + with mock.patch( + "ml_goodput_measurement.goodput.GoodputRecorder" + ) as mock_recorder_cls, mock.patch.object(logging, "warning") as mock_warning: + mock_instance = mock_recorder_cls.return_value + + def raise_runtime_error(*args, **kwargs): + raise RuntimeError("mocked error") + + mock_instance.record_job_start_time.side_effect = raise_runtime_error + mock_instance.record_job_end_time.side_effect = raise_runtime_error + # Should not crash here. + with recorder.record_event(measurement.Event.JOB): + pass + + # Assert warnings were logged for start and end failures + assert mock_warning.call_count == 2 + start_call = mock_warning.call_args_list[0] + end_call = mock_warning.call_args_list[1] + + assert "Failed to record" in start_call.args[0] + assert "Failed to record" in end_call.args[0] + + @parameterized.parameters( + dict(is_pathways_job=False, mock_jax_backend="tpu"), + dict(is_pathways_job=True, mock_jax_backend="proxy"), + dict(is_pathways_job=False, mock_jax_backend=None), + ) + @mock.patch("jax.process_index", return_value=0) + def test_maybe_monitor_goodput(self, _, is_pathways_job, mock_jax_backend): + """Tests the _maybe_monitor_goodput context manager.""" + cfg = GoodputRecorder.default_config().set( + name="test-monitor", + upload_dir="/test", + upload_interval=30, + jax_backend=mock_jax_backend, ) - fv.mark_as_parsed() - - recorder = GoodputRecorder.from_flags(fv) - self.assertIsNone(recorder._monitor) - - # Mock a failure in initializing the GoodputMonitor - with mock.patch( - "ml_goodput_measurement.monitoring.GoodputMonitor", - side_effect=Exception("Failed to initialize GoodputMonitor"), - ): - with self.assertRaises(Exception): - recorder.start_monitoring() - self.assertIsNone(recorder._monitor) - - def test_non_zero_process_index(self): - fv = flags.FlagValues() - measurement.define_flags(flag_values=fv) - fv.set_default( - "recorder_spec", - ["name=test-name", "upload_dir=/test/path/to/upload", "upload_interval=15"], + recorder = GoodputRecorder(cfg) + + with mock.patch("ml_goodput_measurement.monitoring.GoodputMonitor") as mock_monitor_cls: + mock_monitor_instance = mock_monitor_cls.return_value + with recorder._maybe_monitor_goodput(): + pass + + # Verify that GoodputMonitor was instantiated with the correct parameters. + mock_monitor_cls.assert_called_once_with( + job_name="test-monitor", + logger_name="goodput_logger_test-monitor", + tensorboard_dir="/test", + upload_interval=30, + monitoring_enabled=True, + pathway_enabled=is_pathways_job, + include_badput_breakdown=True, + ) + mock_monitor_instance.start_goodput_uploader.assert_called_once() + mock_monitor_instance.stop_goodput_uploader.assert_called_once() + + @parameterized.parameters( + dict( + is_rolling_window_enabled=True, + rolling_window_size=[10, 20], + is_pathways_job=False, + mock_jax_backend="tpu", + ), + dict( + is_rolling_window_enabled=False, + rolling_window_size=[], + is_pathways_job=False, + mock_jax_backend="tpu", + ), + dict( + is_rolling_window_enabled=True, + rolling_window_size=[50], + is_pathways_job=True, + mock_jax_backend="proxy", + ), + ) + @mock.patch("jax.process_index", return_value=0) + def test_maybe_monitor_rolling_window( + self, + mock_process_index, + is_rolling_window_enabled, + rolling_window_size, + is_pathways_job, + mock_jax_backend, + ): # pylint: disable=unused-argument + """Tests the rolling window monitoring.""" + cfg = GoodputRecorder.default_config().set( + name="test-rolling", + upload_dir="/test", + upload_interval=30, + rolling_window_size=rolling_window_size, + jax_backend=mock_jax_backend, ) - fv.mark_as_parsed() + recorder = GoodputRecorder(cfg) - recorder = GoodputRecorder.from_flags(fv) - self.assertIsNone(recorder._monitor) + with mock.patch("ml_goodput_measurement.monitoring.GoodputMonitor") as mock_monitor_cls: + mock_monitor_instance = mock_monitor_cls.return_value + if not is_rolling_window_enabled: + with recorder._maybe_monitor_rolling_window_goodput(): + pass + mock_monitor_cls.assert_not_called() + return + with recorder._maybe_monitor_rolling_window_goodput(): + pass - with mock.patch("jax.process_index") as mock_process_index: - mock_process_index.return_value = 1 # Simulate a non-zero process index + mock_monitor_cls.assert_called_once_with( + job_name="test-rolling", + logger_name="goodput_logger_test-rolling", + tensorboard_dir="/test/rolling_window_test-rolling", + upload_interval=30, + monitoring_enabled=True, + pathway_enabled=is_pathways_job, + include_badput_breakdown=True, + ) + + mock_monitor_instance.start_rolling_window_goodput_uploader.assert_called_with( + rolling_window_size + ) + mock_monitor_instance.stop_rolling_window_goodput_uploader.assert_called_once() + + @mock.patch("jax.process_index", return_value=1) + def test_non_zero_process_index_skips_monitoring( + self, mock_process_index + ): # pylint: disable=unused-argument + """Tests that monitoring is skipped on non-zero process indices.""" + cfg = GoodputRecorder.default_config().set( + name="test", upload_dir="/test", upload_interval=30 + ) + recorder = GoodputRecorder(cfg) - try: - recorder.start_monitoring() - except AttributeError: - self.fail("AttributeError was raised unexpectedly.") + with mock.patch("ml_goodput_measurement.monitoring.GoodputMonitor") as mock_monitor_cls: + # Test cumulative goodput monitoring. + with recorder._maybe_monitor_goodput(): + pass + mock_monitor_cls.assert_not_called() + + cfg_rolling = GoodputRecorder.default_config().set( + name="test-rolling-skip", + upload_dir="/test", + upload_interval=30, + rolling_window_size=[10, 20], + ) + recorder_rolling = GoodputRecorder(cfg_rolling) + with recorder_rolling._maybe_monitor_rolling_window_goodput(): + pass + mock_monitor_cls.assert_not_called() + + @parameterized.parameters( + dict( + rolling_window_size=[5, 10], + jax_backend="tpu", + expected_monitor_calls=2, # Cumulative & Rolling Window + expect_rolling=True, + expect_cumulative=True, + ), + dict( + rolling_window_size=[], + jax_backend="tpu", + expected_monitor_calls=1, # Cumulative only + expect_rolling=False, + expect_cumulative=True, + ), + dict( + rolling_window_size=[5, 10], + jax_backend=None, # Disables Pathways + expected_monitor_calls=2, + expect_rolling=True, + expect_cumulative=True, + ), + dict( + rolling_window_size=[], + jax_backend=None, + expected_monitor_calls=1, + expect_rolling=False, + expect_cumulative=True, + ), + ) + @mock.patch("jax.process_index", return_value=0) + def test_maybe_monitor_all_goodput( + self, + _, + rolling_window_size, + jax_backend, + expected_monitor_calls, + expect_rolling, + expect_cumulative, + ): + """Tests all goodput monitoring with various configs.""" + cfg = GoodputRecorder.default_config().set( + name="test-all", + upload_dir="/test", + upload_interval=30, + rolling_window_size=rolling_window_size, + jax_backend=jax_backend, + ) + recorder = GoodputRecorder(cfg) + + with mock.patch("ml_goodput_measurement.monitoring.GoodputMonitor") as mock_monitor_cls: + mock_monitor_instance = mock_monitor_cls.return_value + + with recorder.maybe_monitor_all_goodput(): + pass + + self.assertEqual(mock_monitor_cls.call_count, expected_monitor_calls) + + if expect_cumulative: + mock_monitor_instance.start_goodput_uploader.assert_called_once() + mock_monitor_instance.stop_goodput_uploader.assert_called_once() + else: + mock_monitor_instance.start_goodput_uploader.assert_not_called() + mock_monitor_instance.stop_goodput_uploader.assert_not_called() + + if expect_rolling: + mock_monitor_instance.start_rolling_window_goodput_uploader.assert_called_once_with( + rolling_window_size + ) + mock_monitor_instance.stop_rolling_window_goodput_uploader.assert_called_once() + else: + mock_monitor_instance.start_rolling_window_goodput_uploader.assert_not_called() + mock_monitor_instance.stop_rolling_window_goodput_uploader.assert_not_called() diff --git a/axlearn/cloud/gcp/tpu.py b/axlearn/cloud/gcp/tpu.py index 85e1de028..77e043be3 100644 --- a/axlearn/cloud/gcp/tpu.py +++ b/axlearn/cloud/gcp/tpu.py @@ -13,12 +13,14 @@ def get_default_env(*, tpu_type: str, num_tpu_slices: int, job_name: str) -> dict[str, Any]: """Gets the default environment for TPU pods.""" + del job_name # Unused. return dict( # Use a large refresh to mitigate DNS timeout issues until tf>2.12 upgrade. GCS_RESOLVE_REFRESH_SECS=600, TPU_TYPE=tpu_type, NUM_TPU_SLICES=num_tpu_slices, - XLA_FLAGS=f"--xla_dump_to=/output/{job_name}/xla", + XLA_FLAGS="", + # XLA_FLAGS=f"--xla_dump_to=/output/{job_name}/xla", TF_CPP_MIN_LOG_LEVEL=0, # Necessary for surfacing FATAL TPU errors. TPU_STDERR_LOG_LEVEL=0, diff --git a/axlearn/common/array_serialization.py b/axlearn/common/array_serialization.py index 9ba0bbf81..2b8ee3026 100644 --- a/axlearn/common/array_serialization.py +++ b/axlearn/common/array_serialization.py @@ -559,21 +559,34 @@ def serialize( # pylint: disable-next=redefined-outer-name async def _run_serializer(): + logging.info( + "******* DEBUG GlobalAsyncCheckpointManager _run_serializer " + "with number of commit_futures: %s", + len(commit_futures), + ) future_writer = jax.tree.map( serialization.async_serialize, arrays, tensorstore_specs, commit_futures ) + logging.info("******* DEBUG GlobalAsyncCheckpointManager _run_serializer Completed") return await asyncio.gather(*future_writer) + # Is this the problem? + logging.info("******* DEBUG Starting to run _run_serializer") + # Note: We need to run the coroutine in another event loop driven by a separate thread. # The current event loop might be already running an async function when `serialize` is # invoked from a coroutine, in which case asyncio.get_running_loop().run_until_complete() # would not be able to execute another coroutine to completion. asyncio.run_coroutine_threadsafe(_run_serializer(), self._loop).result() + logging.info("******* DEBUG Starting to run _run_serializer") + self._add_futures( jax.tree_util.tree_flatten(commit_futures)[0] + (additional_futures or []) ) + logging.info("******* DEBUG Starting to run async_commit") + # Used in wait_until_finished to check on process != 0, if the checkpoint # has finished writing. self._start_async_commit(on_commit_callback) diff --git a/axlearn/common/checkpointer.py b/axlearn/common/checkpointer.py index 358f43037..452e7a06e 100644 --- a/axlearn/common/checkpointer.py +++ b/axlearn/common/checkpointer.py @@ -175,11 +175,14 @@ def async_save_tf_savables( When this call returns, `value_map` can be safely mutated, but saving to `dir` will not complete unless the returned future is set. """ + logging.info("******* DEBUG Saving TF savables to %s async", dir) # pylint: disable-next=consider-using-with f = tempfile.TemporaryDirectory() for path, value in utils.flatten_items(value_map): tf_checkpoint = tf.train.Checkpoint(value) + logging.info("******* DEBUG Writing %s to path %s", f.name, path) tf_checkpoint.write(os.path.join(f.name, path)) + logging.info("******* DEBUG Done writing %s to path %s", f.name, path) return executor.submit(_upload_dir, f, dst_dir=dir) @@ -399,6 +402,7 @@ def __init__(self, cfg: Config): # TODO(markblee): Consider making BoundedDataShardedAsyncCheckpointManager # the default once stable. if cfg.max_concurrent_gb is not None or cfg.max_data_shard_degree: + logging.info("******* DEBUG Using BoundedDataShardedAsyncCheckpointManager") self._manager = BoundedDataShardedAsyncCheckpointManager( max_concurrent_gb=cfg.max_concurrent_gb, timeout_secs=cfg.timeout_secs, @@ -411,6 +415,7 @@ def __init__(self, cfg: Config): f"shard_threshold_bytes is set to {cfg.shard_threshold_bytes}, but " "max_data_shard_degree is not set. It will not take any effect." ) + logging.info("******* DEBUG Using GlobalAsyncCheckpointManager") self._manager = GlobalAsyncCheckpointManager(timeout_secs=cfg.timeout_secs) if cfg.max_concurrent_restore_gb is not None and cfg.max_concurrent_restore_gb <= 0: raise ValueError( @@ -514,8 +519,12 @@ def save_to_dir( logging.info("Creating directories: %s", dirs) list(self._executor.map(fs.makedirs, dirs)) logging.info("All directories created") + + logging.info("******* DEBUG starting sync_global_devices") # Wait for directory and index creation. multihost_utils.sync_global_devices(ckpt_dir) + logging.info("******* DEBUG finished sync_global_devices") + # Each worker writes its tf checkpoints under a different path. save_tf_future = async_save_tf_savables( spec.tf_ckpt_map, @@ -527,6 +536,7 @@ def save_to_dir( ) def commit(): + logging.info("******* DEBUG starting on_commit_callback") on_commit_callback(ckpt_dir=ckpt_dir, index=spec.index) logging.info( "Serialization of %s completed in %s seconds.", @@ -538,6 +548,9 @@ def commit(): logging.debug( "array_values=%s tensorstore=%s", utils.shapes(spec.gda_values), spec.tensorstore_specs ) + logging.info( + "array_values=%s tensorstore=%s", utils.shapes(spec.gda_values), spec.tensorstore_specs + ) self._manager.serialize( spec.gda_values, spec.tensorstore_specs, diff --git a/axlearn/common/compiler_options.py b/axlearn/common/compiler_options.py index 06239e54b..a8c48532b 100644 --- a/axlearn/common/compiler_options.py +++ b/axlearn/common/compiler_options.py @@ -59,6 +59,12 @@ def default_xla_options( # further if you see "Allocator failed to allocate". A feature # to dynamically allocate may come later: b/380514965 megascale_grpc_premap_memory_bytes=17179869184, + # DEBUGGING ONLY: RapidEye output directory for debugging purposes, + megascale_rapid_eye_error_digest_log_path="/output/rapideye/", + # megascale_jax_offset_launch_id_by_module_name="false", + # megascale_jax_use_device_set_based_launch_id="false", + # enable megascale debug port. + megascale_debug_port=8081, # Flag controlling the maximum number of overlapping host offloadings. xla_tpu_host_transfer_overlap_limit=24, # Flag controlling the maximum number of overlapping cross-DCN send/recv. @@ -149,12 +155,20 @@ def default_xla_options( # Similar to megascale_error_reporter_abort_on_hang but for unrecoverable errors. megascale_error_reporter_abort_on_error="true", # Increase the timeout at which a hang is detected/reported, default is 5m. - megascale_graph_hang_threshold="10m", + megascale_graph_hang_threshold="60m", # Similar to megascale_graph_hang_threshold but specific to within a launch_id. # Default is 1m. - megascale_graph_within_launch_hang_threshold="10m", + megascale_graph_within_launch_hang_threshold="60m", # TODO(ethanli): temporary workaround to avoid memory leak in megascale. megascale_grpc_enable_xor_tracer="false", + # # The duration of missing heartbeats before shutting down. + # jax_heartbeat_timeout="100s", + # # JAX gRPC timeout duration. + # jax_rpc_timeout="120s", + # # JAX distributed initialization timeout. + # jax_distributed_initialization_timeout="3600s", + # # JAX shutdown timeout duration + # jax_distributed_shutdown_timeout="5m", ) # Validate options. Will never fail if this function is implemented correctly. @@ -163,7 +177,20 @@ def default_xla_options( int(v) continue except ValueError: - assert v in [True, False, "true", "false", "megachip_tccontrol", "10m"], (k, v) + assert v in [ + True, + False, + "true", + "false", + "megachip_tccontrol", + "10m", + "60m", + "100s", + "120s", + "3600s", + "5m", + "/output/rapideye/", + ], (k, v) return options @@ -302,6 +329,7 @@ def infer_xla_performance_flags( if current_configuration in mesh_configurations_for_sparse_core_offloading: flags = dict( # Must disable continuation fusion to enable sparse core offloading. + # AXLEARN TESTING NOTE: We are disabling this to test for SparseCore related issues. xla_tpu_enable_async_collective_fusion_fuse_all_gather="false", xla_tpu_enable_async_collective_fusion_fuse_all_reduce="false", xla_tpu_enable_async_collective_fusion_fuse_reduce_scatter="false", diff --git a/axlearn/common/launch_trainer.py b/axlearn/common/launch_trainer.py index bba28533e..7470ad66c 100644 --- a/axlearn/common/launch_trainer.py +++ b/axlearn/common/launch_trainer.py @@ -2,6 +2,7 @@ """Utilities to launch a trainer.""" +import contextlib import json import os from typing import Any, Optional @@ -128,8 +129,8 @@ def get_trainer_config( return trainer_config -def run_trainer(trainer_config: SpmdTrainer.Config) -> Any: - measurement.record_event(measurement.Event.START_JOB) +def _run_trainer_impl(trainer_config: SpmdTrainer.Config) -> Any: + """Instantiates and runs the trainer.""" trainer_config_debug_string = trainer_config.debug_string() logging.info("Trainer config:\n%s", trainer_config_debug_string) if jax.process_index() == 0: @@ -149,6 +150,13 @@ def run_trainer(trainer_config: SpmdTrainer.Config) -> Any: trainer: SpmdTrainer = trainer_config.instantiate(parent=None) prng_key = jax.random.PRNGKey(seed=FLAGS.trainer_prng_seed) - output = trainer.run(prng_key) - measurement.record_event(measurement.Event.END_JOB) - return output + return trainer.run(prng_key) + + +def run_trainer(trainer_config: SpmdTrainer.Config) -> Any: + recorder = measurement.global_recorder + job_events_manager = ( + recorder.record_event(measurement.Event.JOB) if recorder else contextlib.nullcontext() + ) + with job_events_manager: + return _run_trainer_impl(trainer_config) diff --git a/axlearn/common/launch_trainer_main.py b/axlearn/common/launch_trainer_main.py index 2f617b4cd..8d170a950 100644 --- a/axlearn/common/launch_trainer_main.py +++ b/axlearn/common/launch_trainer_main.py @@ -13,7 +13,6 @@ def main(_): launch.setup() trainer_config = launch_trainer.get_trainer_config() trainer_config.set(recorder=config_for_function(lambda: measurement.global_recorder)) - measurement.start_monitoring() launch_trainer.run_trainer(trainer_config) diff --git a/axlearn/common/measurement.py b/axlearn/common/measurement.py index b0a40a85f..1d2a9dea7 100644 --- a/axlearn/common/measurement.py +++ b/axlearn/common/measurement.py @@ -2,6 +2,7 @@ """A library to measure e2e metrics like goodput.""" +import contextlib import enum import importlib from typing import Optional, TypeVar @@ -15,30 +16,20 @@ class Event(enum.Enum): """Event to be recorded. Attributes: - START_JOB: Start of job. - END_JOB: End of job. - START_STEP: Start of a training step. Should be recorded with `step` as a positional arg. - START_ACCELERATOR_INIT: Start of accelerator mesh initialization. - END_ACCELERATOR_INIT: End of accelerator mesh initialization. - START_TRAINING_PREPARATION: Start of training preparation. - END_TRAINING_PREPARATION: End of training preparation. - START_DATA_LOADING: Start of data loading. - END_DATA_LOADING: End of data loading. - START_CUSTOM_BADPUT_EVENT: Start of custom badput event. - END_CUSTOM_BADPUT_EVENT: End of custom badput event. + JOB: Start and end of the job. + STEP: Start of a training step. Should be recorded with `step` as a positional arg. + ACCELERATOR_INIT: Start and end of accelerator mesh initialization. + TRAINING_PREPARATION: Start and end of training preparation. + DATA_LOADING: Start and end of data loading. + CUSTOM_BADPUT_EVENT: Start and end of custom badput events. """ - START_JOB = "START_JOB" - END_JOB = "END_JOB" - START_STEP = "START_STEP" - START_ACCELERATOR_INIT = "START_ACCELERATOR_INIT" - END_ACCELERATOR_INIT = "END_ACCELERATOR_INIT" - START_TRAINING_PREPARATION = "START_TRAINING_PREPARATION" - END_TRAINING_PREPARATION = "END_TRAINING_PREPARATION" - START_DATA_LOADING = "START_DATA_LOADING" - END_DATA_LOADING = "END_DATA_LOADING" - START_CUSTOM_BADPUT_EVENT = "START_CUSTOM_BADPUT_EVENT" - END_CUSTOM_BADPUT_EVENT = "END_CUSTOM_BADPUT_EVENT" + JOB = "job" + STEP = "step" + ACCELERATOR_INIT = "tpu_init" + TRAINING_PREPARATION = "training_preparation" + DATA_LOADING = "data_loading" + CUSTOM_BADPUT_EVENT = "custom_badput_event" class Recorder(Configurable): @@ -59,9 +50,15 @@ def from_flags(cls, fv: Optional[flags.FlagValues]) -> "Recorder": """Converts flags to a recorder.""" raise NotImplementedError(cls) - def record(self, event: Event, *args, **kwargs): - """Records an event with the given name.""" - raise NotImplementedError(type(self)) + @contextlib.contextmanager + def record_event(self, event: Event, *args, **kwargs): + """A context manager to record the start and end of an event.""" + # pylint: disable=unnecessary-pass + # pylint: disable=unused-argument + try: + yield + finally: + pass def start_monitoring(self, **kwargs): """Starts computing and uploading metrics at some configured interval in the background.""" @@ -134,14 +131,6 @@ def initialize(fv: flags.FlagValues): ) -def record_event(event: Event): - """Records a global event.""" - if global_recorder is None: - logging.log_first_n(logging.INFO, "No recorder configured, ignoring events.", 1) - else: - global_recorder.record(event) - - def start_monitoring(): """Begins monitoring events as per global monitor functionality.""" if global_recorder is None: diff --git a/axlearn/common/measurement_test.py b/axlearn/common/measurement_test.py index c9043f20b..d36605f29 100644 --- a/axlearn/common/measurement_test.py +++ b/axlearn/common/measurement_test.py @@ -3,24 +3,30 @@ """Tests measurement utils.""" # pylint: disable=protected-access +import contextlib from unittest import mock from absl import flags from absl.testing import parameterized from axlearn.common import measurement +from axlearn.experiments.testdata.axlearn_common_measurement_test.dummy_recorder import ( + DummyRecorder as RealDummyRecorder, +) class UtilsTest(parameterized.TestCase): """Tests utils.""" def setUp(self): + super().setUp() self._orig_recorder = measurement.global_recorder - self._orig_recorders = measurement._recorders + self._orig_recorders = measurement._recorders.copy() measurement.global_recorder = None measurement._recorders = {} def tearDown(self): + super().tearDown() measurement.global_recorder = self._orig_recorder measurement._recorders = self._orig_recorders @@ -33,32 +39,25 @@ class DummyRecorder(measurement.Recorder): self.assertEqual(DummyRecorder, measurement._recorders.get("test")) - # Registering twice should fail. with self.assertRaisesRegex(ValueError, "already registered"): measurement.register_recorder("test")(DummyRecorder) @parameterized.parameters( - # No-op if no recorder_type provided. - dict( - recorder_type=None, - expected=None, - ), - dict( - recorder_type="test", - expected="Mock", - ), - # Try initializing from another module. + dict(recorder_type=None), + dict(recorder_type="test"), dict( recorder_type=( - f"axlearn.experiments.testdata.{__name__.replace('.', '_')}.dummy_recorder:" + "axlearn.experiments.testdata.axlearn_common_measurement_test.dummy_recorder:" "dummy_recorder" - ), - expected="DummyRecorder", + ) ), ) - def test_initialize(self, recorder_type, expected): - mock_recorder = mock.MagicMock() - measurement.register_recorder("test")(mock_recorder) + def test_initialize(self, recorder_type): + mock_recorder_cls = mock.MagicMock() + mock_recorder_instance = mock_recorder_cls.from_flags.return_value + mock_recorder_instance.record_event.return_value = contextlib.nullcontext() + measurement.register_recorder("test")(mock_recorder_cls) + measurement.register_recorder("dummy_recorder")(RealDummyRecorder) fv = flags.FlagValues() measurement.define_flags(flag_values=fv) @@ -69,24 +68,17 @@ def test_initialize(self, recorder_type, expected): measurement.initialize(fv) if recorder_type is None: - # global_recorder should not be initialized, and record_event should be no-op. self.assertIsNone(measurement.global_recorder) - measurement.record_event(measurement.Event.START_JOB) return recorder_name = recorder_type.split(":", 1)[-1] if recorder_name == "test": - self.assertTrue(mock_recorder.from_flags.called) - - self.assertIn(expected, str(measurement._recorders.get(recorder_name, None))) - self.assertIn(expected, str(measurement.global_recorder)) - - # Ensure that record_event does not fail. - with mock.patch.object(measurement.global_recorder, "record") as mock_record: - measurement.record_event(measurement.Event.START_JOB) - self.assertIn(measurement.Event.START_JOB, mock_record.call_args[0]) + self.assertEqual(mock_recorder_instance, measurement.global_recorder) + mock_recorder_cls.from_flags.assert_called_once() + elif recorder_name == "dummy_recorder": + self.assertIsNotNone(measurement.global_recorder) + self.assertIsInstance(measurement.global_recorder, RealDummyRecorder) - # Ensure that start_monitoring does not fail. with mock.patch.object( measurement.global_recorder, "start_monitoring" ) as mock_start_monitoring: diff --git a/axlearn/common/trainer.py b/axlearn/common/trainer.py index d24dbee69..081380100 100644 --- a/axlearn/common/trainer.py +++ b/axlearn/common/trainer.py @@ -241,118 +241,121 @@ def __init__( self._device_monitor = maybe_instantiate(cfg.device_monitor) self._recorder = maybe_instantiate(cfg.recorder) self._is_initialized: bool = False - self._maybe_record_event(measurement.Event.START_ACCELERATOR_INIT) + # Accelerator initialization. + with self._record_event(measurement.Event.ACCELERATOR_INIT): + if cfg.model.dtype is None: + raise ValueError(f"dtype must be explicitly specified for {self.path()}.model") + if cfg.model.param_init is None: + cfg.model.param_init = DefaultInitializer.default_config() + logging.info( + "model.param_init is not specified. Default to DefaultInitializer: %s", + cfg.model.param_init, + ) - if cfg.model.dtype is None: - raise ValueError(f"dtype must be explicitly specified for {self.path()}.model") - if cfg.model.param_init is None: - cfg.model.param_init = DefaultInitializer.default_config() - logging.info( - "model.param_init is not specified. Default to DefaultInitializer: %s", - cfg.model.param_init, + self._per_param_train_dtype = maybe_instantiate( + canonicalize_per_param_dtype(cfg.train_dtype) ) - self._per_param_train_dtype = maybe_instantiate( - canonicalize_per_param_dtype(cfg.train_dtype) - ) - - # Create the device mesh. - if devices is None: - self._step_log( - "Devices: global=%s local=%s %s", - jax.device_count(), - jax.local_device_count(), - [device.platform for device in jax.local_devices()], - ) - else: - local_devices = [d for d in devices.flatten() if d.process_index == jax.process_index()] - self._step_log( - "Devices: global=%s local=%s %s", - len(devices), - len(local_devices), - [device.platform for device in local_devices], - ) - self._step_log("Mesh shape: %s", cfg.mesh_shape) - devices = ( - utils.create_device_mesh(mesh_shape=cfg.mesh_shape) if devices is None else devices - ) - mesh = jax.sharding.Mesh(devices, cfg.mesh_axis_names) - self._step_log("Global mesh: %s", mesh) - self._mesh = mesh - self._context_manager: Callable[[], ContextManager] = ( - maybe_instantiate(cfg.context_manager) or contextlib.nullcontext - ) - xsc_check_policy = None - if cfg.xsc_check_policy: - if jax.default_backend() != "tpu": - # XSC is currently only supported on TPU XLA backend. - logging.warning( - "xsc_check_policy was set for non-TPU XLA backend. Running without XSC." + # Create the device mesh. + if devices is None: + self._step_log( + "Devices: global=%s local=%s %s", + jax.device_count(), + jax.local_device_count(), + [device.platform for device in jax.local_devices()], ) else: - xsc_check_policy = maybe_instantiate(cfg.xsc_check_policy) - self._xsc_check_policy: Optional[Callable[[int], bool]] = xsc_check_policy - self._compiled_train_step: Optional[jax.stages.Compiled] = None - - # Create all children within the mesh context so that utils.input_partition_spec() works - # properly. - with self.mesh(): - if cfg.batch_axis_names is not None: - cfg.input = maybe_set_config( - cfg.input, partition_spec=PartitionSpec(cfg.batch_axis_names) + local_devices = [ + d for d in devices.flatten() if d.process_index == jax.process_index() + ] + self._step_log( + "Devices: global=%s local=%s %s", + len(devices), + len(local_devices), + [device.platform for device in local_devices], ) - self.input: Input = self._add_child( - "input", maybe_set_config(cfg.input, is_training=True) - ) - # Start from the beginning of the input dataset by default. - self._input_iter = iter(self.input.dataset()) - cfg.summary_writer.dir = cfg.summary_writer.dir or os.path.join( - cfg.dir, "summaries", "train_train" + self._step_log("Mesh shape: %s", cfg.mesh_shape) + devices = ( + utils.create_device_mesh(mesh_shape=cfg.mesh_shape) if devices is None else devices ) - self._add_child("summary_writer", cfg.summary_writer) - self._add_child("model", cfg.model) - self._add_child("learner", cfg.learner) - cfg.checkpointer.dir = cfg.checkpointer.dir or os.path.join(cfg.dir, "checkpoints") - self._add_child("checkpointer", cfg.checkpointer) - if cfg.init_state_builder is not None: - self._add_child("init_state_builder", cfg.init_state_builder) - - self._model_param_specs = self.model.create_parameter_specs_recursively() - model_param_partition_specs = jax.tree.map( - lambda spec: spec.mesh_axes, self._model_param_specs + mesh = jax.sharding.Mesh(devices, cfg.mesh_axis_names) + self._step_log("Global mesh: %s", mesh) + self._mesh = mesh + self._context_manager: Callable[[], ContextManager] = ( + maybe_instantiate(cfg.context_manager) or contextlib.nullcontext ) - for name, spec in utils.flatten_items(self._model_param_specs): - self._step_log("Model param spec: %s=%s", name, spec) - self._learner_state_partition_specs = self.learner.create_state_partition_specs( - self._model_param_specs - ) - for name, spec in utils.flatten_items(self._learner_state_partition_specs): - self._step_log("Learner state spec: %s=%s", name, spec) - self._trainer_state_specs = TrainerState( - prng_key=ParameterSpec(dtype=jnp.uint32, shape=[4], mesh_axes=PartitionSpec(None)), - model=self._model_param_specs, - learner=self._learner_state_partition_specs, - ) - self._trainer_state_partition_specs: TrainerState = jax.tree.map( - lambda spec: spec.sharding, self._trainer_state_specs - ) - # Create evalers, which depend on model_param_partition_specs. - self._evalers = {} - for evaler_name, evaler_cfg in cfg.evalers.items(): - evaler_cfg.summary_writer.dir = evaler_cfg.summary_writer.dir or os.path.join( - cfg.dir, "summaries", evaler_name - ) + xsc_check_policy = None + if cfg.xsc_check_policy: + if jax.default_backend() != "tpu": + # XSC is currently only supported on TPU XLA backend. + logging.warning( + "xsc_check_policy was set for non-TPU XLA backend. Running without XSC." + ) + else: + xsc_check_policy = maybe_instantiate(cfg.xsc_check_policy) + self._xsc_check_policy: Optional[Callable[[int], bool]] = xsc_check_policy + self._compiled_train_step: Optional[jax.stages.Compiled] = None + + # Create all children within the mesh context so that utils.input_partition_spec() works + # properly. + with self.mesh(): if cfg.batch_axis_names is not None: - maybe_set_config( - evaler_cfg.input, partition_spec=PartitionSpec(cfg.batch_axis_names) + cfg.input = maybe_set_config( + cfg.input, partition_spec=PartitionSpec(cfg.batch_axis_names) ) - self._evalers[evaler_name] = self._add_child( - evaler_name, - evaler_cfg, - model=self.model, - model_param_partition_specs=model_param_partition_specs, + self.input: Input = self._add_child( + "input", maybe_set_config(cfg.input, is_training=True) + ) + # Start from the beginning of the input dataset by default. + self._input_iter = iter(self.input.dataset()) + cfg.summary_writer.dir = cfg.summary_writer.dir or os.path.join( + cfg.dir, "summaries", "train_train" + ) + self._add_child("summary_writer", cfg.summary_writer) + self._add_child("model", cfg.model) + self._add_child("learner", cfg.learner) + cfg.checkpointer.dir = cfg.checkpointer.dir or os.path.join(cfg.dir, "checkpoints") + self._add_child("checkpointer", cfg.checkpointer) + if cfg.init_state_builder is not None: + self._add_child("init_state_builder", cfg.init_state_builder) + + self._model_param_specs = self.model.create_parameter_specs_recursively() + model_param_partition_specs = jax.tree.map( + lambda spec: spec.mesh_axes, self._model_param_specs + ) + for name, spec in utils.flatten_items(self._model_param_specs): + self._step_log("Model param spec: %s=%s", name, spec) + self._learner_state_partition_specs = self.learner.create_state_partition_specs( + self._model_param_specs ) - self._maybe_record_event(measurement.Event.END_ACCELERATOR_INIT) + for name, spec in utils.flatten_items(self._learner_state_partition_specs): + self._step_log("Learner state spec: %s=%s", name, spec) + self._trainer_state_specs = TrainerState( + prng_key=ParameterSpec( + dtype=jnp.uint32, shape=[4], mesh_axes=PartitionSpec(None) + ), + model=self._model_param_specs, + learner=self._learner_state_partition_specs, + ) + self._trainer_state_partition_specs: TrainerState = jax.tree.map( + lambda spec: spec.sharding, self._trainer_state_specs + ) + # Create evalers, which depend on model_param_partition_specs. + self._evalers = {} + for evaler_name, evaler_cfg in cfg.evalers.items(): + evaler_cfg.summary_writer.dir = evaler_cfg.summary_writer.dir or os.path.join( + cfg.dir, "summaries", evaler_name + ) + if cfg.batch_axis_names is not None: + maybe_set_config( + evaler_cfg.input, partition_spec=PartitionSpec(cfg.batch_axis_names) + ) + self._evalers[evaler_name] = self._add_child( + evaler_name, + evaler_cfg, + model=self.model, + model_param_partition_specs=model_param_partition_specs, + ) @property def step(self): @@ -370,6 +373,15 @@ def trainer_state_specs(self): def trainer_state_partition_specs(self): return self._trainer_state_partition_specs + @contextlib.contextmanager + def _record_event(self, event: measurement.Event, *args, **kwargs): + """A helper to record an event if a recorder is configured.""" + if self._recorder: + with self._recorder.record_event(event, *args, **kwargs) as event_manager: + yield event_manager + else: + yield + def _train_step_input_partition_specs(self): # Note that subclasses may override this method to set a partition spec for pjit which is # different from that of the input partition spec. @@ -527,10 +539,6 @@ def _should_force_run_evals( ) return force_run_evals - def _maybe_record_event(self, event: measurement.Event, *args, **kwargs): - if self._recorder is not None: - self._recorder.record(event, *args, **kwargs) - # pylint: disable-next=too-many-statements,too-many-branches def run( self, prng_key: Tensor, *, return_evaler_summaries: Optional[Union[bool, set[str]]] = None @@ -556,6 +564,7 @@ def run( different types of values such as WeightedScalar, Tensor, or string, depending on the specific `metric_calculator` config of the evaler. """ + with ( ( self._device_monitor.start_monitoring() @@ -566,6 +575,7 @@ def run( self.mesh(), jax.log_compiles(self.vlog_is_on(1)), self._context_manager(), + self._recorder.maybe_monitor_all_goodput(), ): cfg = self.config # Check if need to force run evals at the last training step. @@ -574,8 +584,9 @@ def run( ) # Prepare training. - if not self._prepare_training(prng_key): - return None + with self._record_event(measurement.Event.TRAINING_PREPARATION): + if not self._prepare_training(prng_key): + return None self._is_initialized = True @@ -588,10 +599,10 @@ def run( input_iterator = self.input.batches(self._input_iter) while True: - self._maybe_record_event(measurement.Event.START_DATA_LOADING) try: - input_batch = next(input_iterator) - self._maybe_record_event(measurement.Event.END_DATA_LOADING) + with self._record_event(measurement.Event.DATA_LOADING): + input_batch = next(input_iterator) + logging.log_first_n( logging.INFO, "host_input_batch=%s", 3, utils.shapes(input_batch) ) @@ -601,21 +612,21 @@ def run( self._step = self._step + 1 self.vlog(3, "Start step %s", self.step) - self._maybe_record_event(measurement.Event.START_STEP, self._step) - output = self._run_step( - utils.host_to_global_array( - input_batch, - partition=self._train_step_input_partition_specs(), - ), - force_run_evals=( - force_run_eval_sets_at_max_step - if self.step >= cfg.max_step - else None - ), - ) + with self._record_event(measurement.Event.STEP, self._step): + output = self._run_step( + utils.host_to_global_array( + input_batch, + partition=self._train_step_input_partition_specs(), + ), + force_run_evals=( + force_run_eval_sets_at_max_step + if self.step >= cfg.max_step + else None + ), + ) self.vlog(3, "Done step %s", self.step) num_steps += 1 - if num_steps % 100 == 0: + if num_steps % 10 == 0: now = time.perf_counter() average_step_time = (now - start_time) / num_steps self._step_log("Average step time: %s seconds", average_step_time) @@ -626,9 +637,6 @@ def run( self._step_log("Reached max_step=%s. Stopping", cfg.max_step) break except StopIteration: - # Add END_DATA_LOADING event here to close the unpaired START_DATA_LOADING - # event. - self._maybe_record_event(measurement.Event.END_DATA_LOADING) break if self.step < cfg.max_step: self._step_log("Reached end of inputs. Stopping") @@ -869,7 +877,6 @@ def _prepare_training(self, prng_key: Tensor) -> bool: A boolean indicating whether the model training should start. If not, return None from the `run` function. """ - self._maybe_record_event(measurement.Event.START_TRAINING_PREPARATION) cfg = self.config # Attempt to restore the latest checkpoint, which may contain a saved `_input_iter`. @@ -902,7 +909,6 @@ def _prepare_training(self, prng_key: Tensor) -> bool: return False self._jit_train_step = self._pjit_train_step() - self._maybe_record_event(measurement.Event.END_TRAINING_PREPARATION) return True def restore_checkpoint(self, restore_step: Optional[int] = None) -> Optional[int]: @@ -1043,36 +1049,29 @@ def _get_compiled_train_step_fn( mesh_shape=cfg.mesh_shape, mesh_axis_names=cfg.mesh_axis_names, device_kind=device_kind ) if not with_xsc: - self._maybe_record_event( - measurement.Event.START_CUSTOM_BADPUT_EVENT, - custom_badput_event_type="COMPILATION_NO_XSC", - ) - self._compiled_train_step = self.compile_train_step( - trainer_state=trainer_state, input_batch=input_batch, compiler_options=options - ) - self._maybe_record_event( - measurement.Event.END_CUSTOM_BADPUT_EVENT, + with self._record_event( + measurement.Event.CUSTOM_BADPUT_EVENT, custom_badput_event_type="COMPILATION_NO_XSC", - ) + ): + self._compiled_train_step = self.compile_train_step( + trainer_state=trainer_state, input_batch=input_batch, compiler_options=options + ) return self._compiled_train_step + logging.log_first_n(logging.INFO, "Compiling XSC train step.", 1) - self._maybe_record_event( - measurement.Event.START_CUSTOM_BADPUT_EVENT, - custom_badput_event_type="COMPILATION_WITH_XSC", - ) - compiled_jit_train_step_fn = self.compile_train_step( - trainer_state=trainer_state, - input_batch=input_batch, - compiler_options=options - | infer_xsc_compiler_options( - halt_on_detection=True, repeat_count=1, device_kind=device_kind - ), - ) - self._maybe_record_event( - measurement.Event.END_CUSTOM_BADPUT_EVENT, + with self._record_event( + measurement.Event.CUSTOM_BADPUT_EVENT, custom_badput_event_type="COMPILATION_WITH_XSC", - ) + ): + compiled_jit_train_step_fn = self.compile_train_step( + trainer_state=trainer_state, + input_batch=input_batch, + compiler_options=options + | infer_xsc_compiler_options( + halt_on_detection=True, repeat_count=1, device_kind=device_kind + ), + ) return compiled_jit_train_step_fn def _run_step( @@ -1100,7 +1099,7 @@ def _run_step( # Run the compiled function. self._trainer_state, outputs = compiled_train_step_fn(self.trainer_state, input_batch) - if self.step % 100 == 0 or 0 <= self.step <= 5: + if self.step % 10 == 0 or 0 <= self.step <= 5: self._step_log( "loss=%s aux=%s", outputs["loss"], @@ -1130,26 +1129,23 @@ def _run_eval( force_runs: Optional[set[str]] = None, ) -> dict[str, Any]: """Runs evaluations and returns the corresponding summaries.""" - self._maybe_record_event( - measurement.Event.START_CUSTOM_BADPUT_EVENT, custom_badput_event_type="EVAL" - ) - evaler_summaries = {} - # Note: we will use the same eval key as the training keys of the future step, - # which should be okay. - prng_key = self._trainer_state.prng_key - for evaler_name, evaler in self._evalers.items(): - prng_key, summaries, _ = evaler.eval_step( - self.step, - prng_key=prng_key, - model_params=self.model_params_for_eval(), - train_summaries=train_summaries, - force_run=bool(force_runs is not None and evaler_name in force_runs), - ) - evaler_summaries[evaler_name] = summaries - self._maybe_record_event( - measurement.Event.END_CUSTOM_BADPUT_EVENT, custom_badput_event_type="EVAL" - ) - return evaler_summaries + with self._record_event( + measurement.Event.CUSTOM_BADPUT_EVENT, custom_badput_event_type="EVAL" + ): + evaler_summaries = {} + # Note: we will use the same eval key as the training keys of the future step, + # which should be okay. + prng_key = self._trainer_state.prng_key + for evaler_name, evaler in self._evalers.items(): + prng_key, summaries, _ = evaler.eval_step( + self.step, + prng_key=prng_key, + model_params=self.model_params_for_eval(), + train_summaries=train_summaries, + force_run=bool(force_runs is not None and evaler_name in force_runs), + ) + evaler_summaries[evaler_name] = summaries + return evaler_summaries def _pjit_train_step(self) -> jax.stages.Wrapped: return pjit( diff --git a/axlearn/common/utils_spmd.py b/axlearn/common/utils_spmd.py index 1917e99a3..4870bd1c5 100644 --- a/axlearn/common/utils_spmd.py +++ b/axlearn/common/utils_spmd.py @@ -88,6 +88,16 @@ def setup( coordinator_address=distributed_coordinator, num_processes=num_processes, process_id=process_id, + # The duration of missing heartbeats before shutting down. + heartbeat_timeout="120s", + # JAX distributed initialization timeout. + initialization_timeout="3600s", + # JAX distributed shutdown timeout. + shutdown_timeout="3600s", + # RPC timeout. + rpc_timeout="3600s", + # RPC timeout for heartbeat. + coordinator_rpc_timeout="3600s", ) if jax_backend == "gpu": # jax 0.4.34 introduced a change to cluster auto-detection behavior, supplying diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-flash-fp8.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-flash-fp8.txt new file mode 100644 index 000000000..30aaebc9c --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-flash-fp8.txt @@ -0,0 +1,313 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 50000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 524288 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 5000 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 524288 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 5000 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.input_dispatcher.global_logical_batch_size: 1024 +evalers['train'].input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 4096 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +evalers['train'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 524288 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 5000 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.input_dispatcher.global_logical_batch_size: 1024 +evalers['validation'].input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 4096 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +evalers['validation'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.input_dispatcher.global_logical_batch_size: 1024 +input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +input.input_partitioner.fn: 'axlearn.common.input_base.partition_by_path_rank' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][0]: 'data' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][1]: 'expert' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][2]: 'fsdp' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][0]: 'data' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][1]: 'expert' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][2]: 'fsdp' +input.input_partitioner.path_rank_to_partition[(None, 2)][1]: 'seq' +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 4096 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.00015 +learner.optimizer.args[1].update_schedule.alpha: 0.1 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 524288 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 2000 +learner.optimizer.args[1].weight_decay: 0.1 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 524288 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'tpu-v6e-256.*' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: 64 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.common.utils.policy' +mesh_rules[0][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' +mesh_rules[0][1].config_modifiers[2].tpu_block_size: 1024 +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_shape[0]: 1 +mesh_shape[1]: -1 +mesh_shape[2]: 1 +mesh_shape[3]: 64 +mesh_shape[4]: 1 +mesh_shape[5]: 4 +model.batch_axis_names: None +model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' +model.decoder.dim: 12288 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0]: None +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 1 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.lm_head.klass: 'axlearn.common.decoder.LmHead' +model.decoder.lm_head.param_partition_spec[0]: None +model.decoder.lm_head.param_partition_spec[1]: 'model' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 0 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' +model.decoder.transformer.layer.feed_forward.hidden_dim.round_up_to_multiples_of: 256 +model.decoder.transformer.layer.feed_forward.hidden_dim.scale: 3.5 +model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.feed_forward.linear1.bias: False +model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.bias: False +model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.feed_forward.structure: 'prenorm' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' +model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' +model.decoder.transformer.layer.self_attention.attention.causal: True +model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8 +model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 10000.0 +model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.flash_attention.layer.FlashAttention' +model.decoder.transformer.layer.self_attention.attention.kv_cache.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.kv_cache.klass: 'axlearn.common.attention.KVCache' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][2]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.self_attention.attention.num_heads: 96 +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][1]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][2]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.self_attention.attention.tpu_block_size: 512 +model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.self_attention.structure: 'prenorm' +model.decoder.transformer.num_layers: 80 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 32768 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.metrics.klass: 'axlearn.common.causal_lm.CompositeLossMetrics' +model.metrics.metrics['lm'].klass: 'axlearn.common.causal_lm.CrossEntropyLossMetrics' +model.metrics.metrics['aux'].klass: 'axlearn.common.causal_lm.AuxLossMetrics' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-flash-fp8_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-flash-fp8_init.txt new file mode 100644 index 000000000..d3f162fe3 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-flash-fp8_init.txt @@ -0,0 +1,10 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32768, 12288], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(12288, 112, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(12288, 96, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(12288, 43008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(12288, 43008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(43008, 12288), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/output_norm/scale: constant(1.0) +decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(32768, 12288), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-flash-fp8_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-flash-fp8_regularizer.txt new file mode 100644 index 000000000..65733fb7f --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-flash-fp8_regularizer.txt @@ -0,0 +1,11 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/lm_head/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-flash.txt new file mode 100644 index 000000000..30aaebc9c --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-flash.txt @@ -0,0 +1,313 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 50000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 524288 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 5000 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 524288 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 5000 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.input_dispatcher.global_logical_batch_size: 1024 +evalers['train'].input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 4096 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +evalers['train'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 524288 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 5000 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.input_dispatcher.global_logical_batch_size: 1024 +evalers['validation'].input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 4096 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +evalers['validation'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.input_dispatcher.global_logical_batch_size: 1024 +input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +input.input_partitioner.fn: 'axlearn.common.input_base.partition_by_path_rank' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][0]: 'data' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][1]: 'expert' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][2]: 'fsdp' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][0]: 'data' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][1]: 'expert' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][2]: 'fsdp' +input.input_partitioner.path_rank_to_partition[(None, 2)][1]: 'seq' +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 4096 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.00015 +learner.optimizer.args[1].update_schedule.alpha: 0.1 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 524288 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 2000 +learner.optimizer.args[1].weight_decay: 0.1 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 524288 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'tpu-v6e-256.*' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: 64 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.common.utils.policy' +mesh_rules[0][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' +mesh_rules[0][1].config_modifiers[2].tpu_block_size: 1024 +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_shape[0]: 1 +mesh_shape[1]: -1 +mesh_shape[2]: 1 +mesh_shape[3]: 64 +mesh_shape[4]: 1 +mesh_shape[5]: 4 +model.batch_axis_names: None +model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' +model.decoder.dim: 12288 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0]: None +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 1 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.lm_head.klass: 'axlearn.common.decoder.LmHead' +model.decoder.lm_head.param_partition_spec[0]: None +model.decoder.lm_head.param_partition_spec[1]: 'model' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 0 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' +model.decoder.transformer.layer.feed_forward.hidden_dim.round_up_to_multiples_of: 256 +model.decoder.transformer.layer.feed_forward.hidden_dim.scale: 3.5 +model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.feed_forward.linear1.bias: False +model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.bias: False +model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.feed_forward.structure: 'prenorm' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' +model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' +model.decoder.transformer.layer.self_attention.attention.causal: True +model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8 +model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 10000.0 +model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.flash_attention.layer.FlashAttention' +model.decoder.transformer.layer.self_attention.attention.kv_cache.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.kv_cache.klass: 'axlearn.common.attention.KVCache' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][2]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.self_attention.attention.num_heads: 96 +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][1]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][2]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.self_attention.attention.tpu_block_size: 512 +model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.self_attention.structure: 'prenorm' +model.decoder.transformer.num_layers: 80 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 32768 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.metrics.klass: 'axlearn.common.causal_lm.CompositeLossMetrics' +model.metrics.metrics['lm'].klass: 'axlearn.common.causal_lm.CrossEntropyLossMetrics' +model.metrics.metrics['aux'].klass: 'axlearn.common.causal_lm.AuxLossMetrics' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-flash_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-flash_init.txt new file mode 100644 index 000000000..d3f162fe3 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-flash_init.txt @@ -0,0 +1,10 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32768, 12288], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(12288, 112, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(12288, 96, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(12288, 43008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(12288, 43008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(43008, 12288), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/output_norm/scale: constant(1.0) +decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(32768, 12288), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-flash_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-flash_regularizer.txt new file mode 100644 index 000000000..65733fb7f --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-flash_regularizer.txt @@ -0,0 +1,11 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/lm_head/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-fp8.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-fp8.txt new file mode 100644 index 000000000..95e954a3b --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-fp8.txt @@ -0,0 +1,280 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 50000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 524288 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 5000 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 524288 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 5000 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.input_dispatcher.global_logical_batch_size: 1024 +evalers['train'].input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 4096 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +evalers['train'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 524288 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 5000 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.input_dispatcher.global_logical_batch_size: 1024 +evalers['validation'].input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 4096 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +evalers['validation'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.input_dispatcher.global_logical_batch_size: 1024 +input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +input.input_partitioner.fn: 'axlearn.common.input_base.partition_by_path_rank' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][0]: 'data' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][1]: 'expert' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][2]: 'fsdp' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][0]: 'data' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][1]: 'expert' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][2]: 'fsdp' +input.input_partitioner.path_rank_to_partition[(None, 2)][1]: 'seq' +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 4096 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.00015 +learner.optimizer.args[1].update_schedule.alpha: 0.1 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 524288 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 2000 +learner.optimizer.args[1].weight_decay: 0.1 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 524288 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'tpu-v6e-256.*' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: 64 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.common.utils.policy' +mesh_rules[0][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' +mesh_rules[0][1].config_modifiers[2].tpu_block_size: 1024 +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_shape[0]: 1 +mesh_shape[1]: -1 +mesh_shape[2]: 1 +mesh_shape[3]: 64 +mesh_shape[4]: 1 +mesh_shape[5]: 4 +model.batch_axis_names: None +model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' +model.decoder.dim: 12288 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0]: None +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 1 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.lm_head.klass: 'axlearn.common.decoder.LmHead' +model.decoder.lm_head.param_partition_spec[0]: None +model.decoder.lm_head.param_partition_spec[1]: 'model' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 0 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' +model.decoder.transformer.layer.feed_forward.hidden_dim.round_up_to_multiples_of: 256 +model.decoder.transformer.layer.feed_forward.hidden_dim.scale: 3.5 +model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.feed_forward.linear1.bias: False +model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.bias: False +model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.feed_forward.structure: 'prenorm' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' +model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' +model.decoder.transformer.layer.self_attention.attention.causal: True +model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8 +model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 10000.0 +model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.attention.GroupedQueryAttention' +model.decoder.transformer.layer.self_attention.attention.kv_cache.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.kv_cache.klass: 'axlearn.common.attention.KVCache' +model.decoder.transformer.layer.self_attention.attention.num_heads: 96 +model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.self_attention.structure: 'prenorm' +model.decoder.transformer.num_layers: 80 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 32768 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.metrics.klass: 'axlearn.common.causal_lm.CompositeLossMetrics' +model.metrics.metrics['lm'].klass: 'axlearn.common.causal_lm.CrossEntropyLossMetrics' +model.metrics.metrics['aux'].klass: 'axlearn.common.causal_lm.AuxLossMetrics' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-fp8_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-fp8_init.txt new file mode 100644 index 000000000..d3f162fe3 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-fp8_init.txt @@ -0,0 +1,10 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32768, 12288], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(12288, 112, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(12288, 96, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(12288, 43008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(12288, 43008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(43008, 12288), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/output_norm/scale: constant(1.0) +decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(32768, 12288), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-fp8_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-fp8_regularizer.txt new file mode 100644 index 000000000..65733fb7f --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2-fp8_regularizer.txt @@ -0,0 +1,11 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/lm_head/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2.txt new file mode 100644 index 000000000..95e954a3b --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2.txt @@ -0,0 +1,280 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 50000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 524288 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 5000 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 524288 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 5000 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.input_dispatcher.global_logical_batch_size: 1024 +evalers['train'].input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 4096 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +evalers['train'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 524288 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 5000 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.input_dispatcher.global_logical_batch_size: 1024 +evalers['validation'].input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 4096 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +evalers['validation'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.input_dispatcher.global_logical_batch_size: 1024 +input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +input.input_partitioner.fn: 'axlearn.common.input_base.partition_by_path_rank' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][0]: 'data' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][1]: 'expert' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][2]: 'fsdp' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][0]: 'data' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][1]: 'expert' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][2]: 'fsdp' +input.input_partitioner.path_rank_to_partition[(None, 2)][1]: 'seq' +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 4096 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.00015 +learner.optimizer.args[1].update_schedule.alpha: 0.1 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 524288 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 2000 +learner.optimizer.args[1].weight_decay: 0.1 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 524288 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'tpu-v6e-256.*' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: 64 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.common.utils.policy' +mesh_rules[0][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' +mesh_rules[0][1].config_modifiers[2].tpu_block_size: 1024 +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_shape[0]: 1 +mesh_shape[1]: -1 +mesh_shape[2]: 1 +mesh_shape[3]: 64 +mesh_shape[4]: 1 +mesh_shape[5]: 4 +model.batch_axis_names: None +model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' +model.decoder.dim: 12288 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0]: None +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 1 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.lm_head.klass: 'axlearn.common.decoder.LmHead' +model.decoder.lm_head.param_partition_spec[0]: None +model.decoder.lm_head.param_partition_spec[1]: 'model' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 0 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' +model.decoder.transformer.layer.feed_forward.hidden_dim.round_up_to_multiples_of: 256 +model.decoder.transformer.layer.feed_forward.hidden_dim.scale: 3.5 +model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.feed_forward.linear1.bias: False +model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.bias: False +model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.feed_forward.structure: 'prenorm' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' +model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' +model.decoder.transformer.layer.self_attention.attention.causal: True +model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8 +model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 10000.0 +model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.attention.GroupedQueryAttention' +model.decoder.transformer.layer.self_attention.attention.kv_cache.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.kv_cache.klass: 'axlearn.common.attention.KVCache' +model.decoder.transformer.layer.self_attention.attention.num_heads: 96 +model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.self_attention.structure: 'prenorm' +model.decoder.transformer.num_layers: 80 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 32768 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.metrics.klass: 'axlearn.common.causal_lm.CompositeLossMetrics' +model.metrics.metrics['lm'].klass: 'axlearn.common.causal_lm.CrossEntropyLossMetrics' +model.metrics.metrics['aux'].klass: 'axlearn.common.causal_lm.AuxLossMetrics' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2_init.txt new file mode 100644 index 000000000..d3f162fe3 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2_init.txt @@ -0,0 +1,10 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32768, 12288], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(12288, 112, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(12288, 96, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(12288, 43008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(12288, 43008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(43008, 12288), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/output_norm/scale: constant(1.0) +decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(32768, 12288), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2_regularizer.txt new file mode 100644 index 000000000..65733fb7f --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v2_regularizer.txt @@ -0,0 +1,11 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/lm_head/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-flash-fp8.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-flash-fp8.txt new file mode 100644 index 000000000..b148649f5 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-flash-fp8.txt @@ -0,0 +1,313 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 50000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 983040 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 5000 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 983040 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 5000 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.input_dispatcher.global_logical_batch_size: 2048 +evalers['train'].input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 8192 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +evalers['train'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_128k_c4.model' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 983040 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 5000 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.input_dispatcher.global_logical_batch_size: 2048 +evalers['validation'].input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 8192 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +evalers['validation'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_128k_c4.model' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.input_dispatcher.global_logical_batch_size: 2048 +input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +input.input_partitioner.fn: 'axlearn.common.input_base.partition_by_path_rank' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][0]: 'data' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][1]: 'expert' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][2]: 'fsdp' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][0]: 'data' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][1]: 'expert' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][2]: 'fsdp' +input.input_partitioner.path_rank_to_partition[(None, 2)][1]: 'seq' +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 8192 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +input.source.vocab_cfg.sentencepiece_model_name: 'bpe_128k_c4.model' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.00015 +learner.optimizer.args[1].update_schedule.alpha: 0.1 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 983040 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 2000 +learner.optimizer.args[1].weight_decay: 0.1 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 983040 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'tpu-v6e-256.*' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: 64 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.common.utils.policy' +mesh_rules[0][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' +mesh_rules[0][1].config_modifiers[2].tpu_block_size: 1024 +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_shape[0]: 1 +mesh_shape[1]: -1 +mesh_shape[2]: 1 +mesh_shape[3]: 64 +mesh_shape[4]: 1 +mesh_shape[5]: 4 +model.batch_axis_names: None +model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' +model.decoder.dim: 12288 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0]: None +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 1 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.lm_head.klass: 'axlearn.common.decoder.LmHead' +model.decoder.lm_head.param_partition_spec[0]: None +model.decoder.lm_head.param_partition_spec[1]: 'model' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 0 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' +model.decoder.transformer.layer.feed_forward.hidden_dim.round_up_to_multiples_of: 256 +model.decoder.transformer.layer.feed_forward.hidden_dim.scale: 3.5 +model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.feed_forward.linear1.bias: False +model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.bias: False +model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.feed_forward.structure: 'prenorm' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' +model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' +model.decoder.transformer.layer.self_attention.attention.causal: True +model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8 +model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 +model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.flash_attention.layer.FlashAttention' +model.decoder.transformer.layer.self_attention.attention.kv_cache.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.kv_cache.klass: 'axlearn.common.attention.KVCache' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][2]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.self_attention.attention.num_heads: 96 +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][1]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][2]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.self_attention.attention.tpu_block_size: 512 +model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.self_attention.structure: 'prenorm' +model.decoder.transformer.num_layers: 80 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 131072 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.metrics.klass: 'axlearn.common.causal_lm.CompositeLossMetrics' +model.metrics.metrics['lm'].klass: 'axlearn.common.causal_lm.CrossEntropyLossMetrics' +model.metrics.metrics['aux'].klass: 'axlearn.common.causal_lm.AuxLossMetrics' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-flash-fp8_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-flash-fp8_init.txt new file mode 100644 index 000000000..ce0614cbc --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-flash-fp8_init.txt @@ -0,0 +1,10 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 12288], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(12288, 112, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(12288, 96, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(12288, 43008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(12288, 43008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(43008, 12288), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/output_norm/scale: constant(1.0) +decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(131072, 12288), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-flash-fp8_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-flash-fp8_regularizer.txt new file mode 100644 index 000000000..65733fb7f --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-flash-fp8_regularizer.txt @@ -0,0 +1,11 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/lm_head/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-flash.txt new file mode 100644 index 000000000..b148649f5 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-flash.txt @@ -0,0 +1,313 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 50000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 983040 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 5000 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 983040 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 5000 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.input_dispatcher.global_logical_batch_size: 2048 +evalers['train'].input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 8192 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +evalers['train'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_128k_c4.model' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 983040 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 5000 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.input_dispatcher.global_logical_batch_size: 2048 +evalers['validation'].input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 8192 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +evalers['validation'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_128k_c4.model' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.input_dispatcher.global_logical_batch_size: 2048 +input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +input.input_partitioner.fn: 'axlearn.common.input_base.partition_by_path_rank' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][0]: 'data' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][1]: 'expert' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][2]: 'fsdp' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][0]: 'data' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][1]: 'expert' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][2]: 'fsdp' +input.input_partitioner.path_rank_to_partition[(None, 2)][1]: 'seq' +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 8192 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +input.source.vocab_cfg.sentencepiece_model_name: 'bpe_128k_c4.model' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.00015 +learner.optimizer.args[1].update_schedule.alpha: 0.1 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 983040 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 2000 +learner.optimizer.args[1].weight_decay: 0.1 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 983040 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'tpu-v6e-256.*' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: 64 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.common.utils.policy' +mesh_rules[0][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' +mesh_rules[0][1].config_modifiers[2].tpu_block_size: 1024 +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_shape[0]: 1 +mesh_shape[1]: -1 +mesh_shape[2]: 1 +mesh_shape[3]: 64 +mesh_shape[4]: 1 +mesh_shape[5]: 4 +model.batch_axis_names: None +model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' +model.decoder.dim: 12288 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0]: None +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 1 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.lm_head.klass: 'axlearn.common.decoder.LmHead' +model.decoder.lm_head.param_partition_spec[0]: None +model.decoder.lm_head.param_partition_spec[1]: 'model' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 0 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' +model.decoder.transformer.layer.feed_forward.hidden_dim.round_up_to_multiples_of: 256 +model.decoder.transformer.layer.feed_forward.hidden_dim.scale: 3.5 +model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.feed_forward.linear1.bias: False +model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.bias: False +model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.feed_forward.structure: 'prenorm' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' +model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' +model.decoder.transformer.layer.self_attention.attention.causal: True +model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8 +model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 +model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.flash_attention.layer.FlashAttention' +model.decoder.transformer.layer.self_attention.attention.kv_cache.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.kv_cache.klass: 'axlearn.common.attention.KVCache' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][2]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.self_attention.attention.num_heads: 96 +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][1]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][2]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.self_attention.attention.tpu_block_size: 512 +model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.self_attention.structure: 'prenorm' +model.decoder.transformer.num_layers: 80 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 131072 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.metrics.klass: 'axlearn.common.causal_lm.CompositeLossMetrics' +model.metrics.metrics['lm'].klass: 'axlearn.common.causal_lm.CrossEntropyLossMetrics' +model.metrics.metrics['aux'].klass: 'axlearn.common.causal_lm.AuxLossMetrics' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-flash_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-flash_init.txt new file mode 100644 index 000000000..ce0614cbc --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-flash_init.txt @@ -0,0 +1,10 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 12288], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(12288, 112, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(12288, 96, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(12288, 43008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(12288, 43008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(43008, 12288), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/output_norm/scale: constant(1.0) +decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(131072, 12288), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-flash_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-flash_regularizer.txt new file mode 100644 index 000000000..65733fb7f --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-flash_regularizer.txt @@ -0,0 +1,11 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/lm_head/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-fp8.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-fp8.txt new file mode 100644 index 000000000..6f4a63859 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-fp8.txt @@ -0,0 +1,280 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 50000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 983040 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 5000 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 983040 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 5000 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.input_dispatcher.global_logical_batch_size: 2048 +evalers['train'].input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 8192 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +evalers['train'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_128k_c4.model' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 983040 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 5000 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.input_dispatcher.global_logical_batch_size: 2048 +evalers['validation'].input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 8192 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +evalers['validation'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_128k_c4.model' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.input_dispatcher.global_logical_batch_size: 2048 +input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +input.input_partitioner.fn: 'axlearn.common.input_base.partition_by_path_rank' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][0]: 'data' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][1]: 'expert' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][2]: 'fsdp' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][0]: 'data' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][1]: 'expert' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][2]: 'fsdp' +input.input_partitioner.path_rank_to_partition[(None, 2)][1]: 'seq' +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 8192 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +input.source.vocab_cfg.sentencepiece_model_name: 'bpe_128k_c4.model' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.00015 +learner.optimizer.args[1].update_schedule.alpha: 0.1 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 983040 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 2000 +learner.optimizer.args[1].weight_decay: 0.1 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 983040 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'tpu-v6e-256.*' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: 64 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.common.utils.policy' +mesh_rules[0][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' +mesh_rules[0][1].config_modifiers[2].tpu_block_size: 1024 +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_shape[0]: 1 +mesh_shape[1]: -1 +mesh_shape[2]: 1 +mesh_shape[3]: 64 +mesh_shape[4]: 1 +mesh_shape[5]: 4 +model.batch_axis_names: None +model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' +model.decoder.dim: 12288 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0]: None +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 1 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.lm_head.klass: 'axlearn.common.decoder.LmHead' +model.decoder.lm_head.param_partition_spec[0]: None +model.decoder.lm_head.param_partition_spec[1]: 'model' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 0 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' +model.decoder.transformer.layer.feed_forward.hidden_dim.round_up_to_multiples_of: 256 +model.decoder.transformer.layer.feed_forward.hidden_dim.scale: 3.5 +model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.feed_forward.linear1.bias: False +model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.bias: False +model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.feed_forward.structure: 'prenorm' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' +model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' +model.decoder.transformer.layer.self_attention.attention.causal: True +model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8 +model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 +model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.attention.GroupedQueryAttention' +model.decoder.transformer.layer.self_attention.attention.kv_cache.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.kv_cache.klass: 'axlearn.common.attention.KVCache' +model.decoder.transformer.layer.self_attention.attention.num_heads: 96 +model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.self_attention.structure: 'prenorm' +model.decoder.transformer.num_layers: 80 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 131072 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.metrics.klass: 'axlearn.common.causal_lm.CompositeLossMetrics' +model.metrics.metrics['lm'].klass: 'axlearn.common.causal_lm.CrossEntropyLossMetrics' +model.metrics.metrics['aux'].klass: 'axlearn.common.causal_lm.AuxLossMetrics' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-fp8_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-fp8_init.txt new file mode 100644 index 000000000..ce0614cbc --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-fp8_init.txt @@ -0,0 +1,10 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 12288], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(12288, 112, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(12288, 96, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(12288, 43008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(12288, 43008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(43008, 12288), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/output_norm/scale: constant(1.0) +decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(131072, 12288), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-fp8_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-fp8_regularizer.txt new file mode 100644 index 000000000..65733fb7f --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-fp8_regularizer.txt @@ -0,0 +1,11 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/lm_head/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-flash-fp8.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-flash-fp8.txt new file mode 100644 index 000000000..1389620a4 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-flash-fp8.txt @@ -0,0 +1,313 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 50000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 983040 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 5000 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 983040 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 5000 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.input_dispatcher.global_logical_batch_size: 2048 +evalers['train'].input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 8192 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['train'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 983040 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 5000 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.input_dispatcher.global_logical_batch_size: 2048 +evalers['validation'].input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 8192 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['validation'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.input_dispatcher.global_logical_batch_size: 2048 +input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +input.input_partitioner.fn: 'axlearn.common.input_base.partition_by_path_rank' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][0]: 'data' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][1]: 'expert' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][2]: 'fsdp' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][0]: 'data' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][1]: 'expert' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][2]: 'fsdp' +input.input_partitioner.path_rank_to_partition[(None, 2)][1]: 'seq' +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 8192 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.00015 +learner.optimizer.args[1].update_schedule.alpha: 0.1 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 983040 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 2000 +learner.optimizer.args[1].weight_decay: 0.1 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 983040 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'tpu-v6e-256.*' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: 64 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.common.utils.policy' +mesh_rules[0][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' +mesh_rules[0][1].config_modifiers[2].tpu_block_size: 1024 +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_shape[0]: 1 +mesh_shape[1]: -1 +mesh_shape[2]: 1 +mesh_shape[3]: 64 +mesh_shape[4]: 1 +mesh_shape[5]: 4 +model.batch_axis_names: None +model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' +model.decoder.dim: 12288 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0]: None +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 128001 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.lm_head.klass: 'axlearn.common.decoder.LmHead' +model.decoder.lm_head.param_partition_spec[0]: None +model.decoder.lm_head.param_partition_spec[1]: 'model' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 128004 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' +model.decoder.transformer.layer.feed_forward.hidden_dim.round_up_to_multiples_of: 256 +model.decoder.transformer.layer.feed_forward.hidden_dim.scale: 3.5 +model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.feed_forward.linear1.bias: False +model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.bias: False +model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.feed_forward.structure: 'prenorm' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' +model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' +model.decoder.transformer.layer.self_attention.attention.causal: True +model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8 +model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 +model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.flash_attention.layer.FlashAttention' +model.decoder.transformer.layer.self_attention.attention.kv_cache.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.kv_cache.klass: 'axlearn.common.attention.KVCache' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][2]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.self_attention.attention.num_heads: 96 +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][1]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][2]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.self_attention.attention.tpu_block_size: 512 +model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.self_attention.structure: 'prenorm' +model.decoder.transformer.num_layers: 80 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 128256 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.metrics.klass: 'axlearn.common.causal_lm.CompositeLossMetrics' +model.metrics.metrics['lm'].klass: 'axlearn.common.causal_lm.CrossEntropyLossMetrics' +model.metrics.metrics['aux'].klass: 'axlearn.common.causal_lm.AuxLossMetrics' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-flash-fp8_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-flash-fp8_init.txt new file mode 100644 index 000000000..e4b66772f --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-flash-fp8_init.txt @@ -0,0 +1,10 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 12288], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(12288, 112, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(12288, 96, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(12288, 43008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(12288, 43008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(43008, 12288), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/output_norm/scale: constant(1.0) +decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(128256, 12288), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-flash-fp8_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-flash-fp8_regularizer.txt new file mode 100644 index 000000000..65733fb7f --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-flash-fp8_regularizer.txt @@ -0,0 +1,11 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/lm_head/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-flash.txt new file mode 100644 index 000000000..1389620a4 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-flash.txt @@ -0,0 +1,313 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 50000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 983040 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 5000 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 983040 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 5000 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.input_dispatcher.global_logical_batch_size: 2048 +evalers['train'].input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 8192 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['train'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 983040 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 5000 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.input_dispatcher.global_logical_batch_size: 2048 +evalers['validation'].input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 8192 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['validation'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.input_dispatcher.global_logical_batch_size: 2048 +input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +input.input_partitioner.fn: 'axlearn.common.input_base.partition_by_path_rank' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][0]: 'data' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][1]: 'expert' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][2]: 'fsdp' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][0]: 'data' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][1]: 'expert' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][2]: 'fsdp' +input.input_partitioner.path_rank_to_partition[(None, 2)][1]: 'seq' +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 8192 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.00015 +learner.optimizer.args[1].update_schedule.alpha: 0.1 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 983040 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 2000 +learner.optimizer.args[1].weight_decay: 0.1 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 983040 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'tpu-v6e-256.*' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: 64 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.common.utils.policy' +mesh_rules[0][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' +mesh_rules[0][1].config_modifiers[2].tpu_block_size: 1024 +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_shape[0]: 1 +mesh_shape[1]: -1 +mesh_shape[2]: 1 +mesh_shape[3]: 64 +mesh_shape[4]: 1 +mesh_shape[5]: 4 +model.batch_axis_names: None +model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' +model.decoder.dim: 12288 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0]: None +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 128001 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.lm_head.klass: 'axlearn.common.decoder.LmHead' +model.decoder.lm_head.param_partition_spec[0]: None +model.decoder.lm_head.param_partition_spec[1]: 'model' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 128004 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' +model.decoder.transformer.layer.feed_forward.hidden_dim.round_up_to_multiples_of: 256 +model.decoder.transformer.layer.feed_forward.hidden_dim.scale: 3.5 +model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.feed_forward.linear1.bias: False +model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.bias: False +model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.feed_forward.structure: 'prenorm' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' +model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' +model.decoder.transformer.layer.self_attention.attention.causal: True +model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8 +model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 +model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.flash_attention.layer.FlashAttention' +model.decoder.transformer.layer.self_attention.attention.kv_cache.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.kv_cache.klass: 'axlearn.common.attention.KVCache' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][2]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.self_attention.attention.num_heads: 96 +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][1]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][2]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.self_attention.attention.tpu_block_size: 512 +model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.self_attention.structure: 'prenorm' +model.decoder.transformer.num_layers: 80 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 128256 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.metrics.klass: 'axlearn.common.causal_lm.CompositeLossMetrics' +model.metrics.metrics['lm'].klass: 'axlearn.common.causal_lm.CrossEntropyLossMetrics' +model.metrics.metrics['aux'].klass: 'axlearn.common.causal_lm.AuxLossMetrics' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-flash_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-flash_init.txt new file mode 100644 index 000000000..e4b66772f --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-flash_init.txt @@ -0,0 +1,10 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 12288], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(12288, 112, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(12288, 96, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(12288, 43008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(12288, 43008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(43008, 12288), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/output_norm/scale: constant(1.0) +decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(128256, 12288), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-flash_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-flash_regularizer.txt new file mode 100644 index 000000000..65733fb7f --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-flash_regularizer.txt @@ -0,0 +1,11 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/lm_head/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-fp8.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-fp8.txt new file mode 100644 index 000000000..1d766fae4 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-fp8.txt @@ -0,0 +1,280 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 50000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 983040 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 5000 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 983040 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 5000 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.input_dispatcher.global_logical_batch_size: 2048 +evalers['train'].input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 8192 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['train'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 983040 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 5000 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.input_dispatcher.global_logical_batch_size: 2048 +evalers['validation'].input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 8192 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['validation'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.input_dispatcher.global_logical_batch_size: 2048 +input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +input.input_partitioner.fn: 'axlearn.common.input_base.partition_by_path_rank' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][0]: 'data' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][1]: 'expert' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][2]: 'fsdp' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][0]: 'data' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][1]: 'expert' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][2]: 'fsdp' +input.input_partitioner.path_rank_to_partition[(None, 2)][1]: 'seq' +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 8192 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.00015 +learner.optimizer.args[1].update_schedule.alpha: 0.1 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 983040 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 2000 +learner.optimizer.args[1].weight_decay: 0.1 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 983040 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'tpu-v6e-256.*' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: 64 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.common.utils.policy' +mesh_rules[0][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' +mesh_rules[0][1].config_modifiers[2].tpu_block_size: 1024 +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_shape[0]: 1 +mesh_shape[1]: -1 +mesh_shape[2]: 1 +mesh_shape[3]: 64 +mesh_shape[4]: 1 +mesh_shape[5]: 4 +model.batch_axis_names: None +model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' +model.decoder.dim: 12288 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0]: None +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 128001 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.lm_head.klass: 'axlearn.common.decoder.LmHead' +model.decoder.lm_head.param_partition_spec[0]: None +model.decoder.lm_head.param_partition_spec[1]: 'model' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 128004 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' +model.decoder.transformer.layer.feed_forward.hidden_dim.round_up_to_multiples_of: 256 +model.decoder.transformer.layer.feed_forward.hidden_dim.scale: 3.5 +model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.feed_forward.linear1.bias: False +model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.bias: False +model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.feed_forward.structure: 'prenorm' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' +model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' +model.decoder.transformer.layer.self_attention.attention.causal: True +model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8 +model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 +model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.attention.GroupedQueryAttention' +model.decoder.transformer.layer.self_attention.attention.kv_cache.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.kv_cache.klass: 'axlearn.common.attention.KVCache' +model.decoder.transformer.layer.self_attention.attention.num_heads: 96 +model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.self_attention.structure: 'prenorm' +model.decoder.transformer.num_layers: 80 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 128256 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.metrics.klass: 'axlearn.common.causal_lm.CompositeLossMetrics' +model.metrics.metrics['lm'].klass: 'axlearn.common.causal_lm.CrossEntropyLossMetrics' +model.metrics.metrics['aux'].klass: 'axlearn.common.causal_lm.AuxLossMetrics' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-fp8_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-fp8_init.txt new file mode 100644 index 000000000..e4b66772f --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-fp8_init.txt @@ -0,0 +1,10 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 12288], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(12288, 112, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(12288, 96, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(12288, 43008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(12288, 43008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(43008, 12288), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/output_norm/scale: constant(1.0) +decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(128256, 12288), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-fp8_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-fp8_regularizer.txt new file mode 100644 index 000000000..65733fb7f --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken-fp8_regularizer.txt @@ -0,0 +1,11 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/lm_head/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken.txt new file mode 100644 index 000000000..1d766fae4 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken.txt @@ -0,0 +1,280 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 50000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 983040 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 5000 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 983040 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 5000 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.input_dispatcher.global_logical_batch_size: 2048 +evalers['train'].input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 8192 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['train'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 983040 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 5000 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.input_dispatcher.global_logical_batch_size: 2048 +evalers['validation'].input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 8192 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['validation'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.input_dispatcher.global_logical_batch_size: 2048 +input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +input.input_partitioner.fn: 'axlearn.common.input_base.partition_by_path_rank' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][0]: 'data' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][1]: 'expert' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][2]: 'fsdp' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][0]: 'data' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][1]: 'expert' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][2]: 'fsdp' +input.input_partitioner.path_rank_to_partition[(None, 2)][1]: 'seq' +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 8192 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.00015 +learner.optimizer.args[1].update_schedule.alpha: 0.1 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 983040 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 2000 +learner.optimizer.args[1].weight_decay: 0.1 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 983040 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'tpu-v6e-256.*' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: 64 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.common.utils.policy' +mesh_rules[0][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' +mesh_rules[0][1].config_modifiers[2].tpu_block_size: 1024 +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_shape[0]: 1 +mesh_shape[1]: -1 +mesh_shape[2]: 1 +mesh_shape[3]: 64 +mesh_shape[4]: 1 +mesh_shape[5]: 4 +model.batch_axis_names: None +model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' +model.decoder.dim: 12288 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0]: None +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 128001 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.lm_head.klass: 'axlearn.common.decoder.LmHead' +model.decoder.lm_head.param_partition_spec[0]: None +model.decoder.lm_head.param_partition_spec[1]: 'model' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 128004 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' +model.decoder.transformer.layer.feed_forward.hidden_dim.round_up_to_multiples_of: 256 +model.decoder.transformer.layer.feed_forward.hidden_dim.scale: 3.5 +model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.feed_forward.linear1.bias: False +model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.bias: False +model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.feed_forward.structure: 'prenorm' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' +model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' +model.decoder.transformer.layer.self_attention.attention.causal: True +model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8 +model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 +model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.attention.GroupedQueryAttention' +model.decoder.transformer.layer.self_attention.attention.kv_cache.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.kv_cache.klass: 'axlearn.common.attention.KVCache' +model.decoder.transformer.layer.self_attention.attention.num_heads: 96 +model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.self_attention.structure: 'prenorm' +model.decoder.transformer.num_layers: 80 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 128256 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.metrics.klass: 'axlearn.common.causal_lm.CompositeLossMetrics' +model.metrics.metrics['lm'].klass: 'axlearn.common.causal_lm.CrossEntropyLossMetrics' +model.metrics.metrics['aux'].klass: 'axlearn.common.causal_lm.AuxLossMetrics' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken_init.txt new file mode 100644 index 000000000..e4b66772f --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken_init.txt @@ -0,0 +1,10 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 12288], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(12288, 112, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(12288, 96, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(12288, 43008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(12288, 43008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(43008, 12288), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/output_norm/scale: constant(1.0) +decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(128256, 12288), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken_regularizer.txt new file mode 100644 index 000000000..65733fb7f --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3-tiktoken_regularizer.txt @@ -0,0 +1,11 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/lm_head/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3.txt new file mode 100644 index 000000000..6f4a63859 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3.txt @@ -0,0 +1,280 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 50000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 983040 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 5000 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 983040 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 5000 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.input_dispatcher.global_logical_batch_size: 2048 +evalers['train'].input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 8192 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +evalers['train'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_128k_c4.model' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 983040 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 5000 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.input_dispatcher.global_logical_batch_size: 2048 +evalers['validation'].input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 8192 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +evalers['validation'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_128k_c4.model' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.per_feed_batch' +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.input_dispatcher.global_logical_batch_size: 2048 +input.input_dispatcher.klass: 'axlearn.common.input_dispatch.InputDispatcher' +input.input_partitioner.fn: 'axlearn.common.input_base.partition_by_path_rank' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][0]: 'data' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][1]: 'expert' +input.input_partitioner.path_rank_to_partition[(None, 1)][0][2]: 'fsdp' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][0]: 'data' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][1]: 'expert' +input.input_partitioner.path_rank_to_partition[(None, 2)][0][2]: 'fsdp' +input.input_partitioner.path_rank_to_partition[(None, 2)][1]: 'seq' +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 8192 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' +input.source.vocab_cfg.sentencepiece_model_name: 'bpe_128k_c4.model' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.00015 +learner.optimizer.args[1].update_schedule.alpha: 0.1 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 983040 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 2000 +learner.optimizer.args[1].weight_decay: 0.1 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 983040 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'tpu-v6e-256.*' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: 64 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: False +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'axlearn.common.utils.policy' +mesh_rules[0][1].config_modifiers[2].klass: 'axlearn.experiments.trainer_config_utils.V6eFlashConfigModifier' +mesh_rules[0][1].config_modifiers[2].tpu_block_size: 1024 +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_shape[0]: 1 +mesh_shape[1]: -1 +mesh_shape[2]: 1 +mesh_shape[3]: 64 +mesh_shape[4]: 1 +mesh_shape[5]: 4 +model.batch_axis_names: None +model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' +model.decoder.dim: 12288 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0]: None +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 1 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.lm_head.klass: 'axlearn.common.decoder.LmHead' +model.decoder.lm_head.param_partition_spec[0]: None +model.decoder.lm_head.param_partition_spec[1]: 'model' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 0 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' +model.decoder.transformer.layer.feed_forward.hidden_dim.round_up_to_multiples_of: 256 +model.decoder.transformer.layer.feed_forward.hidden_dim.scale: 3.5 +model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.feed_forward.linear1.bias: False +model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.bias: False +model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.feed_forward.structure: 'prenorm' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' +model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' +model.decoder.transformer.layer.self_attention.attention.causal: True +model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8 +model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 +model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.attention.GroupedQueryAttention' +model.decoder.transformer.layer.self_attention.attention.kv_cache.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.kv_cache.klass: 'axlearn.common.attention.KVCache' +model.decoder.transformer.layer.self_attention.attention.num_heads: 96 +model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.self_attention.structure: 'prenorm' +model.decoder.transformer.num_layers: 80 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 131072 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.metrics.klass: 'axlearn.common.causal_lm.CompositeLossMetrics' +model.metrics.metrics['lm'].klass: 'axlearn.common.causal_lm.CrossEntropyLossMetrics' +model.metrics.metrics['aux'].klass: 'axlearn.common.causal_lm.AuxLossMetrics' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3_init.txt new file mode 100644 index 000000000..ce0614cbc --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3_init.txt @@ -0,0 +1,10 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 12288], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(12288, 112, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(12288, 96, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(12288, 43008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(12288, 43008), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(43008, 12288), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/output_norm/scale: constant(1.0) +decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(131072, 12288), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3_regularizer.txt new file mode 100644 index 000000000..65733fb7f --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-150B-v3_regularizer.txt @@ -0,0 +1,11 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/lm_head/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index 57c910c70..4a617d1f6 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -15,6 +15,8 @@ import itertools from typing import Any, List, NamedTuple, Optional, Union +import jax +from absl import logging from jax.ad_checkpoint import checkpoint_policies as jax_remat_policies from axlearn.common import causal_lm, config @@ -67,7 +69,7 @@ from axlearn.experiments.text.gpt.common import scaled_hidden_dim from axlearn.experiments.trainer_config_utils import TrainerConfigFn, V6eFlashConfigModifier -MODEL_SIZES = ("test", "1B", "3B", "7B", "8B", "70B") +MODEL_SIZES = ("test", "1B", "3B", "7B", "8B", "70B", "150B") class Version(enum.Enum): @@ -113,6 +115,7 @@ class Version(enum.Enum): "test": 2 * (1024**4), # 2T tokens "7B": 2 * (1024**4), # 2T tokens "70B": 2 * (1024**4), # 2T tokens + "150B": 2 * (1024**4), # 2T tokens }, Version.V3: { "test": 15 * (1024**4), # 15T tokens @@ -120,6 +123,7 @@ class Version(enum.Enum): "3B": 15 * (1024**4), # 15T tokens "7B": 15 * (1024**4), # 15T tokens "70B": 15 * (1024**4), # 15T tokens + "150B": 15 * (1024**4), # 15T tokens }, Version.V3_TIKTOKEN: { "test": 15 * (1024**4), # 15T tokens @@ -127,6 +131,7 @@ class Version(enum.Enum): "3B": 15 * (1024**4), # 15T tokens "8B": 15 * (1024**4), # 15T tokens "70B": 15 * (1024**4), # 15T tokens + "150B": 15 * (1024**4), # 15T tokens }, } @@ -249,7 +254,6 @@ def get_trainer_kwargs( max_step = TOTAL_TOKENS[version][model_size] // tokens_per_batch max_sequence_length = MAX_SEQUENCE_LENGTH[version] train_batch_size = tokens_per_batch // max_sequence_length - # Whether to use grouped query attention. num_kv_heads = None if version in (Version.V3, Version.V3_TIKTOKEN): @@ -809,6 +813,116 @@ def get_trainer_kwargs( ), ), ) + elif model_size == "150B": + ################################################################################## + max_sequence_length = MAX_SEQUENCE_LENGTH[Version.V2] # 4096 + + # model_parallelism * fsdp == num_chips_in_trillium (256) + model_parallelism = 4 + fsdp = 64 + + current_pdbs = 0.5 + train_batch_size = int(current_pdbs * len(jax.devices())) + + # 16k * 4096 = 64M + tokens_per_batch = int(train_batch_size * max_sequence_length) + + # 32M tokens is the max global tokens we can train on. + # We must modify either the pdbs or the model sharding to accommodate 128 slices. + if tokens_per_batch > 32 * (1024**2): + tokens_per_batch = 32 * (1024**2) + # if we want to modify the pdbs: + current_pdbs = 0.25 + + # otherwise we can modify the model sharding. + # model_parallelism = 8 + # fsdp = 32 + + # 32M tokens is the max global tokens we can train on. + assert tokens_per_batch <= 32 * (1024**2) + assert fsdp * model_parallelism == 256 + + # 1 / model_parallelism = 1 / 4 = 0.25 + min_pdbs = 1 / model_parallelism + max_pdbs = 1 + + # More than 1 pdbs causes an OOM. + assert current_pdbs < max_pdbs + assert current_pdbs >= min_pdbs + + # maximum number of devices we can use this config on = + # train_batch_size // min_pdbs = 4096 / 0.25 = 16384 + max_devices = int(train_batch_size // min_pdbs) + + assert isinstance(train_batch_size, int) + assert isinstance(tokens_per_batch, int) + + logging.info( + ( + "******* DEBUGGING: max_sequence_length: %s, model_parallelism: %s," + " fsdp: %s, current_pdbs: %s, train_batch_size: %s," + " tokens_per_batch: %s, min_pdbs: %s, max_pdbs: %s, max_devices: %s" + ), + max_sequence_length, + model_parallelism, + fsdp, + current_pdbs, + train_batch_size, + tokens_per_batch, + min_pdbs, + max_pdbs, + max_devices, + ) + ################################################################################## + + trainer_kwargs = dict( + model_kwargs=dict( + num_layers=80, + hidden_dim=128 * 96, + num_heads=96, + # No GQA support in V1 models, so num_kv_heads is the same as num_heads. + num_kv_heads=None if version == Version.V1 else 8, + ffn_dim=scaled_hidden_dim(scale=3.5, round_up_to_multiples_of=256), + rope_theta=rope_theta, + shared_lm_head=False, + flash_attention=flash_attention, + ), + learner_kwargs=dict(peak_lr=1.5e-4, weight_decay=0.1), + max_sequence_length=max_sequence_length, + train_batch_size=train_batch_size, + max_step=100_000, # max_step, + save_every_n_steps=20, + mesh_shape=mesh_shape_from_axes(data=-1, fsdp=fsdp, model=model_parallelism), + mesh_rules=( + ( + # Target per-device token count = 4k. + # PDBS = 0.5 at 8k context. + # Each slice can train a batch size of 128. + "tpu-v6e-256.*", + ChainConfigModifier.default_config().set( + config_modifiers=[ + MeshShapeModifier.default_config().set( + mesh_shape=mesh_shape_from_axes(data=-1, fsdp=64, model=4) + ), + RematSpecModifier.default_config().set( + remat_policies={ + "model.decoder.transformer.layer": RematSpec( + prevent_cse=False, + policy=save_and_offload_only_these_names_regex( + names_which_can_be_offloaded=".*input", + names_which_can_be_saved=None, + offload_src="device", + offload_dst="pinned_host", + ), + ), + } + ), + V6eFlashConfigModifier.default_config(), + ], + ), + ), + ), + ) else: raise NotImplementedError(f"Unknown model size {model_size}.") model_kwargs = trainer_kwargs.pop("model_kwargs") @@ -920,6 +1034,12 @@ def trainer_configs( if model_size not in TOTAL_TOKENS[version]: # This combination does not exist. continue vocab_size = VOCAB_SIZE[version] + logging.info( + "******* DEBUGGING: version: %s, model_size: %s, flash_attention: %s", + version, + model_size, + flash_attention, + ) config_name = make_config_name( arch=arch, model_size=model_size, diff --git a/disrupt_nodes.sh b/disrupt_nodes.sh new file mode 100755 index 000000000..e647aba96 --- /dev/null +++ b/disrupt_nodes.sh @@ -0,0 +1,85 @@ +#!/bin/bash +# --- !!! WARNING: This script performs destructive operations continuously. Use with extreme caution. !!! --- +# USAGE: ./disrupt_nodes.sh +# +# Continuously finds a pod starting with and deletes its GCE node. +# WARNING: This script runs in a destructive loop. Use with extreme caution. +# +# Example: ./disupt_nodes.sh my-job-worker- +# --- + +# Check if a pod prefix was passed as an argument +if [ -z "$1" ]; then + echo "Usage: $0 " + echo "Please provide the jobset name as the pod prefix to identify the target pods." + exit 1 +fi + +# The prefix for the pods you want to target (passed as the first argument) +POD_PREFIX="$1" + +# Set your GCP Project ID and Zone here +# GCP_PROJECT_ID="cloud-tpu-best-effort-colo" +# GCP_ZONE="us-east5-a" + +GCP_PROJECT_ID="tpu-prod-env-one-vm" +GCP_ZONE="southamerica-west1-a" + +# Log file path +LOG_FILE="./disrupt_nodes.log" + +# Sleep duration in seconds (e.g., 3600 for 1 hour) +SLEEP_SECONDS=7200 + +# --- Script Starts --- + +# Function for logging with timestamps +log() { + echo "$(date '+%Y-%m-%d %H:%M:%S') - $1" | tee -a "${LOG_FILE}" +} + +# Ensure gcloud/kubectl are in PATH +# export PATH="/snap/bin:/usr/local/bin:/usr/bin:/bin:$PATH" + +log "--- Script initialized. Target pod prefix: '${POD_PREFIX}'. Loop starts now. ---" + +# Main loop +while true; do + log "--- New iteration started ---" + + # 1. Get the name of the first pod that matches the specified prefix. + # Using awk is slightly more robust than `grep | head | awk`. + # It finds the first line where the first column starts with the prefix, prints the name, and exits. + POD_NAME=$(kubectl get pods --no-headers=true | awk -v prefix="${POD_PREFIX}" '$1 ~ "^"prefix {print $1; exit}') + + if [ -z "${POD_NAME}" ]; then + log "No running pod found with prefix '${POD_PREFIX}'. Skipping deletion for this cycle." + else + log "Identified target pod: ${POD_NAME}" + + # 2. Get the node name (VM instance name) where the pod is running. + NODE_NAME=$(kubectl get pod "${POD_NAME}" -o=jsonpath='{.spec.nodeName}' 2>/dev/null) + + if [ -z "${NODE_NAME}" ]; then + log "Could not determine node for pod ${POD_NAME}. It might be terminating. Skipping." + else + log "Pod ${POD_NAME} is on node: ${NODE_NAME}" + + # 3. Delete the underlying Compute Engine VM instance. + log "Attempting to delete GCE instance: ${NODE_NAME} in zone ${GCP_ZONE}" + + # The --quiet flag suppresses the interactive confirmation prompt. + kubectl exec -it ${POD_NAME} -c ${POD_PREFIX} -- sh -c "kill -s SIGILL 1" 2>&1 | tee -a "${LOG_FILE}" + + # Check the exit code of the gcloud command + if [ ${PIPESTATUS[0]} -eq 0 ]; then + log "Successfully initiated deletion for node: ${NODE_NAME}" + else + log "ERROR: Failed to delete node: ${NODE_NAME}. See gcloud output above." + fi + fi + fi + + log "--- Iteration finished. Sleeping for ${SLEEP_SECONDS} seconds... ---" + sleep "${SLEEP_SECONDS}" +done diff --git a/docs/05-Goodput-Monitoring.md b/docs/05-Goodput-Monitoring.md index ca1452c19..cb17f6989 100644 --- a/docs/05-Goodput-Monitoring.md +++ b/docs/05-Goodput-Monitoring.md @@ -1,10 +1,14 @@ # ML Goodput Monitoring -AXLearn supports automatic measurement and upload of workload metrics such as -Goodput, Badput Breakdown and Step Time Deviation using the ML Goodput -Measurement library. +AXLearn supports automatic measurement and upload of a wide range of workload +metrics using the **ML Goodput Measurement** library. This includes: +* **Goodput** and **Badput Breakdown** +* **Step Metrics** (Ideal Step Time, Step Time Deviation, Last Productive Step etc.) +* **Workload Hang Metrics** (Disruption Count, Step Info) +* **Rolling Window Goodput & Badput Breakdown** The [ML Goodput Measurement](https://github.com/AI-Hypercomputer/ml-goodput-measurement) library currently supports monitoring workloads running on Google Cloud Platform. For more information on details of the library, visit the Github page or the [ml-goodput-measurement](https://pypi.org/project/ml-goodput-measurement/) PyPI package documentation. + ### What is Goodput Goodput is the metric that measures the efficiency of model training jobs, i.e. productive time spent on training progress proportional to the total time spent @@ -15,12 +19,26 @@ improve to get the most value from their accelerators. Badput is the metric that measures time that a workload spent on anything that is not productive training proportional to the total time spent by the workload. For example, the time spent in accelerator initialization, training preparation, -program startup, data loading, portions of checkpointing, disruptions and -wasted progress since the last checkpoint etc. all contribute to Badput. +program startup, data loading, portions of checkpointing, recovering from +disruptions, wasted progress since the last checkpoint etc. all contribute to Badput. + +The ML Goodput Measurement library exposes Badput Breakdown. Further details of +each bucket can be found [here](https://github.com/AI-Hypercomputer/ml-goodput-measurement?tab=readme-ov-file#badput-breakdown-details) + +## What is Rolling Window Goodput & Badput +The ML Goodput Measurement library allows users to monitor goodput and badput +breakdown metrics within specific, moving time windows. You can specify a list +of rolling window interval sizes in seconds, and the library will asynchronously +query and upload metrics calculated only within the context of those windows. +This is useful for understanding workload performance over recent, specific +durations (e.g., the last 24 hours). -The ML Goodput Measurement library exposes Badput Breakdown. Further details of each bucket can be found [here](https://github.com/AI-Hypercomputer/ml-goodput-measurement?tab=readme-ov-file#badput-breakdown-details) +If the workload's actual runtime timeline is shorter than a requested window size, +the entire runtime timeline of the workload is used for the metrics computation. -### What is Step Time Deviation +> **Note**: Both the standard (cumulative) and rolling window query APIs can be enabled simultaneously to get a complete picture of your workload's performance. + +### What are Ideal Step Time and Step Time Deviation Step Time Deviation is the metric that measures deviation of step time (in seconds) from ideal step time. It is the difference between the actual time @@ -33,8 +51,8 @@ The formula for step deviation is: Ideal step time is equal to the user-configured `ideal_step_time` if it is provided. If the user has not specified an ideal step time, then the ideal step -time is calculated as the average of the "normal" step times recorded for the -workload, where a "normal" step is defined as having a duration less than or +time is calculated as a weighted average of the "normal" step times recorded for +the workload, where a "normal" step is defined as having a duration less than or equal to `median + median absolute deviation * 3` of the sample space of step times. This computation requires at least 10 recorded steps. @@ -77,7 +95,7 @@ project, then do the following: Please use a unique workload name, unless you intend to monitor cumulative Goodput/Badput metrics of a previous workload along with your current workload. -### How to Monitor Goodput and Badput +### How to Monitor Cumulative Goodput Metrics To enable Goodput recording and monitoring on AXLearn, follow the example below. @@ -94,24 +112,22 @@ To enable Goodput recording and monitoring on AXLearn, follow the example below. --recorder_spec=upload_interval=30 \ ``` -### How to Monitor Step Time Deviation +### How to Monitor Rolling Window Goodput Metrics -AXLearn enables step time deviation monitoring by default. You can configure -the upload frequency by setting -`--recorder_spec=step_deviation_interval_seconds=30`. To disable step deviation -set `--recorder_spec=step_deviation_interval_seconds=-1`. +To enable rolling window metrics, set `enable_rolling_window_goodput_monitoring` to `True` +and provide a list of interval sizes for `rolling_window_size` in seconds: ```bash - axlearn gcp launch run --instance_type=tpu-v5litepod-16 \ +axlearn gcp launch run --instance_type=tpu-v5litepod-16 \ --bundler_type=artifactregistry --bundler_spec=image=tpu \ --bundler_spec=dockerfile=Dockerfile \ - --name= \ - -- python3 -m ...training-config... \ + -- python3 -m my_training_job \ --recorder_type=axlearn.cloud.gcp.measurement:goodput \ --recorder_spec=name= \ --recorder_spec=upload_dir=my-output-directory/summaries \ --recorder_spec=upload_interval=30 \ - --recorder_spec=step_deviation_interval_seconds=30 \ + --recorder_spec=enable_rolling_window_goodput_monitoring=True \ + --recorder_spec=rolling_window_size=86400,259200,432000 ``` ### Visualize on Tensorboard @@ -121,12 +137,16 @@ set `--recorder_spec=step_deviation_interval_seconds=-1`. ### Enabling Google Cloud Monitoring -AXLearn has an additional option of pushing goodput, badput and step time -deviation metrics to Google Cloud Monitoring. By default if goodput monitoring -is enabled, the data gets published to Google Cloud Monitoring. Set the variables -`enable_gcp_goodput_metrics` and `enable_gcp_step_deviation_metrics` to `False` in -`goodput_monitoring.GCPOptions` in `cloud/gcp/measurement.py` to disable goodput and step_deviation -uploads to GCM respectively. +By default, when Goodput monitoring is enabled via the recorder, AXLearn automatically pushes metrics to Google Cloud Monitoring. + +- **Cumulative Metrics** are enabled by default when you specify the `recorder_type`. + To disable this, you would need to set `enable_gcp_goodput_metrics` to `False` in + `goodput_monitoring.GCPOptions` within the `cloud/gcp/measurement.py` file. +- **Rolling Window Metrics** can be explicitly enabled by setting + `enable_rolling_window_goodput_monitoring` to `True` and providing window sizes + via `rolling_window_size`. + +You can enable either cumulative monitoring, rolling window monitoring, or both simultaneously. ```bash axlearn gcp launch run --instance_type=tpu-v5litepod-16 \ @@ -138,7 +158,8 @@ uploads to GCM respectively. --recorder_spec=name= \ --recorder_spec=upload_dir=my-output-directory/summaries \ --recorder_spec=upload_interval=30 \ - --recorder_spec=step_deviation_interval_seconds=30 \ + --recorder_spec=enable_rolling_window_goodput_monitoring=True \ + --recorder_spec=rolling_window_size=86400,604800 ``` #### Visualization in Google Cloud Monitoring @@ -159,3 +180,38 @@ To visualize the collected metrics within Google Cloud Monitoring: c. [**Performance:**](https://cloud.google.com/monitoring/api/metrics_gcp#:~:text=workload/performance) Represents the workload's performance metric, specifically step deviation in this context, measured by `compute.googleapis.com/workload/performance`. + +#### Google Cloud Monitoring Dashboard: Goodput Monitor + +Following are instructions for deploying a custom dashboard `goodput_dashboard.json` +to your Google Cloud project's Monitoring console. This dashboard +offers a comprehensive view of "Goodput" metrics, helping you monitor the +your workloads and set up custom alerts for "events" such as performance degradation. + + +#### Deployment Steps + +Follow these steps to create a new custom dashboard using the provided JSON +configuration: + +1. **Navigate to the Monitoring Console**: In your Google Cloud project, + go to the **Monitoring** section. From the left-hand navigation menu, + select **Dashboards**. + +2. **Create Custom Dashboard**: Click the **Create Custom Dashboard** button. + +3. **Use JSON Editor**: In the new dashboard interface, select the + **JSON editor** option. + +4. **Copy and Save Configuration**: Open the [goodput_dashboard.json](https://github.com/AI-Hypercomputer/ml-goodput-measurement/blob/main/ml_goodput_measurement/dashboards/goodput_dashboard.json) file. + Copy its entire content and paste it into the JSON editor. Once pasted, + click **Save**. + + +Your "Goodput Monitor" dashboard should now be visible and operational within +your custom dashboards list. + +> **_NOTE:_** This dashboard is intended to be a starting point for your +> monitoring needs. We recommend customizing it to meet your specific needs. +> Please refer to the [Monitoring Dashboard documentation](https://cloud.google.com/monitoring/dashboards) +> for further guidance and customization options. diff --git a/patches/shard_map.py.patch b/patches/shard_map.py.patch new file mode 100644 index 000000000..e6e10104f --- /dev/null +++ b/patches/shard_map.py.patch @@ -0,0 +1,27 @@ +--- shard_map_orig.py 2025-06-18 01:27:00.782665547 +0000 ++++ shard_map.py 2025-06-18 01:26:06.798346281 +0000 +@@ -1793,10 +1793,10 @@ + ) -> tuple[core.JaxprEqn, core.JaxprEqn, Sequence[bool], Sequence[bool], + list[core.Var]]: + jaxpr, mesh = eqn.params['jaxpr'], eqn.params['mesh'] +- auto = eqn.params['auto'] +- with _extend_axis_env(mesh, auto): ++ manual_axes = frozenset() ++ with (_extend_axis_env(mesh, manual_axes), use_abstract_mesh(_as_manual_mesh(mesh, manual_axes))): + jaxpr_known, jaxpr_staged, unks_out, inst_out, num_res = \ +- pe.partial_eval_jaxpr_custom(jaxpr, unks_in, inst_in, False, False, saveable) ++ pe.partial_eval_jaxpr_custom(jaxpr, unks_in, inst_in, False, False, saveable) + num_out_primals = len(jaxpr_known.outvars) - num_res + in_fwd = pe._jaxpr_forwarding(jaxpr_known)[num_out_primals:] + out_vars, res_vars = split_list(jaxpr_known.outvars, [num_out_primals]) +@@ -1804,8 +1804,8 @@ + out_fwd = [idx_map.get(id(v)) for v in res_vars] + which = [f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)] + mesh = eqn.params['mesh'] +- with (_extend_axis_env(mesh, auto), +- use_abstract_mesh(_as_manual_mesh(mesh, auto))): ++ with (_extend_axis_env(mesh, manual_axes), ++ use_abstract_mesh(_as_manual_mesh(mesh, frozenset()))): + jaxpr_known = pe.prune_jaxpr_outputs(jaxpr_known, [True] * num_out_primals + which) + jaxpr_known, jaxpr_staged = _add_reshapes(which, jaxpr_known, jaxpr_staged) + jaxpr_known = core.remove_named_axis_effects(jaxpr_known, mesh.axis_names) diff --git a/pyproject.toml b/pyproject.toml index 57e415e7e..4c454fb97 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,7 +98,7 @@ gcp = [ "google-cloud-compute==1.19.2", # Needed for region discovery for CloudBuild API access. "google-cloud-core==2.3.3", "google-cloud-build==3.24.1", - "ml-goodput-measurement==0.0.10", + "ml-goodput-measurement==0.0.14", "pika==1.3.2", # used by event queue "pyOpenSSL>=22.1.0", # compat with cryptography version. "tpu-info==0.2.0", # For TPU monitoring from libtpu. https://github.com/AI-Hypercomputer/cloud-accelerator-diagnostics/tree/main/tpu_info