diff --git a/Dockerfile b/Dockerfile index d1cd4dbe6..18c8ebf9f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -2,6 +2,7 @@ ARG TARGET=base ARG BASE_IMAGE=ubuntu:22.04 +ARG BASE_IMAGE_COLOCATED=us-docker.pkg.dev/cloud-tpu-v2-images/pathways-colocated-python/sidecar:2025_10_29-python_3.10-jax_0.6.2 FROM ${BASE_IMAGE} AS base @@ -102,11 +103,39 @@ COPY pyproject.toml README.md /root/ RUN uv pip install -qq --prerelease=allow .[core,tpu] && uv cache clean RUN if [ -n "$EXTRAS" ]; then uv pip install -qq .[$EXTRAS] && uv cache clean; fi RUN if [ "$INSTALL_PATHWAYS_JAXLIB" = "true" ]; then \ - uv pip install --prerelease=allow "jaxlib==0.5.3.dev20250918" \ + uv pip install --prerelease=allow "jaxlib==0.6.2.dev20251021" \ --find-links https://storage.googleapis.com/axlearn-wheels/wheels.html; \ fi COPY . . +################################################################################ +# Colocated Python container spec. # +################################################################################ + +FROM ${BASE_IMAGE_COLOCATED} AS colocated-python + +WORKDIR /app +COPY . . + +# Install the additional user-provided dependencies, strictly enforcing the rules +# from the base image's constraints file. +RUN \ + # 1. Install user-provided dependencies with modified constraints + grep -v "^numpy" /opt/venv/server_constraints.txt | grep -v "^scipy" > /tmp/modified_constraints.txt && \ + echo "--> Installing user-provided dependencies..." && \ + uv pip install ".[core,gcp]" -c /tmp/modified_constraints.txt && \ + \ + # 2. Override numpy and scipy with specific versions + uv pip install numpy==2.1.1 scipy==1.15.3 && \ + \ + # 3. Verify that the colocated_python_cpu_client is present. + echo "--> Verifying JAX patch integrity..." && \ + python -c "from jax._src.lib import _jax; _jax.colocated_python_cpu_client" && \ + echo "--> JAX patch verification successful." && \ + \ + # 4. Clean the cache to keep the image slim. + uv cache clean + ################################################################################ # GPU container spec. # ################################################################################ diff --git a/axlearn/cloud/gcp/bundler.py b/axlearn/cloud/gcp/bundler.py index 8070f3257..c2e2dd7aa 100644 --- a/axlearn/cloud/gcp/bundler.py +++ b/axlearn/cloud/gcp/bundler.py @@ -60,6 +60,7 @@ from axlearn.cloud.common.utils import canonicalize_to_list, to_bool from axlearn.cloud.gcp.cloud_build import wait_for_cloud_build from axlearn.cloud.gcp.config import gcp_settings +from axlearn.cloud.gcp.pathways_utils import _COLOCATED_PYTHON_SIDECAR_NAME from axlearn.cloud.gcp.utils import common_flags from axlearn.common.config import REQUIRED, Required, config_class, maybe_set_config @@ -98,20 +99,46 @@ class ArtifactRegistryBundler(DockerBundler): TYPE = "artifactregistry" + @config_class + class Config(DockerBundler.Config): + """Configures ArtifactRegistryBundler. + + Attributes: + enable_colocated_python: Applicable only to Pathways jobs. Whether to build a Colocated + Python sidecar image alongside the main image. The sidecar image name will be + "{main_image_name}-colocated-sidecar". + """ + + enable_colocated_python: bool = False + @classmethod def from_spec(cls, spec: list[str], *, fv: Optional[flags.FlagValues]) -> DockerBundler.Config: - cfg = super().from_spec(spec, fv=fv) + cfg: ArtifactRegistryBundler.Config = super().from_spec(spec, fv=fv) cfg.repo = cfg.repo or gcp_settings("docker_repo", required=False, fv=fv) cfg.dockerfile = cfg.dockerfile or gcp_settings("default_dockerfile", required=False, fv=fv) + cfg.enable_colocated_python = cfg.enable_colocated_python or gcp_settings( + "enable_colocated_python", required=False, fv=fv + ) return cfg - def _build_and_push(self, *args, **kwargs): + def _build_and_push(self, *args, image: str, **kwargs): cfg = self.config subprocess.run( ["gcloud", "auth", "configure-docker", registry_from_repo(cfg.repo)], check=True, ) - return super()._build_and_push(*args, **kwargs) + + if cfg.enable_colocated_python: + # Build Colocated Python sidecar image + _, tag = image.rsplit(":", maxsplit=1) + colocated_bundler = cfg.set( + image=_COLOCATED_PYTHON_SIDECAR_NAME, + target="colocated-python", + enable_colocated_python=False, + ).instantiate() + colocated_bundler.bundle(tag=tag) + + return super()._build_and_push(*args, image=image, **kwargs) @register_bundler @@ -129,6 +156,9 @@ class Config(BaseDockerBundler.Config): from flags. is_async: Whether to build asynchronously. If True, callers should invoke `wait_until_finished()` to wait for bundling to complete. + enable_colocated_python: Applicable only to Pathways jobs. Whether to build a Colocated + Python sidecar image alongside the main image. The sidecar image name will be + "{main_image_name}-colocated-sidecar". """ # GCP project. @@ -138,6 +168,7 @@ class Config(BaseDockerBundler.Config): # If provided, should be the identifier of a private worker pool. # See: https://cloud.google.com/build/docs/private-pools/private-pools-overview private_worker_pool: Optional[str] = None + enable_colocated_python: bool = False @classmethod def from_spec( @@ -148,6 +179,9 @@ def from_spec( cfg.repo = cfg.repo or gcp_settings("docker_repo", required=False, fv=fv) cfg.dockerfile = cfg.dockerfile or gcp_settings("default_dockerfile", required=False, fv=fv) cfg.is_async = to_bool(cfg.is_async) + cfg.enable_colocated_python = cfg.enable_colocated_python or gcp_settings( + "enable_colocated_python", required=False, fv=fv + ) return cfg # pylint: disable-next=no-self-use,unused-argument @@ -175,9 +209,14 @@ def _build_and_push( ) image_path, image_tag = image.rsplit(":", maxsplit=1) latest_tag = f"{image_path}:latest" - cloudbuild_yaml = f""" -steps: -- name: "gcr.io/cloud-builders/docker" + + # Build steps - start with main image + build_steps = [] + images_list = [f'"{image}"', f'"{latest_tag}"'] + + # Main image build step + build_steps.append( + f"""- name: "gcr.io/cloud-builders/docker" args: [ "build", "-f", "{os.path.relpath(dockerfile, context)}", @@ -193,11 +232,43 @@ def _build_and_push( "." ] env: - - "DOCKER_BUILDKIT=1" + - "DOCKER_BUILDKIT=1\"""" + ) + + # Add colocated image build step if required + if cfg.enable_colocated_python: + colocated_image_path = f"{cfg.repo}/{_COLOCATED_PYTHON_SIDECAR_NAME}" + colocated_image = f"{colocated_image_path}:{image_tag}" + colocated_latest_image = f"{colocated_image_path}:latest" + + build_steps.append( + f"""- name: "gcr.io/cloud-builders/docker" + args: [ + "build", + "-f", "{os.path.relpath(dockerfile, context)}", + "-t", "{colocated_image}", + "-t", "{colocated_latest_image}", + "--target", "colocated-python", + "--cache-from", "{colocated_image}", + "--cache-from", "{colocated_latest_image}", + {cache_from} + {build_platform} + {build_args} + {labels} + "." + ] + env: + - "DOCKER_BUILDKIT=1\"""" + ) + + images_list.extend([f'"{colocated_image}"', f'"{colocated_latest_image}"']) + + cloudbuild_yaml = f""" +steps: +{chr(10).join(build_steps)} timeout: 3600s images: -- "{image}" -- "{latest_tag}" +{chr(10).join([f"- {img}" for img in images_list])} tags: [{image_tag}] options: logging: CLOUD_LOGGING_ONLY diff --git a/axlearn/cloud/gcp/examples/colocated_python_benchmark.py b/axlearn/cloud/gcp/examples/colocated_python_benchmark.py new file mode 100644 index 000000000..f7bd6840c --- /dev/null +++ b/axlearn/cloud/gcp/examples/colocated_python_benchmark.py @@ -0,0 +1,365 @@ +#!/usr/bin/env python3 +""" +Standalone script to preload a model from GCS using Colocated Python. + +This script reads the checkpoint index to determine the model structure and creates +appropriate TensorSpec objects for preloading. + +Usage: + python colocated_python_benchmark.py --ckpt_path gs://your-bucket/path/to/checkpoint +""" + +import argparse +import asyncio +import functools +import os +import sys +import time +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Dict, Sequence + +import jax +import jax.numpy as jnp +import pathwaysutils +from jax._src.mesh import thread_resources +from jax.experimental import colocated_python, mesh_utils +from jax.experimental.array_serialization import serialization as array_serialization +from jax.experimental.array_serialization import tensorstore_impl + +from axlearn.common import utils +from axlearn.common.array_serialization import _async_deserialize +from axlearn.common.checkpointer import parse_step_from_dir, read_index_file +from axlearn.common.utils import TensorSpec, infer_mesh_shape +import logging + + +def _colocated_deserialize( + shardings: Sequence[jax.sharding.NamedSharding], + tensorstore_specs: Sequence[Dict[str, Any]], + global_shapes: Sequence[tuple], + dtypes: Sequence[jnp.dtype], +): + # concurrent_bytes = 1099511627776 + concurrent_bytes = 34359738368 * 6 # multiple of 32GB + cpu_devices = colocated_python.colocated_cpu_devices(jax.devices()) + print(f"{cpu_devices=}") + + if len(cpu_devices) > 1: + print(f"TPU Mesh: {thread_resources.env.physical_mesh}") + cpu_mesh = colocated_python.colocated_cpu_devices(thread_resources.env.physical_mesh) + print(f"CPU Mesh: {cpu_mesh}") + cpu_shardings = [ + jax.sharding.NamedSharding(cpu_mesh, sharding.spec) for sharding in shardings + ] + else: + cpu_shardings = [ + jax.sharding.SingleDeviceSharding(cpu_devices[0]) for sharding in shardings + ] + + def output_spec_fn(): + return [ + jax.ShapeDtypeStruct(shape=shape, dtype=dtype, sharding=sharding) + for shape, dtype, sharding in zip(global_shapes, dtypes, cpu_shardings) + ] + + @colocated_python.colocated_python + def run_deserializer(): + # Object should be created once per process. + # pylint: disable=protected-access + # print("Print statement inside colocated") + logging.info("Logging statement inside colocated") + sys.stderr.write("Stdder statement in colocated") + start_colocated_time=time.perf_counter() + byte_limiter = tensorstore_impl._LimitInFlightBytes(concurrent_bytes) + h2d_limiter = tensorstore_impl._LimitInFlightBytes(concurrent_bytes) + thread_pool = ThreadPoolExecutor(1) + multi_thread_pool = ThreadPoolExecutor(2) + + future_arrays = jax.tree.map( + functools.partial( + _async_deserialize, + byte_limiter=byte_limiter, + h2d_limiter=h2d_limiter, + single_thread_pool=thread_pool, + multi_thread_pool=multi_thread_pool, + ), + cpu_shardings, + tensorstore_specs, + global_shapes, + dtypes, + ) + + async def gather_func(): + return await asyncio.gather(*future_arrays) + + result = asyncio.run(gather_func()) + logging.info(f"Deserialize took {time.perf_counter() - start_colocated_time:.2f} seconds") + return result + + run_deserializer = run_deserializer.specialize( + devices=cpu_devices, + out_specs_fn=output_spec_fn, + ) + + # Try running in the current event loop if one exists, otherwise create new one + result = run_deserializer() + return result + + +def create_mesh(mesh_shape=(1, 1, 1, 1, 1, 1, -1)): + """Create a JAX mesh for distributed computation.""" + inferred_mesh_shape = infer_mesh_shape(mesh_shape) + print(f"Using mesh shape {inferred_mesh_shape} for {len(jax.devices())} devices") + devices = mesh_utils.create_device_mesh(inferred_mesh_shape) + return jax.sharding.Mesh(devices, ("pipeline", "data", "expert", "fsdp", "seq", "track", "model")) + + +def create_state_spec_from_checkpoint(ckpt_path: str): + """Create a NestedTensorSpec from checkpoint index information.""" + index = read_index_file(ckpt_path) + print(f"Read checkpoint index with {len(index)} entries") + + state_spec = {} + + for path, value in index: + if path == "step": + continue + + # Filter out learner state + if is_learner_path(path): + continue + + if isinstance(value, dict) and "shape" in value and "dtype" in value: + # pylint: disable=eval-used + shape = eval(value["shape"]) if isinstance(value["shape"], str) else value["shape"] + dtype_str = value["dtype"] + + # Convert dtype string to jax dtype + dtype = getattr(jnp, dtype_str, jnp.float32) + if dtype == jnp.float32: + dtype = jnp.bfloat16 + + # Create nested dict structure from path + keys = path.split("/") + current = state_spec + for key in keys[:-1]: + if key not in current: + current[key] = {} + current = current[key] + + current[keys[-1]] = TensorSpec(shape=shape, dtype=dtype) + + return state_spec + + +def is_learner_path(path: str) -> bool: + """Check if a path is part of the learner state.""" + # Exclude all learner paths (optimizer state, ema, etc.) + return path.startswith("learner/") + + +def get_inference_partition_spec(path: str, shape: tuple) -> jax.sharding.PartitionSpec: + """Get inference-friendly partition spec based on tensor path and shape.""" + if "track" in path: + return jax.sharding.PartitionSpec(None, "track", "model") + + return jax.sharding.PartitionSpec() + + + +def create_checkpoint_spec_from_state(ckpt_dir: str, state_spec: dict): + """Create checkpoint spec following the pattern from TensorStoreStateStorage._get_spec.""" + + tensorstore_specs = [] + shapes = [] + dtypes = [] + shardings = [] + + # Get current mesh for creating shardings + mesh = thread_resources.env.physical_mesh + if not mesh.shape: + raise RuntimeError("Checkpoint restoration must take place within the context of a Mesh") + + # Process each tensor in the state spec + for path, value in utils.flatten_items(state_spec, separator="/"): + if isinstance(value, TensorSpec): + # Get dtype + dtype = getattr(value.dtype, "dtype", value.dtype) + + # Create storage path and tensorstore spec + gda_path = os.path.join(ckpt_dir, "gda", path) + tensorstore_spec = array_serialization.get_tensorstore_spec(gda_path) + + # Get inference-friendly partition spec based on tensor path and shape + partition_spec = get_inference_partition_spec(path, value.shape) + + # Create sharding with the appropriate partition spec + sharding = jax.sharding.NamedSharding(mesh, partition_spec) + + tensorstore_specs.append(tensorstore_spec) + shapes.append(value.shape) + dtypes.append(dtype) + shardings.append(sharding) + + return tensorstore_specs, shardings, shapes, dtypes + + +def _default_deserialize( + shardings: Sequence[jax.sharding.NamedSharding], + tensorstore_specs: Sequence[Dict[str, Any]], + global_shapes: Sequence[tuple], + dtypes: Sequence[jnp.dtype], +): + # concurrent_bytes = 1099511627776 + concurrent_bytes = 34359738368 * 6 # multiple of 32GB + # Object should be created once per process. + # pylint: disable=protected-access + byte_limiter = tensorstore_impl._LimitInFlightBytes(concurrent_bytes) + h2d_limiter = tensorstore_impl._LimitInFlightBytes(34359738368) + thread_pool = ThreadPoolExecutor(1) + multi_thread_pool = ThreadPoolExecutor(2) + + future_arrays = jax.tree.map( + functools.partial( + _async_deserialize, + byte_limiter=byte_limiter, + h2d_limiter=h2d_limiter, + single_thread_pool=thread_pool, + multi_thread_pool=multi_thread_pool, + ), + shardings, + tensorstore_specs, + global_shapes, + dtypes, + ) + + async def gather_func(): + return await asyncio.gather(*future_arrays) + result = asyncio.run(gather_func()) + return result + + +def load_model_default(ckpt_path: str): + """Main function to preload a model from GCS checkpoint.""" + step = parse_step_from_dir(ckpt_path) + print(f"Starting model preload from: {ckpt_path} (step {step})") + + if not ckpt_path.startswith("gs://"): + raise ValueError(f"Only GCS paths (gs://) are supported, got: {ckpt_path}") + + with create_mesh(): + print("Reading checkpoint structure...") + state_spec = create_state_spec_from_checkpoint(ckpt_path) + + print(f"Found {len(jax.tree_util.tree_leaves(state_spec))} tensors in checkpoint") + + tensorstore_specs, shardings, shapes, dtypes = create_checkpoint_spec_from_state( + ckpt_path, state_spec + ) + + print("Preloading checkpoint to TPU memory...") + start_time = time.perf_counter() + + restored_values = _default_deserialize( + shardings=shardings, + tensorstore_specs=tensorstore_specs, + global_shapes=shapes, + dtypes=dtypes, + ) + + preload_time = time.perf_counter() - start_time + print(f"Preload completed in {preload_time:.2f} seconds") + print(f"Preloaded {len(restored_values)} arrays") + + return restored_values + + +def load_model_colocated(ckpt_path: str): + """Main function to preload a model from GCS checkpoint.""" + step = parse_step_from_dir(ckpt_path) + print(f"Starting model preload from: {ckpt_path} (step {step})") + + if not ckpt_path.startswith("gs://"): + raise ValueError(f"Only GCS paths (gs://) are supported, got: {ckpt_path}") + + with create_mesh(): + print("Reading checkpoint structure...") + state_spec = create_state_spec_from_checkpoint(ckpt_path) + + print(f"Found {len(jax.tree_util.tree_leaves(state_spec))} tensors in checkpoint") + + tensorstore_specs, shardings, shapes, dtypes = create_checkpoint_spec_from_state( + ckpt_path, state_spec + ) + + print("Preloading checkpoint to CPU memory...") + start_time = time.perf_counter() + + preloaded_values = _colocated_deserialize( + shardings=shardings, + tensorstore_specs=tensorstore_specs, + global_shapes=shapes, + dtypes=dtypes, + ) + # for x in preloaded_values: + # x.block_until_ready() + + preload_time = time.perf_counter() - start_time + print(f"Preload completed in {preload_time:.2f} seconds") + print(f"Preloaded {len(preloaded_values)} arrays") + + print("Transferring arrays to TPU...") + start_time = time.perf_counter() + + restored_values = [jax.device_put(x, s) for x, s in zip(preloaded_values, shardings)] + for x in restored_values: + x.block_until_ready() + + transfer_time = time.perf_counter() - start_time + print(f"Transfer completed in {transfer_time:.2f} seconds") + + return restored_values + + +def main(): + parser = argparse.ArgumentParser(description="Preload model from GCS checkpoint") + parser.add_argument( + "--ckpt_path", + required=True, + help="GCS path to checkpoint directory (e.g., gs://bucket/path/to/checkpoint)", + ) + args = parser.parse_args() + + if os.getenv("JAX_PLATFORMS") == "proxy": + pathwaysutils.initialize() + else: + jax.distributed.initialize() + + print(f"JAX devices: {jax.devices()}") + + print("--- Running colocated benchmark ---") + # Extract profile dir from ckpt_path. The profile dir should be gs://bucket/profiles/ + hostname = os.uname().nodename + profile_dir = f"{args.ckpt_path.split('/checkpoints')[0]}/profiles/{hostname}/colocated-test/" + jax.profiler.start_trace(log_dir=profile_dir) + start_colocated_time = time.perf_counter() + loaded_values_colocated = load_model_colocated(ckpt_path=args.ckpt_path) + print(f"✅ Successfully loaded model from {args.ckpt_path}") + print(f"Deserialize took {time.perf_counter() - start_colocated_time:.2f} seconds") + print(f" Total parameters: {sum(x.size for x in loaded_values_colocated):,}") + jax.profiler.stop_trace() + + # Exit early if on pathways + if os.getenv("JAX_PLATFORMS") == "proxy": + sys.exit(0) + + print("\n--- Running default benchmark ---") + start_default_time = time.perf_counter() + loaded_values_default = load_model_default(ckpt_path=args.ckpt_path) + print(f"✅ Successfully loaded model from {args.ckpt_path}") + print(f"Deserialize took {time.perf_counter() - start_default_time:.2f} seconds") + print(f" Total parameters: {sum(x.size for x in loaded_values_default):,}") + + +if __name__ == "__main__": + main() diff --git a/axlearn/cloud/gcp/pathways_utils.py b/axlearn/cloud/gcp/pathways_utils.py index f21afebfe..d311367bd 100644 --- a/axlearn/cloud/gcp/pathways_utils.py +++ b/axlearn/cloud/gcp/pathways_utils.py @@ -17,6 +17,7 @@ _METADATA_GOOGLE_INTERNAL_IP, BASTION_JOB_VERSION_LABEL, BaseReplicatedJob, + FlagConfigurable, TPUReplicatedJob, _LoadBalancer, ) @@ -45,6 +46,7 @@ # The port used by pathways worker server. # The specific value is not important, as long as clients and servers use the same port. _PATHWAYS_WORKER_PORT = 29001 +_COLOCATED_CONTAINER_PORT = 50051 # Pin to specific pathways image version for stable release. # There is no guarantee that this image will work with newer Jax releases. # This image version extends GRPC timeout for long context models, based on jax-0.5.3-patch060625 @@ -58,6 +60,20 @@ _PATHWAYS_SERVER_IMAGE = ( f"us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:{_PATHWAYS_IMAGE_TAG}" ) + +# For now, we use different Pathways images if Colocated Python is enabled. +_PATHWAYS_COLOCATED_IMAGE_TAG = "2025-10-29" +# The docker image used by pathways proxy container. +_PATHWAYS_COLOCATED_PROXY_IMAGE = ( + "us-docker.pkg.dev/cloud-tpu-v2-images/pathways-colocated-python/proxy_server:" + f"{_PATHWAYS_COLOCATED_IMAGE_TAG}-increased-grpc-timeout" +) +# The docker image used by pathways resource manager container and worker container. +_PATHWAYS_COLOCATED_SERVER_IMAGE = ( + "us-docker.pkg.dev/cloud-tpu-v2-images/pathways-colocated-python/server:" + f"{_PATHWAYS_COLOCATED_IMAGE_TAG}" +) + # The container name of pathways resourcemanager. _PATHWAYS_RESOURCE_MANAGER_CONTAINER_NAME = "pathways-rm" # The container name of pathways proxy. @@ -67,6 +83,8 @@ # The k8s replicatedJob name for pathways-worker pods. _PATHWAYS_WORKER_REPLICATED_JOB_NAME = "pwwk" +_COLOCATED_PYTHON_SIDECAR_NAME = "colocated-python-sidecar" + # Add node-selector for cpu workload to avoid sharing nodes with system services. _PATHWAYS_HEAD_NODE_POOL_SELECTOR_KEY = "axlearn/nodepool_type" _PATHWAYS_HEAD_NODE_POOL_SELECTOR_VALUE = "workload" @@ -76,6 +94,12 @@ _PATHWAYS_BACK_OFF_LIMIT = 32 +def get_colocated_python_image(image_id: str) -> str: + path, tag = image_id.rsplit(":", maxsplit=1) + repo, _ = path.rsplit("/", maxsplit=1) + return f"{repo}/{_COLOCATED_PYTHON_SIDECAR_NAME}:{tag}" + + def parse_xla_flag_value(value: str) -> Union[int, bool, str]: """Attempts to convert an XLA flag string value to int. @@ -158,6 +182,80 @@ def round_up_to_power_of_2(n): return 1 << (n - 1).bit_length() +class PathwaysColocatedPythonPlugin(FlagConfigurable): + """Functionality for Pathways jobs with Colocated Python support.""" + + @config_class + class Config(FlagConfigurable.Config): + """Configures PathwaysColocatedPythonPlugin. + + Attributes: + pathways_proxy_image: The Pathways proxy image. + pathways_server_image: The Pathways server image. + """ + + pathways_proxy_image: Optional[str] = None + pathways_server_image: Optional[str] = None + + @classmethod + def define_flags(cls, fv): + super().define_flags(fv) + common_kwargs = dict(flag_values=fv, allow_override=True) + flags.DEFINE_string( + "pathways_proxy_image", + None, + "Allows a custom Pathways proxy image to be provided.", + **common_kwargs, + ) + flags.DEFINE_string( + "pathways_server_image", + None, + "Allows a custom Pathways server image to be provided.", + **common_kwargs, + ) + + def __init__(self, cfg: Config, *, bundler: Bundler): + super().__init__(cfg) + self._enable_colocated_python = getattr(bundler.config, "enable_colocated_python", False) + + # pylint: disable-next=no-self-use + def build_colocated_python_container(self, image: str): + """Builds the Colocated Python sidecar container.""" + return dict( + name=_COLOCATED_PYTHON_SIDECAR_NAME, + image=get_colocated_python_image(image), + restartPolicy="Always", + env=[ + { + "name": "GRPC_SERVER_ADDRESS", + "value": f"0.0.0.0:{_COLOCATED_CONTAINER_PORT}", + }, + ], + imagePullPolicy="Always", + ports=[dict(containerPort=_COLOCATED_CONTAINER_PORT)], + ) + + @property + def pathways_proxy_image(self) -> str: + if (custom_proxy_image := self.config.pathways_proxy_image) is not None: + return custom_proxy_image + elif self.is_colocated_python_enabled: + return _PATHWAYS_COLOCATED_PROXY_IMAGE + return _PATHWAYS_PROXY_IMAGE + + @property + def pathways_server_image(self) -> str: + if (custom_server_image := self.config.pathways_server_image) is not None: + return custom_server_image + elif self.is_colocated_python_enabled: + return _PATHWAYS_COLOCATED_SERVER_IMAGE + return _PATHWAYS_SERVER_IMAGE + + @property + def is_colocated_python_enabled(self) -> bool: + return self._enable_colocated_python + + class PathwaysReplicatedJob(BaseReplicatedJob): """Builds a replicated jobspec for Pathways on TPU, to be used with JobSet API.""" @@ -169,6 +267,7 @@ class Config(BaseReplicatedJob.Config): inner: The wrapped TPUReplicatedJob configuration. pathways_head_cpu: CPU request for pathways-head container. pathways_head_mem: Memory request for pathways-head container. + colocated_python: Configuration for Colocated Python. """ inner: Required[TPUReplicatedJob.Config] = REQUIRED @@ -176,6 +275,8 @@ class Config(BaseReplicatedJob.Config): pathways_head_cpu: Optional[str] = None pathways_head_mem: Optional[str] = None + colocated_python: Required[PathwaysColocatedPythonPlugin.Config] = REQUIRED + @classmethod def define_flags(cls, fv): super().define_flags(fv) @@ -213,13 +314,15 @@ def set_defaults(cls, fv): @classmethod def default_config(cls): cfg = super().default_config() - return cfg.set(inner=TPUReplicatedJob.default_config()) + return cfg.set( + inner=TPUReplicatedJob.default_config(), + colocated_python=PathwaysColocatedPythonPlugin.default_config(), + ) - def __init__(self, cfg: BaseReplicatedJob.Config, *, bundler: Bundler): + def __init__(self, cfg: Config, *, bundler: Bundler): super().__init__(cfg, bundler=bundler) self._bundler = bundler self._inner: TPUReplicatedJob = cfg.inner.instantiate(bundler=self._bundler) - pathways_cfg: PathwaysReplicatedJob.Config = self.config self._tpu_type = infer_tpu_type(cfg.inner.accelerator.instance_type) if self._tpu_type not in USER_FACING_NAME_TO_SYSTEM_CHARACTERISTICS: raise NotImplementedError(f"Missing system characteristics for {self._tpu_type}") @@ -231,7 +334,7 @@ def __init__(self, cfg: BaseReplicatedJob.Config, *, bundler: Bundler): num_slices=cfg.inner.accelerator.num_replicas, backend="tpu", ) - pathways_xla_flags = parse_kv_flags(pathways_cfg.pathways_xla_flags, delimiter="=") + pathways_xla_flags = parse_kv_flags(cfg.pathways_xla_flags, delimiter="=") for k, v in pathways_xla_flags.items(): k = k.lstrip("--") v = parse_xla_flag_value(v) @@ -257,6 +360,8 @@ def __init__(self, cfg: BaseReplicatedJob.Config, *, bundler: Bundler): job_name=_PATHWAYS_WORKER_REPLICATED_JOB_NAME, ) + self._colocated_python = cfg.colocated_python.instantiate(bundler=bundler) + def _update_env_list(self, env_list: list[dict], name: str, value: str): for env in env_list: if env.get("name") == name: @@ -352,6 +457,8 @@ def _build_pathways_head_sidecar_containers(self) -> list[Nested[Any]]: f"--server_port={_PATHWAYS_PROXY_PORT}", f"--gcs_scratch_location={staging_location}", ] + if self._colocated_python.is_colocated_python_enabled: + cmd_args.append("--sidecar_name=external") cmd_args.extend(xla_flags_from_options(self._xla_options).split()) instance_type = f"{pathways_tpu_version}:{system.topology}" @@ -360,7 +467,7 @@ def _build_pathways_head_sidecar_containers(self) -> list[Nested[Any]]: return [ dict( name=_PATHWAYS_PROXY_CONTAINER_NAME, - image=_PATHWAYS_PROXY_IMAGE, + image=self._colocated_python.pathways_proxy_image, # https://kubernetes.io/docs/concepts/workloads/pods/sidecar-containers/#pod-sidecar-containers # SideCar container is an init container with restartPolicy as "Always". restartPolicy="Always", @@ -383,7 +490,7 @@ def _build_pathways_head_sidecar_containers(self) -> list[Nested[Any]]: ), dict( name=_PATHWAYS_RESOURCE_MANAGER_CONTAINER_NAME, - image=_PATHWAYS_SERVER_IMAGE, + image=self._colocated_python.pathways_server_image, # https://kubernetes.io/docs/concepts/workloads/pods/sidecar-containers/#pod-sidecar-containers # SideCar container is an init container with restartPolicy as "Always". restartPolicy="Always", @@ -559,7 +666,7 @@ def _build_pathways_worker_container( mega_scale_args = xla_flags_from_options(self._mxla_options).split() worker_container["args"].extend(mega_scale_args) - worker_container["image"] = _PATHWAYS_SERVER_IMAGE + worker_container["image"] = self._colocated_python.pathways_server_image ports = worker_container.get("ports", []) ports.append({"containerPort": _PATHWAYS_WORKER_PORT}) @@ -592,6 +699,11 @@ def _build_pathways_worker_pod( pod_spec["containers"] = [ self._build_pathways_worker_container(pathways_worker_replicated_job_index) ] + if self._colocated_python.is_colocated_python_enabled: + image = cfg.image_id or self._bundler.id(cfg.name) + pod_spec["initContainers"].append( + self._colocated_python.build_colocated_python_container(image) + ) worker_pod["spec"] = pod_spec # Service account for nodes. @@ -756,6 +868,7 @@ class Config(BaseLeaderWorkerTemplate.Config): inner: The wrapped TPUReplicatedJob configuration. pathways_head_cpu: CPU request for pathways-head container. pathways_head_mem: Memory request for pathways-head container. + colocated_python: Configuration for Colocated Python. """ inner: Required[TPULeaderWorkerTemplate.Config] = REQUIRED @@ -766,6 +879,8 @@ class Config(BaseLeaderWorkerTemplate.Config): target_port: Optional[int] = None enable_service: bool = None + colocated_python: Required[PathwaysColocatedPythonPlugin.Config] = REQUIRED + @classmethod def define_flags(cls, fv): super().define_flags(fv) @@ -817,9 +932,12 @@ def set_defaults(cls, fv): @classmethod def default_config(cls): cfg = super().default_config() - return cfg.set(inner=TPULeaderWorkerTemplate.default_config()) + return cfg.set( + inner=TPULeaderWorkerTemplate.default_config(), + colocated_python=PathwaysColocatedPythonPlugin.default_config(), + ) - def __init__(self, cfg: BaseLeaderWorkerTemplate.Config, *, bundler): + def __init__(self, cfg: Config, *, bundler): super().__init__(cfg, bundler=bundler) cfg: PathwaysLeaderWorkerTemplate.Config = self.config @@ -829,6 +947,8 @@ def __init__(self, cfg: BaseLeaderWorkerTemplate.Config, *, bundler): if self._tpu_type not in USER_FACING_NAME_TO_SYSTEM_CHARACTERISTICS: raise NotImplementedError(f"Missing system characteristics for {self._tpu_type}") + self._colocated_python = cfg.colocated_python.instantiate(bundler=bundler) + def _build_pathways_worker_container(self) -> dict: cfg: TPULeaderWorkerTemplate.Config = self.config # pylint: disable-next=protected-access @@ -845,7 +965,7 @@ def _build_pathways_worker_container(self) -> dict: ports = worker_container.get("ports", []) ports.append({"containerPort": _PATHWAYS_WORKER_PORT}) worker_container["ports"] = ports - worker_container["image"] = _PATHWAYS_SERVER_IMAGE + worker_container["image"] = self._colocated_python.pathways_server_image worker_container.pop("command") return worker_container @@ -862,6 +982,11 @@ def build_worker_pod(self) -> dict: pod_spec["HostNetwork"] = True pod_spec["dnsPolicy"] = "ClusterFirstWithHostNet" pod_spec["containers"] = [self._build_pathways_worker_container()] + if self._colocated_python.is_colocated_python_enabled: + image = cfg.image_id or self._bundler.id(cfg.name) + pod_spec["initContainers"].append( + self._colocated_python.build_colocated_python_container(image) + ) worker_pod["spec"] = pod_spec # Service account for nodes. @@ -881,15 +1006,18 @@ def build_worker_pod(self) -> dict: def _build_pathways_proxy_container(self) -> dict: cfg: TPULeaderWorkerTemplate.Config = self._inner.config staging_location = f"{cfg.output_dir}/pathways-staging" + cmd_args = [ + f"--resource_manager_address=localhost:{_PATHWAYS_RESOURCE_MANAGER_PORT}", + f"--server_port={_PATHWAYS_PROXY_PORT}", + f"--gcs_scratch_location={staging_location}", + ] + if self._colocated_python.is_colocated_python_enabled: + cmd_args.append("--sidecar_name=external") return dict( name=_PATHWAYS_PROXY_CONTAINER_NAME, - image=_PATHWAYS_PROXY_IMAGE, - args=[ - f"--resource_manager_address=localhost:{_PATHWAYS_RESOURCE_MANAGER_PORT}", - f"--server_port={_PATHWAYS_PROXY_PORT}", - f"--gcs_scratch_location={staging_location}", - ], + image=self._colocated_python.pathways_proxy_image, + args=cmd_args, env=[{"name": "IFRT_PROXY_USE_INSECURE_GRPC_CREDENTIALS", "value": "true"}], ports=[dict(containerPort=_PATHWAYS_PROXY_PORT)], ) @@ -903,7 +1031,7 @@ def _build_pathways_rm_container(self) -> dict: return dict( name=_PATHWAYS_RESOURCE_MANAGER_CONTAINER_NAME, - image=_PATHWAYS_SERVER_IMAGE, + image=self._colocated_python.pathways_server_image, env=[ { "name": "TPU_SKIP_MDS_QUERY", diff --git a/axlearn/cloud/gcp/pathways_utils_test.py b/axlearn/cloud/gcp/pathways_utils_test.py index 41c11d59f..b5e358bd7 100644 --- a/axlearn/cloud/gcp/pathways_utils_test.py +++ b/axlearn/cloud/gcp/pathways_utils_test.py @@ -11,6 +11,7 @@ from axlearn.cloud.gcp import bundler, jobset_utils, lws_utils, pathways_utils from axlearn.cloud.gcp.bundler import CloudBuildBundler from axlearn.cloud.gcp.pathways_utils import ( + _PATHWAYS_COLOCATED_SERVER_IMAGE, _PATHWAYS_HEAD_NODE_POOL_SELECTOR_KEY, _PATHWAYS_HEAD_NODE_POOL_SELECTOR_VALUE, _PATHWAYS_PROXY_CONTAINER_NAME, @@ -77,8 +78,11 @@ def _job_config(self, bundler_cls: type[Bundler], instance_type: str = "tpu-v5p- bundler_cfg = bundler_cls.from_spec([], fv=fv).set(image="test-image") yield cfg, bundler_cfg - @parameterized.parameters(dict(instance_type="tpu-v5p-16"), dict(instance_type="tpu-v5p-256")) - def test_build_pathways_head_pod(self, instance_type): + @parameterized.product( + instance_type=["tpu-v5p-16", "tpu-v5p-256"], + enable_colocated_python=[False, True], + ) + def test_build_pathways_head_pod(self, instance_type, enable_colocated_python): with ( self._job_config( CloudBuildBundler, @@ -90,7 +94,11 @@ def test_build_pathways_head_pod(self, instance_type): name="test", command="test_command", output_dir="FAKE", - ).instantiate(bundler=bundler_cfg.instantiate()) + ).instantiate( + bundler=bundler_cfg.set( + enable_colocated_python=enable_colocated_python + ).instantiate() + ) builder = cfg.instantiate(bundler=bundler_cfg.instantiate()) # pylint: disable-next=protected-access @@ -149,6 +157,8 @@ def test_build_pathways_head_pod(self, instance_type): for container in pod_spec["initContainers"]: if container["name"] == _PATHWAYS_PROXY_CONTAINER_NAME: proxy_container = container + if enable_colocated_python: + self.assertIn("--sidecar_name=external", container["args"]) if container["name"] == _PATHWAYS_RESOURCE_MANAGER_CONTAINER_NAME: rm_container = container self.assertIsNotNone(proxy_container, "Pathways proxy container not found.") @@ -168,7 +178,10 @@ def test_build_pathways_head_pod(self, instance_type): self.assertIn("--instance_count=1", rm_container["args"]) self.assertIn("--instance_type=tpuv5:4x4x8_untwisted", rm_container["args"]) - def test_build_pathways_worker_pod(self): + @parameterized.product( + enable_colocated_python=[False, True], + ) + def test_build_pathways_worker_pod(self, enable_colocated_python): with ( self._job_config( CloudBuildBundler, @@ -180,7 +193,11 @@ def test_build_pathways_worker_pod(self): command="test_command", output_dir="FAKE", service_account="test-service-account", - ).instantiate(bundler=bundler_cfg.instantiate()) + ).instantiate( + bundler=bundler_cfg.set( + enable_colocated_python=enable_colocated_python + ).instantiate() + ) builder = cfg.instantiate(bundler=bundler_cfg.instantiate()) # pylint: disable-next=protected-access @@ -192,7 +209,12 @@ def test_build_pathways_worker_pod(self): self.assertEqual(pod_spec.get("hostNetwork"), True) self.assertEqual(pod_spec.get("dnsPolicy"), "ClusterFirstWithHostNet") worker_container = pod_spec.get("containers")[0] - self.assertEqual(worker_container["image"], _PATHWAYS_SERVER_IMAGE) + server_image = ( + _PATHWAYS_COLOCATED_SERVER_IMAGE + if enable_colocated_python + else _PATHWAYS_SERVER_IMAGE + ) + self.assertEqual(worker_container["image"], server_image) annotations = pod["metadata"]["annotations"] self.assertEqual( "test-service-account@test-project.iam.gserviceaccount.com", @@ -202,6 +224,9 @@ def test_build_pathways_worker_pod(self): # 128GiB self.assertIn("--tpu_premapped_buffer_size=137438953472", worker_container["args"]) + expected_num_init_containers = 2 if enable_colocated_python else 1 + self.assertEqual(expected_num_init_containers, len(pod_spec.get("initContainers", []))) + # Check worker container args for Megascale (MXLA) flags. # pylint: disable-next=protected-access mxla_arg_flags = xla_flags_from_options(builder._mxla_options).split() @@ -477,7 +502,10 @@ def _job_config(self, bundler_cls: type[Bundler], **kwargs): print("debug: cfg: ", type(cfg)) yield cfg, bundler_cfg - def test_build_leader_pod(self): + @parameterized.product( + enable_colocated_python=[False, True], + ) + def test_build_leader_pod(self, enable_colocated_python): with ( self._job_config( CloudBuildBundler, @@ -488,7 +516,11 @@ def test_build_leader_pod(self): name="a" * 36, command="test_command", output_dir="FAKE", - ).instantiate(bundler=bundler_cfg.instantiate()) + ).instantiate( + bundler=bundler_cfg.set( + enable_colocated_python=enable_colocated_python + ).instantiate() + ) builder = cfg.instantiate(bundler=bundler_cfg.instantiate()) pod = builder.build_leader_pod() @@ -496,7 +528,22 @@ def test_build_leader_pod(self): self.assertEqual(len(pod_spec["containers"]), 3) - def test_build_worker_pod(self): + proxy_container = None + rm_container = None + for container in pod_spec["containers"]: + if container["name"] == _PATHWAYS_PROXY_CONTAINER_NAME: + proxy_container = container + if enable_colocated_python: + self.assertIn("--sidecar_name=external", container["args"]) + if container["name"] == _PATHWAYS_RESOURCE_MANAGER_CONTAINER_NAME: + rm_container = container + self.assertIsNotNone(proxy_container, "Pathways proxy container not found.") + self.assertIsNotNone(rm_container, "Pathways rm container not found.") + + @parameterized.product( + enable_colocated_python=[False, True], + ) + def test_build_worker_pod(self, enable_colocated_python): with ( self._job_config( CloudBuildBundler, @@ -507,15 +554,27 @@ def test_build_worker_pod(self): name="a" * 36, command="test_command", output_dir="FAKE", - ).instantiate(bundler=bundler_cfg.instantiate()) + ).instantiate( + bundler=bundler_cfg.set( + enable_colocated_python=enable_colocated_python + ).instantiate() + ) builder = cfg.instantiate(bundler=bundler_cfg.instantiate()) pod = builder.build_worker_pod() pod_spec = pod["spec"] container = pod_spec.get("containers")[0] - self.assertEqual(container["image"], _PATHWAYS_SERVER_IMAGE) + server_image = ( + _PATHWAYS_COLOCATED_SERVER_IMAGE + if enable_colocated_python + else _PATHWAYS_SERVER_IMAGE + ) + self.assertEqual(container["image"], server_image) self.assertEqual(len(container["args"]), 3) + expected_num_init_containers = 2 if enable_colocated_python else 1 + self.assertEqual(expected_num_init_containers, len(pod_spec.get("initContainers", []))) + def test_leader_worker_template(self): with ( self._job_config( diff --git a/axlearn/common/array_serialization.py b/axlearn/common/array_serialization.py index 369926e2e..d9988862f 100644 --- a/axlearn/common/array_serialization.py +++ b/axlearn/common/array_serialization.py @@ -307,6 +307,7 @@ async def _async_serialize( ) # pylint: disable=protected-access spec_has_metadata = { + "0.6.2.dev0+selfbuilt": lambda: serialization.ts_impl._spec_has_metadata, "0.6.2": lambda: serialization.ts_impl._spec_has_metadata, "0.5.3": lambda: serialization._spec_has_metadata, }[jax.__version__]() @@ -487,6 +488,7 @@ async def cb(index: array.Index, device: jax.Device): requested_domain = ts.IndexTransform(input_shape=shape)[index].domain restricted_domain = t.domain.intersect(requested_domain) estimate_read_memory_footprint = { + "0.6.2.dev0+selfbuilt": lambda: serialization.ts_impl.estimate_read_memory_footprint, "0.6.2": lambda: serialization.ts_impl.estimate_read_memory_footprint, "0.5.3": lambda: serialization.estimate_read_memory_footprint, }[jax.__version__]() @@ -568,6 +570,7 @@ async def cb(index: array.Index, device: jax.Device): # pylint: disable=protected-access create_async_array_from_callback = { + "0.6.2.dev0+selfbuilt": lambda: serialization.ts_impl._create_async_array_from_callback, "0.6.2": lambda: serialization.ts_impl._create_async_array_from_callback, "0.5.3": lambda: serialization.create_async_array_from_callback, }[jax.__version__]() @@ -653,6 +656,7 @@ def serialize( commit_futures = [[] for _ in range(len(tensorstore_specs))] async_serialize = { + "0.6.2.dev0+selfbuilt": lambda: serialization.ts_impl.async_serialize, "0.6.2": lambda: serialization.ts_impl.async_serialize, "0.5.3": lambda: serialization.async_serialize, }[jax.__version__]()