Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,15 @@ ARG EXTRAS=
ENV UV_FIND_LINKS=https://storage.googleapis.com/jax-releases/libtpu_releases.html
# Ensure we install the TPU version, even if building locally.
# Jax will fallback to CPU when run on a machine without TPU.
RUN uv pip install --prerelease=allow .[core,tpu] && uv cache clean
COPY libtpu.so /root/libtpu.so
RUN uv pip install --prerelease=allow .[core,gcp,tpu] && uv cache clean
RUN uv pip install libtpu==0.0.14

# Add this line to print the installed version of libtpu.
RUN pip show libtpu | grep Version
RUN pip show jax | grep Version
RUN pip show jaxlib | grep Version

RUN if [ -n "$EXTRAS" ]; then uv pip install .[$EXTRAS] && uv cache clean; fi
COPY . .

Expand Down
2 changes: 1 addition & 1 deletion axlearn/cloud/gcp/jobs/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,6 +780,6 @@ def _wrapped_usage(


if __name__ == "__main__":
configure_logging(logging.INFO)
configure_logging(logging.DEBUG)
_private_flags()
app.run(main)
28 changes: 13 additions & 15 deletions axlearn/cloud/gcp/jobset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import io
import logging
import math
import os
from dataclasses import dataclass
from typing import Any, Optional, Sequence
Expand All @@ -27,10 +26,7 @@
)
from axlearn.cloud.gcp.config import gcp_settings
from axlearn.cloud.gcp.node_pool import PRE_PROVISIONER_LABEL
from axlearn.cloud.gcp.system_characteristics import (
GCE_MACHINE_TYPE_TO_MEMORY_CHARACTERISTICS,
USER_FACING_NAME_TO_SYSTEM_CHARACTERISTICS,
)
from axlearn.cloud.gcp.system_characteristics import USER_FACING_NAME_TO_SYSTEM_CHARACTERISTICS
from axlearn.cloud.gcp.tpu import get_default_env, infer_tpu_workers
from axlearn.cloud.gcp.utils import validate_jobset_name
from axlearn.common.compiler_options import infer_tpu_type
Expand Down Expand Up @@ -451,15 +447,17 @@ def _build_container(self) -> Nested[Any]:
if cfg.enable_tpu_ici_resiliency is not None:
env_vars["ENABLE_ICI_RESILIENCY"] = str(cfg.enable_tpu_ici_resiliency).lower()

env_vars["TPU_LIBRARY_PATH"] = "/root/libtpu.so"

resources = {"limits": {"google.com/tpu": system.chips_per_vm}}
# Set request memory by host machine type.
machine_memory_gi = GCE_MACHINE_TYPE_TO_MEMORY_CHARACTERISTICS.get(
system.gce_machine_type, None
)
if machine_memory_gi is not None:
request_memory_gi = machine_memory_gi * _MEMORY_REQUEST_PERCENTAGE
resources["limits"]["memory"] = f"{machine_memory_gi}Gi"
resources["requests"] = {"memory": f"{math.floor(request_memory_gi)}Gi"}
# # Set request memory by host machine type.
# machine_memory_gi = GCE_MACHINE_TYPE_TO_MEMORY_CHARACTERISTICS.get(
# system.gce_machine_type, None
# )
# if machine_memory_gi is not None:
# request_memory_gi = machine_memory_gi * _MEMORY_REQUEST_PERCENTAGE
# resources["limits"]["memory"] = f"{machine_memory_gi}Gi"
# resources["requests"] = {"memory": f"{math.floor(request_memory_gi)}Gi"}

