Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
331bcd2
add base benchmark scripts
samos123 Sep 11, 2025
2c01b11
remove pathways cpu and memory limit
samos123 Sep 11, 2025
2876881
remove batching
samos123 Sep 11, 2025
4a9f1f8
fix for jax 0.5.3
samos123 Sep 11, 2025
1262bf7
limit amount of max bytes being restored
samos123 Sep 11, 2025
c057ecc
WIP: send all jax device puts in parallel with pathways
samos123 Sep 12, 2025
ca6dd86
also prevent byte limiter
samos123 Sep 12, 2025
6489cf8
Merge branch 'benchmark-deserialize-main' into pathways-array-seriali…
samos123 Sep 12, 2025
2ae8a54
fuji pdbs=1
samos123 Sep 12, 2025
6d84ab0
add logging of device puts
samos123 Sep 12, 2025
cf282c6
add profiling of checkpoint load
samos123 Sep 12, 2025
36351ed
Update fuji 70b mesh for v5e
samos123 Sep 12, 2025
793fb65
bring back the limiters
samos123 Sep 12, 2025
0ab519f
add log to print total restore time
samos123 Sep 12, 2025
eea2db5
try deleting from HBM after device_put
samos123 Sep 12, 2025
0894e2d
save every 100 steps fuji 7b and remove delete
samos123 Sep 12, 2025
ddb6b03
dont save any remats fuji 7b
samos123 Sep 12, 2025
1f87664
7b fsdp=32
samos123 Sep 12, 2025
68bf122
exit after deserialize
samos123 Sep 12, 2025
e30c7ce
concurrent restore 128GB
samos123 Sep 12, 2025
213fa24
print total time before stopping trace
samos123 Sep 12, 2025
d0e45e6
add logging of concurrent_gb
samos123 Sep 12, 2025
4fc79a1
force to 128gb for real this time
samos123 Sep 12, 2025
d71ce4a
add improvements to GCS perf
samos123 Sep 12, 2025
22b8faf
non blocking device put, only block when all device puts are done
samos123 Sep 12, 2025
a129ef3
time the download from GCS
samos123 Sep 12, 2025
73654f2
add scripts to launch fuji within interactive pathways cluster
samos123 Sep 12, 2025
2033d65
add pathways premap buffer 17gb
samos123 Sep 12, 2025
67127b3
Generate unique subdir for profile
samos123 Sep 12, 2025
976dcd3
pathways bump async computations
samos123 Sep 12, 2025
0c5f507
comment out flag that didnt work
samos123 Sep 12, 2025
0c0cd77
proper pathways flag prefix
samos123 Sep 12, 2025
a98c2a2
test pathways flags
samos123 Sep 12, 2025
f8adbcc
make the proxy bench more similar to axlearn
samos123 Sep 13, 2025
d1c38bc
rerun with premap buffer set
samos123 Sep 13, 2025
4cd54d5
mess around with pathways flags
samos123 Sep 13, 2025
9c245b4
pathways head on TPU VM
samos123 Sep 14, 2025
8dc5205
stick to default concurrent restore of 32gb
samos123 Sep 14, 2025
6bae63d
switch to standard blocking device put
samos123 Sep 14, 2025
e38d34f
concurrent restore 64GB
samos123 Sep 14, 2025
032dd7b
disable force to 64
samos123 Sep 14, 2025
81756e5
re-enable premap buffer
samos123 Sep 14, 2025
a4f2ba6
set cpu nodepool selector to c4
samos123 Sep 15, 2025
0471d3c
Revert "set cpu nodepool selector to c4"
samos123 Sep 15, 2025
09c5bff
fix training script
samos123 Sep 15, 2025
47e2e74
fix pathways_head_on_tpu=false
samos123 Sep 15, 2025
d617723
uniqe id for each xprof
samos123 Sep 15, 2025
08718bd
use privileged to get rid of zero copy warning
samos123 Sep 15, 2025
cf24047
wip: unix socket pathways proxy
samos123 Sep 16, 2025
6792b25
train fuji 7b on v5p / restore
samos123 Sep 16, 2025
22b1eb7
7b remove nodeSelector
samos123 Sep 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 113 additions & 27 deletions axlearn/cloud/gcp/pathways_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,19 @@
# There is no guarantee that this image will work with newer Jax releases.
# This image version extends GRPC timeout for long context models, based on jax-0.5.3-patch060625
# This image extends GRPC timeout for long context models.
_PATHWAYS_IMAGE_TAG = "disable_settings_20250701"
# _PATHWAYS_IMAGE_TAG = "disable_settings_20250701"
_PATHWAYS_IMAGE_TAG = "uds"
# The docker image used by pathways proxy container.
_PATHWAYS_PROXY_IMAGE = (
f"us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:{_PATHWAYS_IMAGE_TAG}"
# f"us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:{_PATHWAYS_IMAGE_TAG}"
"us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/gke/shauryag/"
f"unsanitized_proxy_server:{_PATHWAYS_IMAGE_TAG}"
)
# The docker image used by pathways resource manager container and worker container.
_PATHWAYS_SERVER_IMAGE = (
f"us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:{_PATHWAYS_IMAGE_TAG}"
# f"us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:{_PATHWAYS_IMAGE_TAG}"
"us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/gke/shauryag/"
f"unsanitized_server:{_PATHWAYS_IMAGE_TAG}"
)
# The container name of pathways resourcemanager.
_PATHWAYS_RESOURCE_MANAGER_CONTAINER_NAME = "pathways-rm"
Expand Down Expand Up @@ -107,7 +112,7 @@ def get_pathways_tpu_version(gke_machine_type: str) -> str:


