-
Couldn't load subscription status.
- Fork 931
Fixing task ID replacement for MNP jobs on AWS Batch #2574
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
2a5c211
ce15127
86c3b84
f2ee285
21a62ac
26fa49c
cc5b44e
96259ac
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,6 +4,7 @@ | |
| import random | ||
| import time | ||
| import hashlib | ||
| import os | ||
|
|
||
| try: | ||
| unicode | ||
|
|
@@ -19,7 +20,34 @@ class BatchClient(object): | |
| def __init__(self): | ||
| from ..aws_client import get_aws_client | ||
|
|
||
| self._client = get_aws_client("batch") | ||
| # Prefer the task role by default when running inside AWS Batch containers | ||
| # by temporarily removing higher-precedence env credentials for this process. | ||
| # This avoids AMI-injected AWS_* env vars from overriding the task role. | ||
| # Outside of Batch, we leave env vars untouched unless explicitly opted-in. | ||
| if "AWS_BATCH_JOB_ID" in os.environ: | ||
| _aws_env_keys = [ | ||
| "AWS_ACCESS_KEY_ID", | ||
| "AWS_SECRET_ACCESS_KEY", | ||
| "AWS_SESSION_TOKEN", | ||
| "AWS_PROFILE", | ||
| "AWS_DEFAULT_PROFILE", | ||
| ] | ||
| _present = [k for k in _aws_env_keys if k in os.environ] | ||
| print( | ||
| "[Metaflow] AWS credential-related env vars present before Batch client init:", | ||
| _present, | ||
| ) | ||
| _saved_env = { | ||
| k: os.environ.pop(k) for k in _aws_env_keys if k in os.environ | ||
| } | ||
| try: | ||
| self._client = get_aws_client("batch") | ||
| finally: | ||
| # Restore prior env for the rest of the process | ||
| for k, v in _saved_env.items(): | ||
| os.environ[k] = v | ||
| else: | ||
| self._client = get_aws_client("batch") | ||
|
|
||
| def active_job_queues(self): | ||
| paginator = self._client.get_paginator("describe_job_queues") | ||
|
|
@@ -96,25 +124,21 @@ def execute(self): | |
| commands = self.payload["containerOverrides"]["command"][-1] | ||
| # add split-index as this worker is also an ubf_task | ||
| commands = commands.replace("[multinode-args]", "--split-index 0") | ||
| # For main node, remove the placeholder since it keeps the original task ID | ||
| commands = commands.replace("[NODE-INDEX]", "") | ||
| main_task_override["command"][-1] = commands | ||
|
|
||
| # secondary tasks | ||
| secondary_task_container_override = copy.deepcopy( | ||
| self.payload["containerOverrides"] | ||
| ) | ||
| secondary_commands = self.payload["containerOverrides"]["command"][-1] | ||
| # other tasks do not have control- prefix, and have the split id appended to the task -id | ||
| secondary_commands = secondary_commands.replace( | ||
| self._task_id, | ||
| self._task_id.replace("control-", "") | ||
| + "-node-$AWS_BATCH_JOB_NODE_INDEX", | ||
| ) | ||
| secondary_commands = secondary_commands.replace( | ||
| "ubf_control", | ||
| "ubf_task", | ||
| ) | ||
| secondary_commands = secondary_commands.replace( | ||
| "[multinode-args]", "--split-index $AWS_BATCH_JOB_NODE_INDEX" | ||
| # For secondary nodes: remove "control-" prefix and replace placeholders | ||
| secondary_commands = ( | ||
| secondary_commands.replace("control-", "") | ||
| .replace("[NODE-INDEX]", "-node-$AWS_BATCH_JOB_NODE_INDEX") | ||
| .replace("ubf_control", "ubf_task") | ||
| .replace("[multinode-args]", "--split-index $AWS_BATCH_JOB_NODE_INDEX") | ||
| ) | ||
|
|
||
| secondary_task_container_override["command"][-1] = secondary_commands | ||
|
|
@@ -408,6 +432,14 @@ def _register_job_definition( | |
|
|
||
| self.num_parallel = num_parallel or 0 | ||
| if self.num_parallel >= 1: | ||
| # Set the ulimit of number of open files to 65536. This is because we cannot set it easily once worker processes start on Batch. | ||
| # job_definition["containerProperties"]["linuxParameters"]["ulimits"] = [ | ||
| # { | ||
| # "name": "nofile", | ||
| # "softLimit": 65536, | ||
| # "hardLimit": 65536, | ||
| # } | ||
| # ] | ||
|
Comment on lines
+435
to
+442
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can this be cleaned up? |
||
| job_definition["type"] = "multinode" | ||
| job_definition["nodeProperties"] = { | ||
| "numNodes": self.num_parallel, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -421,6 +421,89 @@ def _wait_for_mapper_tasks(self, flow, step_name): | |
| TIMEOUT = 600 | ||
| last_completion_timeout = time.time() + TIMEOUT | ||
| print("Waiting for batch secondary tasks to finish") | ||
|
|
||
| # Prefer Batch API when metadata is local (nodes can't share local metadata files). | ||
| # If metadata isn't bound yet but we are on Batch, also prefer Batch API. | ||
| md = getattr(self, "metadata", None) | ||
| if md is not None and md.TYPE == "local": | ||
| return self._wait_for_mapper_tasks_batch_api( | ||
| flow, step_name, last_completion_timeout | ||
| ) | ||
| if md is None and "AWS_BATCH_JOB_ID" in os.environ: | ||
| return self._wait_for_mapper_tasks_batch_api( | ||
| flow, step_name, last_completion_timeout | ||
| ) | ||
| return self._wait_for_mapper_tasks_metadata( | ||
| flow, step_name, last_completion_timeout | ||
| ) | ||
|
|
||
| def _wait_for_mapper_tasks_batch_api( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: |
||
| self, flow, step_name, last_completion_timeout | ||
| ): | ||
| """ | ||
| Poll the shared datastore (S3) for DONE markers for each mapper task. | ||
| This avoids relying on a metadata service or local metadata files. | ||
| """ | ||
| from metaflow.datastore.task_datastore import TaskDataStore | ||
|
|
||
| pathspecs = getattr(flow, "_control_mapper_tasks", []) | ||
| total = len(pathspecs) | ||
| if total == 0: | ||
| print("No mapper tasks discovered for datastore wait; returning") | ||
| return True | ||
|
|
||
| print("Waiting for mapper DONE markers in datastore for %d tasks" % total) | ||
| poll_sleep = 3.0 | ||
| while last_completion_timeout > time.time(): | ||
| time.sleep(poll_sleep) | ||
| completed = 0 | ||
| for ps in pathspecs: | ||
| try: | ||
| parts = ps.split("/") | ||
| if len(parts) == 3: | ||
| run_id, step, task_id = parts | ||
| else: | ||
| # Fallback in case of unexpected format | ||
| run_id, step, task_id = self.run_id, step_name, parts[-1] | ||
| tds = TaskDataStore( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| self.flow_datastore, | ||
| run_id, | ||
| step, | ||
| task_id, | ||
| mode="r", | ||
| allow_not_done=True, | ||
| ) | ||
| if tds.has_metadata(TaskDataStore.METADATA_DONE_SUFFIX): | ||
| completed += 1 | ||
| except Exception as e: | ||
| if os.environ.get("METAFLOW_DEBUG_BATCH_POLL") in ( | ||
| "1", | ||
| "true", | ||
| "True", | ||
| ): | ||
|
Comment on lines
+479
to
+483
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this necessary to have as a debug flag? were there a lot of datastore errors encountered to necessitate this? I'd opt to either printing out the error always, or never depending on how relevant it is for the user. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. alternatively we can revisit this separately and consider adding a proper debug flag for |
||
| print("Datastore wait: error checking %s: %s" % (ps, e)) | ||
| continue | ||
| if completed == total: | ||
| print("All mapper tasks have written DONE markers to datastore") | ||
| return True | ||
| print( | ||
| "Waiting for mapper DONE markers. Finished: %d/%d" % (completed, total) | ||
| ) | ||
| poll_sleep = min(poll_sleep * 1.25, 10.0) | ||
|
|
||
| raise Exception( | ||
| "Batch secondary workers did not finish in %s seconds (datastore wait)" | ||
| % (time.time() - (last_completion_timeout - 600)) | ||
| ) | ||
|
|
||
| def _wait_for_mapper_tasks_metadata(self, flow, step_name, last_completion_timeout): | ||
| """ | ||
| Polls Metaflow metadata (Step client) for task completion. | ||
| Works with service-backed metadata providers but can fail with local metadata | ||
| in multi-node setups due to isolated per-node filesystems. | ||
| """ | ||
| from metaflow import Step | ||
|
|
||
| while last_completion_timeout > time.time(): | ||
| time.sleep(2) | ||
| try: | ||
|
|
@@ -441,7 +524,8 @@ def _wait_for_mapper_tasks(self, flow, step_name): | |
| except Exception: | ||
| pass | ||
| raise Exception( | ||
| "Batch secondary workers did not finish in %s seconds" % TIMEOUT | ||
| "Batch secondary workers did not finish in %s seconds" | ||
| % (time.time() - (last_completion_timeout - 600)) | ||
| ) | ||
|
|
||
| @classmethod | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this change relevant to the batch parallel issue, or something different? the PR seems to work fine without this part as well