k8s_env_vars = [dict(name=k, value=str(v)) for k, v in env_vars.items()]
k8s_env_vars.append(
Expand Down Expand Up @@ -509,8 +507,8 @@ def _build_uploader_container(
interval_s = 60
sync_command = f"while true; do gsutil -m rsync -r {src} {dst}; sleep {interval_s}; done"
resources = {
"requests": {"cpu": "100m", "memory": "128Mi"},
"limits": {"cpu": "500m", "memory": "256Mi"},
# "requests": {"cpu": "100m", "memory": "128Mi"},
# "limits": {"cpu": "500m", "memory": "256Mi"},
}
return dict(
name="output-uploader",
Expand Down
202 changes: 130 additions & 72 deletions axlearn/cloud/gcp/measurement.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

"""Measurement utils for GCP.

For detailed documentation and advanced usage, please refer to:
axlearn/docs/05-Goodput-Monitoring.md

Example:

# Enable Goodput when launching an AXLearn training job
Expand All @@ -13,10 +16,14 @@
--recorder_spec=name=my-run-with-goodput \
--recorder_spec=upload_dir=my-output-directory/summaries \
--recorder_spec=upload_interval=30 \
--recorder_spec=step_deviation_interval_seconds=30
--recorder_spec=rolling_window_size=86400,604800

"""

import contextlib
import os
from typing import Optional, Sequence

import jax
from absl import flags, logging
from ml_goodput_measurement import goodput
Expand All @@ -38,13 +45,19 @@ class Config(measurement.Recorder.Config):
Attributes:
upload_dir: Directory to store metrics for the monitor.
upload_interval: Time interval (seconds) for monitoring uploads.
step_deviation_interval_seconds: Time interval (seconds) for step deviation metrics
uploads. -1 to disable step deviation uploads.
See "How to Monitor Cumulative Goodput Metrics" in
docs/05-Goodput-Monitoring.md for more details.
rolling_window_size: A sequence of integers defining the rolling window sizes in
seconds.
See "How to Monitor Rolling Window Goodput Metrics" in
docs/05-Goodput-Monitoring.md for more details.
jax_backend: Jax backend type to infer Pathways environment.
"""

upload_dir: Required[str] = REQUIRED
upload_interval: Required[int] = REQUIRED
step_deviation_interval_seconds: int = 30 # Default to 30 seconds
rolling_window_size: Sequence[int] = []
jax_backend: Optional[str] = None

@classmethod
def from_flags(cls, fv: flags.FlagValues) -> "GoodputRecorder":
Expand All @@ -53,68 +66,78 @@ def from_flags(cls, fv: flags.FlagValues) -> "GoodputRecorder":
`fv.recorder_spec` will be interpreted as a list of `key=value` pairs; config names
corresponding to keys will be set to the corresponding values. A GoodputRecorder can
additionally take in following Tensorboard configs in the recorder_spec:
- upload_dir: The directory to write Tensorboard data to.
- upload_interval: The time interval in seconds at which to query and upload data
to Tensorboard.
- step_deviation_interval_seconds: Time interval (seconds) for step deviation metrics
uploads. Set to less than or equal to 0 to disable step deviation uploads.
- upload_dir: The directory to write Tensorboard data to.
- upload_interval: The time interval in seconds at which to query and upload data
to Tensorboard.
- rolling_window_size: Comma-separated list of integers representing rolling window
sizes in seconds.
- jax_backend: The type of jax backend.
"""
cfg: measurement.Recorder.Config = cls.default_config()
cfg = maybe_set_config(cfg, **parse_kv_flags(fv.recorder_spec, delimiter="="))
return cfg.instantiate()
parsed_flags = parse_kv_flags(fv.recorder_spec, delimiter="=")
if "upload_interval" in parsed_flags:
parsed_flags["upload_interval"] = int(parsed_flags["upload_interval"])
if "rolling_window_size" in parsed_flags and isinstance(
parsed_flags["rolling_window_size"], str
):
parsed_flags["rolling_window_size"] = [
int(x) for x in parsed_flags["rolling_window_size"].split(",")
]
return maybe_set_config(cfg, **parsed_flags).instantiate()

def __init__(self, cfg):
super().__init__(cfg)
cfg: GoodputRecorder.Config = self.config
self._recorder = None
self._monitor = None

def record(self, event: measurement.Event, *args, **kwargs):
# Lazily instantiate the recorder. This avoids invoking jax before setup is complete.
self._recorder: Optional[goodput.GoodputRecorder] = None
self._monitor: Optional[goodput_monitoring.GoodputMonitor] = None
self._rolling_window_monitor: Optional[goodput_monitoring.GoodputMonitor] = None
self._job_name = cfg.name
self._logger_name = f"goodput_logger_{cfg.name}"

@contextlib.contextmanager
def record_event(self, event: measurement.Event, *args, **kwargs):
"""Records a goodput event using a context manager."""
# Lazily instantiate the recorder if it hasn't been already.
if self._recorder is None:
cfg: GoodputRecorder.Config = self.config
if jax.process_index() == 0:
logging.info("Lazily instantiating goodput recorder.")
self._recorder = goodput.GoodputRecorder(
job_name=cfg.name,
logger_name=f"goodput_logger_{cfg.name}",
job_name=self._job_name,
logger_name=self._logger_name,
logging_enabled=(jax.process_index() == 0),
)

if event == measurement.Event.START_JOB:
self._recorder.record_job_start_time(*args, **kwargs)
elif event == measurement.Event.END_JOB:
self._recorder.record_job_end_time(*args, **kwargs)
elif event == measurement.Event.START_STEP:
self._recorder.record_step_start_time(*args, **kwargs)
elif event == measurement.Event.START_ACCELERATOR_INIT:
self._recorder.record_tpu_init_start_time(*args, **kwargs)
elif event == measurement.Event.END_ACCELERATOR_INIT:
self._recorder.record_tpu_init_end_time(*args, **kwargs)
elif event == measurement.Event.START_TRAINING_PREPARATION:
self._recorder.record_training_preparation_start_time(*args, **kwargs)
elif event == measurement.Event.END_TRAINING_PREPARATION:
self._recorder.record_training_preparation_end_time(*args, **kwargs)
elif event == measurement.Event.START_DATA_LOADING:
self._recorder.record_data_loading_start_time(*args, **kwargs)
elif event == measurement.Event.END_DATA_LOADING:
self._recorder.record_data_loading_end_time(*args, **kwargs)
elif event == measurement.Event.START_CUSTOM_BADPUT_EVENT:
self._recorder.record_custom_badput_event_start_time(*args, **kwargs)
elif event == measurement.Event.END_CUSTOM_BADPUT_EVENT:
self._recorder.record_custom_badput_event_end_time(*args, **kwargs)
else:
logging.log_first_n(
logging.WARNING,
"Ignoring unknown event %s",
1,
event,
start_method_name = f"record_{event.value}_start_time"
end_method_name = f"record_{event.value}_end_time"

record_event_start = getattr(self._recorder, start_method_name, None)
record_event_end = getattr(self._recorder, end_method_name, None)

try:
if record_event_start:
record_event_start(*args, **kwargs)
except RuntimeError as e:
logging.warning(
"Failed to record start of event %s. Error: %s", event.value, e, exc_info=True
)

def start_monitoring(self, *args, **kwargs):
"""Starts Monitoring of Goodput.
try:
yield
finally:
try:
if record_event_end:
record_event_end(*args, **kwargs)
except RuntimeError as e:
logging.warning(
"Failed to record end of event %s. Error: %s", event.value, e, exc_info=True
)

@contextlib.contextmanager
def _maybe_monitor_goodput(self, *args, **kwargs):
"""Monitor cumulative goodput if enabled.

Instantiate ml-goodput-measurement's GoodputMonitor to asynchronously calculate
Goodput and Badput at the upload_interval and upload to the specified TensorBoard
directory.
Goodput, Badput, Step & Disruption Information at the upload_interval to the
specified TensorBoard directory and Google Cloud Monitoring.
Note: This function requires initialization of distributed JAX before it is called.
If there are internal GCP errors from querying and uploading data, these will be
logged without affecting the workload. GoodputMonitor logs will provide further
Expand All @@ -123,33 +146,68 @@ def start_monitoring(self, *args, **kwargs):
Default behavior is to push metrics to Google Cloud Monitoring.
This behavior can be overridden by configuring `goodput_monitoring.GCPOptions`
"""
cfg: GoodputRecorder.Config = self.config
include_step_deviation = True
if jax.process_index() == 0:
if jax.process_index() != 0:
yield
return
try:
if self._monitor is None:
if int(cfg.step_deviation_interval_seconds) <= 0:
include_step_deviation = False

gcp_options = goodput_monitoring.GCPOptions(
enable_gcp_goodput_metrics=True,
enable_gcp_step_deviation_metrics=include_step_deviation,
)
self._monitor = goodput_monitoring.GoodputMonitor(
job_name=cfg.name,
logger_name=f"goodput_logger_{cfg.name}",
tensorboard_dir=cfg.upload_dir,
upload_interval=int(cfg.upload_interval),
job_name=self._job_name,
logger_name=self._logger_name,
tensorboard_dir=self.config.upload_dir,
upload_interval=self.config.upload_interval,
monitoring_enabled=True,
pathway_enabled=self.config.jax_backend == "proxy",
include_badput_breakdown=True,
include_step_deviation=include_step_deviation,
step_deviation_interval_seconds=int(cfg.step_deviation_interval_seconds),
gcp_options=gcp_options,
)

self._monitor.start_goodput_uploader(*args, **kwargs)
logging.info("Started Goodput upload to Tensorboard & GCM in the background!")
if include_step_deviation:
self._monitor.start_step_deviation_uploader(*args, **kwargs)
yield
finally:
if self._monitor:
self._monitor.stop_goodput_uploader()
logging.info("Flushed final metrics and safe exited from Goodput monitoring.")

@contextlib.contextmanager
def _maybe_monitor_rolling_window_goodput(self):
"""Monitor rolling window goodput if enabled."""
if not self.config.rolling_window_size or jax.process_index() != 0:
yield
return
try:
if self._rolling_window_monitor is None:
rolling_window_tensorboard_dir = os.path.join(
self.config.upload_dir, f"rolling_window_{self.config.name}"
)
self._rolling_window_monitor = goodput_monitoring.GoodputMonitor(
job_name=self._job_name,
logger_name=self._logger_name,
tensorboard_dir=rolling_window_tensorboard_dir,
upload_interval=self.config.upload_interval,
monitoring_enabled=True,
pathway_enabled=self.config.jax_backend == "proxy",
include_badput_breakdown=True,
)
self._rolling_window_monitor.start_rolling_window_goodput_uploader(
self.config.rolling_window_size
)
logging.info("Started Rolling Window Goodput monitoring in the background!")
yield
finally:
if self._rolling_window_monitor:
self._rolling_window_monitor.stop_rolling_window_goodput_uploader()
logging.info(
"Started Step Deviation upload to Tensorboard & GCM in the background!"
"Flushed final metrics and safe exited from Rolling Window Goodput monitoring."
)

def maybe_monitor_all_goodput(self):
goodput_monitor_manager = self._maybe_monitor_goodput()
rolling_goodput_monitor_manager = self._maybe_monitor_rolling_window_goodput()

@contextlib.contextmanager
def monitor_goodput():
with goodput_monitor_manager, rolling_goodput_monitor_manager:
yield

return monitor_goodput()
Loading