def get_megascale_options(
xla_options: dict[str, Union[str, bool, int]]
xla_options: dict[str, Union[str, bool, int]],
) -> dict[str, Union[str, bool, int]]:
"""Filters XLA options for those pertaining to Megascale.

Expand All @@ -122,7 +127,7 @@ def get_megascale_options(


def get_xla_options(
xla_options: dict[str, Union[str, bool, int]]
xla_options: dict[str, Union[str, bool, int]],
) -> dict[str, Union[str, bool, int]]:
"""Filters XLA options for those starting with 'xla_'.

Expand All @@ -146,12 +151,14 @@ class Config(BaseReplicatedJob.Config):
inner: The wrapped TPUReplicatedJob configuration.
pathways_head_cpu: CPU request for pathways-head container.
pathways_head_mem: Memory request for pathways-head container.
pathways_head_on_tpu: Whether to run pathways head on TPU VM.
"""

inner: Required[TPUReplicatedJob.Config] = REQUIRED
pathways_xla_flags: list[str] = []
pathways_head_cpu: Optional[str] = None
pathways_head_mem: Optional[str] = None
pathways_head_on_tpu: bool = False

@classmethod
def define_flags(cls, fv):
Expand Down Expand Up @@ -180,6 +187,12 @@ def define_flags(cls, fv):
"Memory request for pathways-head container in GiB. Default is 16GiB",
**common_kwargs,
)
flags.DEFINE_boolean(
"pathways_head_on_tpu",
False,
"If True, run pathways head on TPU VM.",
**common_kwargs,
)

@classmethod
def set_defaults(cls, fv):
Expand Down Expand Up @@ -261,10 +274,16 @@ def _build_pathways_head_container(self) -> dict:
head_container = copy.deepcopy(container)

env_list = head_container.get("env", [])
# self._update_env_list(
# env_list,
# "JAX_BACKEND_TARGET",
# f"grpc://localhost:{_PATHWAYS_PROXY_PORT}",
# )
# Unix domain socket
self._update_env_list(
env_list,
"JAX_BACKEND_TARGET",
f"grpc://localhost:{_PATHWAYS_PROXY_PORT}",
"grpc:///tmp/ifrt_proxy.sock",
)
self._update_env_list(env_list, "XCLOUD_ENVIRONMENT", "GCP")
self._update_env_list(env_list, "JAX_PLATFORMS", "proxy")
Expand Down Expand Up @@ -315,10 +334,14 @@ def _build_pathways_head_container(self) -> dict:
mem_req = f"{self.config.pathways_head_mem}Gi"
resources = {
"requests": {"cpu": cpu_req, "memory": mem_req},
"limits": {"cpu": cpu_req, "memory": mem_req},
# "limits": {"cpu": cpu_req, "memory": mem_req},
}
head_container["resources"] = resources

volume_mounts = head_container.get("volumeMounts", [])
volume_mounts.append(dict(name="shared-memory", mountPath="/tmp/"))
head_container["volumeMounts"] = volume_mounts

return head_container

def _build_pathways_head_sidecar_containers(self) -> list[Nested[Any]]:
Expand All @@ -342,6 +365,7 @@ def _build_pathways_head_sidecar_containers(self) -> list[Nested[Any]]:

cmd_args = [
f"--resource_manager_address=localhost:{_PATHWAYS_RESOURCE_MANAGER_PORT}",
# using unix socket but port needs to be set anyway
f"--server_port={_PATHWAYS_PROXY_PORT}",
f"--gcs_scratch_location={staging_location}",
]
Expand All @@ -354,6 +378,7 @@ def _build_pathways_head_sidecar_containers(self) -> list[Nested[Any]]:
dict(
name=_PATHWAYS_PROXY_CONTAINER_NAME,
image=_PATHWAYS_PROXY_IMAGE,
securityContext={"privileged": True},
# https://kubernetes.io/docs/concepts/workloads/pods/sidecar-containers/#pod-sidecar-containers
# SideCar container is an init container with restartPolicy as "Always".
restartPolicy="Always",
Expand All @@ -365,7 +390,10 @@ def _build_pathways_head_sidecar_containers(self) -> list[Nested[Any]]:
{"name": "XLA_FLAGS", "value": f"--xla_dump_to=/output/{cfg.name}/xla"},
],
ports=[dict(containerPort=_PATHWAYS_PROXY_PORT)],
volumeMounts=[dict(name="shared-output", mountPath="/output")],
volumeMounts=[
dict(name="shared-output", mountPath="/output"),
dict(name="shared-memory", mountPath="/tmp/"),
],
),
dict(
name=_PATHWAYS_RESOURCE_MANAGER_CONTAINER_NAME,
Expand Down Expand Up @@ -403,6 +431,7 @@ def _build_pathways_head_pod(self) -> Nested[Any]:
labels.update({BASTION_JOB_VERSION_LABEL: os.environ.get(BASTION_JOB_VERSION_ENV_VAR)})

volumes.append(dict(name="shared-output", emptyDir={}))
volumes.append(dict(name="shared-memory", emptyDir=dict(medium="Memory")))

if cfg.gcsfuse_mount:
annotations.update(
Expand All @@ -414,9 +443,15 @@ def _build_pathways_head_pod(self) -> Nested[Any]:
}
)

node_selector = {
_PATHWAYS_HEAD_NODE_POOL_SELECTOR_KEY: _PATHWAYS_HEAD_NODE_POOL_SELECTOR_VALUE,
}
if self.config.pathways_head_on_tpu:
# pylint: disable-next=protected-access
pod = self._inner._build_pod()
node_selector = {}
tolerations = pod["spec"]["tolerations"]
else:
node_selector = {
_PATHWAYS_HEAD_NODE_POOL_SELECTOR_KEY: _PATHWAYS_HEAD_NODE_POOL_SELECTOR_VALUE,
}

head_container = self._build_pathways_head_container()
init_containers = [
Expand Down Expand Up @@ -444,6 +479,32 @@ def _build_pathways_head_pod(self) -> Nested[Any]:
"hostNetwork": True,
"dnsPolicy": "ClusterFirstWithHostNet",
}
if self.config.pathways_head_on_tpu:
head_pod_spec["affinity"] = {
"podAffinity": {
"requiredDuringSchedulingIgnoredDuringExecution": [
{
"labelSelector": {
"matchExpressions": [
{
"key": "batch.kubernetes.io/job-name",
"operator": "In",
"values": [
f"{cfg.name}-{_PATHWAYS_WORKER_REPLICATED_JOB_NAME}-0"
],
}
]
},
"topologyKey": "kubernetes.io/hostname",
}
]
}
}

# Remove host ports to avoid scheduling conflicts on the same node.
# The pod runs on host network anyway, so the ports are still accessible.
if "ports" in head_pod_spec["containers"][0]:
del head_pod_spec["containers"][0]["ports"]

if cfg.priority_class:
head_pod_spec["priorityClassName"] = cfg.priority_class
Expand Down Expand Up @@ -537,6 +598,17 @@ def _build_pathways_worker_container(
f"--resource_manager_address={pathways_head_address}:"
+ f"{_PATHWAYS_RESOURCE_MANAGER_PORT}",
f"--gcs_scratch_location={cfg.output_dir}/pathways-staging",
# Set premap buffer to 17GB, needed for faster jax.device_put h2d
# "--pathways_tpu_premapped_buffer_size=17179869184" doesn't work in cloud
# Below flags did not help on 7b restore time
# Recycle vs on-demand seems to give a slight perf boost
"--tpu_pinned_host_allocation_recycle=true",
# pylint: disable=line-too-long
"--temporary_flags_for_debugging=temporary_flag_for_debugging_tpu_premapped_buffer_size=68719476736",
# "--temporary_flags_for_debugging=temporary_flag_for_debuggings_max_num_threads_for_xla_compilation=1000"
# "--temporary_flags_for_debugging=temporary_flag_for_debugging_xla_max_inflight_async_computations=1000",
# "--temporary_flags_for_debugging=temporary_flag_for_debugging_xla_tpu_allow_async_allocations=true",
# "--temporary_flags_for_debugging=temporary_flag_for_debugging_tpu_num_premapped_partitions=65536",
]
mega_scale_args = xla_flags_from_options(self._mxla_options).split()
worker_container["args"].extend(mega_scale_args)
Expand Down Expand Up @@ -634,18 +706,23 @@ def _build_pathways_worker_job(
def __call__(self) -> Sequence[Nested[Any]]:
cfg: TPUReplicatedJob.Config = self._inner.config

replicated_jobs = [
dict(
name=_PATHWAYS_HEAD_REPLICATED_JOB_NAME,
replicas=1,
template=self._build_pathways_head_job(),
),
dict(
name=_PATHWAYS_WORKER_REPLICATED_JOB_NAME,
replicas=cfg.accelerator.num_replicas,
template=self._build_pathways_worker_job(),
),
]
worker_job = dict(
name=_PATHWAYS_WORKER_REPLICATED_JOB_NAME,
replicas=cfg.accelerator.num_replicas,
template=self._build_pathways_worker_job(),
)
head_job = dict(
name=_PATHWAYS_HEAD_REPLICATED_JOB_NAME,
replicas=1,
template=self._build_pathways_head_job(),
)
if self.config.pathways_head_on_tpu:
head_job["dependsOn"] = [
dict(name=_PATHWAYS_WORKER_REPLICATED_JOB_NAME, status="Ready")
]
replicated_jobs = [worker_job, head_job]
else:
replicated_jobs = [head_job, worker_job]

return replicated_jobs

Expand Down Expand Up @@ -865,6 +942,7 @@ def _build_pathways_proxy_container(self) -> dict:
return dict(
name=_PATHWAYS_PROXY_CONTAINER_NAME,
image=_PATHWAYS_PROXY_IMAGE,
securityContext={"privileged": True},
args=[
f"--resource_manager_address=localhost:{_PATHWAYS_RESOURCE_MANAGER_PORT}",
f"--server_port={_PATHWAYS_PROXY_PORT}",
Expand Down Expand Up @@ -900,6 +978,14 @@ def _build_pathways_rm_container(self) -> dict:
"--instance_count=1",
f"--instance_type={pathways_tpu_version}:{system.topology}",
f"--gcs_scratch_location={staging_location}",
# Troubleshooting perf
"--tpu_pinned_host_allocation_recycle=true",
# pylint: disable=line-too-long
"--temporary_flags_for_debugging=temporary_flag_for_debugging_tpu_premapped_buffer_size=68719476736",
# "--temporary_flags_for_debugging=temporary_flag_for_debuggings_max_num_threads_for_xla_compilation=1000"
# "--temporary_flags_for_debugging=temporary_flag_for_debugging_xla_max_inflight_async_computations=1000",
# "--temporary_flags_for_debugging=temporary_flag_for_debugging_xla_tpu_allow_async_allocations=true",
# "--temporary_flags_for_debugging=temporary_flag_for_debugging_tpu_num_premapped_partitions=65536",
],
ports=[dict(containerPort=_PATHWAYS_RESOURCE_MANAGER_PORT)],
)
Expand All @@ -910,7 +996,7 @@ def _build_head_container(self) -> dict:
mem_req = f"{self.config.pathways_head_mem}Gi"
resources = {
"requests": {"cpu": cpu_req, "memory": mem_req},
"limits": {"cpu": cpu_req, "memory": mem_req},
# "limits": {"cpu": cpu_req, "memory": mem_req},
}
return dict(
name=cfg.name,
Expand All @@ -936,9 +1022,9 @@ def _build_head_container(self) -> dict:
],
imagePullPolicy="Always",
resources=resources,
ports=[dict(containerPort=self.config.target_port)]
if self.config.enable_service
else [],
ports=(
[dict(containerPort=self.config.target_port)] if self.config.enable_service else []
),
)

def build_leader_pod(self) -> Nested[Any]:
Expand Down
Loading
Loading