diff --git a/earth2studio/utils/distributed.py b/earth2studio/utils/distributed.py new file mode 100644 index 000000000..c958c689c --- /dev/null +++ b/earth2studio/utils/distributed.py @@ -0,0 +1,233 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Callable, Generator +from concurrent.futures import ThreadPoolExecutor +from queue import Queue +from typing import Any, Literal + +import torch +from physicsnemo.distributed import DistributedManager +from torch.distributed import rpc + + +class DistributedInference: + """Inference a model on remote GPUs. + + DistributedInference can be used to inference a model on one or more remote GPUs + (i.e. GPUs on other ranks of the distributed environment). The user can pass data to the + remote models by calling the DistributedInference object. The input is automatically + queued and passed to the next available remote GPU. The calls are asynchronous and the + results can be obtained by iterating over the `results` method. + + Parameters + ---------- + model : Type[Callable] + The model to initialize on remote GPUs. + + This must be implemented as a callable object that has, at a minimum, a `forward` + method that takes a tensor of input data and returns a tensor of output data. + + It can also have an __init__ constructor; this is called on each remote process + when the DistributedInference is instantiated. The constructor can be used + to load the model on the remote GPU and for other initialization. + + The model can also have other methods that can be called remotely using the + `call_func` method of DistributedInference. This can be used e.g. to get information + from the remote models to the main process. + *args : + Positional arguments to pass to the model constructor. + remote_ranks : list[int] | None, optional + The ranks of the remote GPUs to initialize the model on. If not provided, the model + will be initialized on all other ranks found in the distributed environment. + **kwargs : + Keyword arguments to pass to the model constructor. + """ + + @staticmethod + def initialize() -> None: + """Initialize the DistributedInference interface. + + This function must be called before instantiating any DistributedInference objects, + typically at the beginning of an inference script. + """ + DistributedManager.initialize() + dist = DistributedManager() + + options = rpc.TensorPipeRpcBackendOptions(num_worker_threads=128) + + # build device map + local_device = str(dist.device) + (local_device_type, local_device_id) = local_device.split(":") + if local_device_type != "cuda": + raise ValueError("Only CUDA devices are supported.") + local_device_num = int(local_device_id) + device_num_list = [ + torch.empty(1, dtype=torch.int64, device=dist.device) + for _ in range(dist.world_size) + ] + # gather device numbers from each worker + local_device_num = torch.tensor( + [local_device_num], dtype=torch.int64, device=dist.device + ) + torch.distributed.all_gather(device_num_list, local_device_num) + + for rank in range(dist.world_size): + if rank == dist.rank: + continue + remote_device_num = int(device_num_list[rank][0]) + remote_device = f"cuda:{remote_device_num}" + options.set_device_map(f"worker{rank}", {local_device: remote_device}) + + rpc.init_rpc( + f"worker{dist.rank}", + rank=dist.rank, + world_size=dist.world_size, + rpc_backend_options=options, + ) + + @staticmethod + def finalize() -> None: + """Shut down the DistributedInference interface. + + This function must be called, typically at the end of an inference script, + to ensure that the ranks hosting remote models do not shut down prematurely. + """ + rpc.shutdown() + + def __init__( + self, + model: type, + *args: Any, + remote_ranks: list[int] | None = None, + **kwargs: Any, + ): + self.dist = DistributedManager() + if remote_ranks is None: # select all other ranks + remote_ranks = list(range(self.dist.world_size)) + del remote_ranks[self.dist.rank] + self.remote_ranks = remote_ranks + self.available_remotes: Queue[int] = Queue(len(remote_ranks)) + self.out_queue: Queue[Any] = Queue(len(remote_ranks)) + + # initialize remote models + self.remote_models = { + rank: rpc.remote(f"worker{rank}", model, args=args, kwargs=kwargs) + for rank in remote_ranks + } + # initialize queue of available remotes + for rank in remote_ranks: + self.available_remotes.put(rank) + + def call_func( + self, func: str, *args: Any, rank: int | Literal["all"] = "all", **kwargs: Any + ) -> Any: + """Call a member function of the remote model. + + This can be used e.g. to get information from the model or to set parameters. + + Parameters + ---------- + func : str + The name of the member function to call. + *args : + Additional positional arguments to pass to the function. + rank : int | Literal["all"], optional + The rank of the remote GPU to call the function on. If "all", the function + will be called on all remote GPUs. + **kwargs : + Additional keyword arguments to pass to the function. + + Returns + ------- + The result of the function call. If `rank` is "all", a list of results from all + remote GPUs is returned. + """ + if rank == "all": + result = [ + self.call_func(func, *args, rank=rank, **kwargs) + for rank in self.remote_ranks + ] + return result + + rm = self.remote_models[rank] + remote_func = getattr(rm.rpc_sync(), func) + return remote_func(*args, **kwargs) + + def __call__(self, *args: Any, **kwargs: Any) -> None: + """Inference the remote model asynchronously. + + This will block until a remote model is available to accept the inputs. + + Parameters + ---------- + *args : + Positional arguments to pass to the model `forward` method. + **kwargs : + Keyword arguments to pass to the model `forward` method. + """ + + # get a remote model from the queue (will block until one is available) + rank = self.available_remotes.get() + rm = self.remote_models[rank] + torch.cuda.synchronize(device=self.dist.device) + task = rm.rpc_async().forward(*args, **kwargs) + + def callback( + completed_task: torch.futures.Future, + ) -> None: # called when the inference is finished + result = completed_task.value() + torch.cuda.synchronize( + device=self.dist.device + ) # necessary to ensure result is usable + self.out_queue.put(result) + self.available_remotes.put(rank) + + task.then(callback) + + def wait(self) -> None: + """Wait for all inference tasks to finish.""" + + for _ in range(len(self.remote_ranks)): + self.available_remotes.get() + self.out_queue.put(None) # signal that the inference is finished + for rank in self.remote_ranks: + self.available_remotes.put(rank) + + def results(self) -> Generator[Any, None, None]: + """Get the results of the inference tasks. + + This method will yield results until all inference tasks have finished. The results + may arrive out of order with respect to the inference calls. + """ + while (result := self.out_queue.get()) is not None: + yield result + + +def local_concurrent_pipeline(tasks: list[Callable]) -> None: + """Run a list of tasks concurrently on the local machine. + + This can be used to set up different stages of a distributed inference pipeline. + It will block until all tasks have finished. + + Parameters + ---------- + tasks : list[Callable] + A list of tasks to run concurrently. + """ + with ThreadPoolExecutor(max_workers=len(tasks)) as executor: + for task in tasks: + executor.submit(task) diff --git a/recipes/distributed/.gitignore b/recipes/distributed/.gitignore new file mode 100644 index 000000000..869c63791 --- /dev/null +++ b/recipes/distributed/.gitignore @@ -0,0 +1,2 @@ +uv.lock +outputs* diff --git a/recipes/distributed/README.md b/recipes/distributed/README.md new file mode 100644 index 000000000..a34a7aa9a --- /dev/null +++ b/recipes/distributed/README.md @@ -0,0 +1,104 @@ +# Earth2Studio Distributed Inference Recipe + +This recipe shows how to use the `DistributedInference` interface to distribute inference workloads +to multiple GPUs in a distributed computing environment (e.g. `torchrun` or MPI). + +## Prerequisites + +### Software + +Installing Earth2Studio and [Hydra](https://hydra.cc/docs/intro/) is sufficient for running the +recipe. The commands below in Quick Start will install a tested environment. + +### Hardware + +- GPUs: Any type with >= 20 GB memory, at least 2 GPUs required to run the recipe +- Storage: A few GB to store inference results and model checkpoints. + +## Quick Start + +### Installation + +Installing Earth2Studio is generally a sufficient prerequisite to use this recipe. The support +for models used by the recipe must be included in the installation. For the diagnostic model +example, this means installing Earth2Studio with + +```bash +pip install earth2studio[fcn,precip-afno] +``` + +To install a full tested environment, you can use pip: + +```bash +pip install -r requirements.txt +``` + +or set up a uv virtual environment: + +```bash +uv sync +``` + +### Test distributed inference + +Start an environment with at least 2 GPUs available. The run the distributed diagnostic model +example, substituting `` with the number of GPUs you have: + +```bash +# if you installed a uv environment +uv run torchrun --standalone --nnodes=1 --nproc-per-node= main.py --config-name=diagnostic.yaml + +# using default python +torchrun --standalone --nnodes=1 --nproc-per-node= main.py --config-name=diagnostic.yaml +``` + +## Documentation + +### Using the recipes + +Specify the recipe you want to run using the `--config-name` command line argument to `main.py`. +This is used to select the relevant function in `main.py`. Currently, only `diagnostic.yaml` is +provided; more recipes will be added later. + +The configuration of the recipes is managed with Hydra using YAML config files located in the +`cfg` directory. You can override default options by editing the config file, or from the command +line using the Hydra syntax: for example, to save the diagnostic model recipe output to +`output_file.zarr`: + +```bash +torchrun --standalone --nnodes=1 --nproc-per-node= main.py\ + --config-name=diagnostic.yaml ++parameters.output_path=output_file.zarr +``` + +### Supported distribution methods + +In a single-node environment, we recommend using `torchrun`. + +`DistributedInference` should also work with any distribution method supported by the +[`DistributedManager`](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-core/api/physicsnemo.distributed.html) +of PhysicsNeMo. The startup commands will need to be modified to the distribution. For instance, +an MPI job using 2 GPUs on a single node can be started with Slurm using a script: + +```bash +cd /recipes/distributed/ +mpirun --allow-run-as-root python main.py --config-name=diagnostic.yaml +``` + +which can then be launched with +`srun --nodes=1 --ntasks-per-node=2 --gpus-per-node=2 `, +replacing `` with the path where Earth2Studio is located and `` +with the startup script path. + +### Creating custom applications + +To create custom applications using `DistributedInference`, you can use the provided recipes as a +starting point. + +## Testing + +See the [testing `README`](test/README.md). + +## References + +- [PyTorch TensorPipe CUDA RPC](https://docs.pytorch.org/tutorials/recipes/cuda_rpc.html), the +PyTorch feature used to implement `DistributedInference`. diff --git a/recipes/distributed/cfg/diagnostic.yaml b/recipes/distributed/cfg/diagnostic.yaml new file mode 100644 index 000000000..504e5cdd5 --- /dev/null +++ b/recipes/distributed/cfg/diagnostic.yaml @@ -0,0 +1,6 @@ +recipe: "diagnostic" + +parameters: + time: "2023-06-01T00:00:00" + nsteps: 12 + output_path: diagnostic_distributed.zarr diff --git a/recipes/distributed/main.py b/recipes/distributed/main.py new file mode 100644 index 000000000..480d4642c --- /dev/null +++ b/recipes/distributed/main.py @@ -0,0 +1,66 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import hydra +from omegaconf import DictConfig +from src.diagnostic_distributed import PrecipDiagnostic, diagnostic_distributed + +from earth2studio.data import GFS +from earth2studio.io import ZarrBackend +from earth2studio.models.px import FCN, PrognosticModel +from earth2studio.utils.distributed import DistributedInference, DistributedManager + + +def diagnostic( + time: str = "2023-06-01T00:00:00", + nsteps: int = 12, + output_path: str = "diagnostic_distributed.zarr", + model_cls: type[PrognosticModel] = FCN, +) -> None: + """Distributed diagnostic model recipe.""" + dist = DistributedManager() + if dist.world_size < 2: + raise ValueError("This recipe requires at least 2 processes") + + if dist.rank == 0: # rank 0 will run the prognostic model and handle IO + model = model_cls.load_model(model_cls.load_default_package()) + + # create diagnostic models on the other available ranks + dist_diagnostic = DistributedInference(PrecipDiagnostic) + + # initialize data source and IO backend + data = GFS() + io = ZarrBackend(output_path) + + # run the inference + diagnostic_distributed([time], nsteps, model, dist_diagnostic, data, io) + + +recipes = { + "diagnostic": diagnostic, +} + + +@hydra.main(version_base=None, config_path="cfg") +def main(cfg: DictConfig) -> None: + """Initialize DistributedInference, choose the recipe and run it.""" + DistributedInference.initialize() + recipes[cfg.recipe](**cfg.parameters) + DistributedInference.finalize() + + +if __name__ == "__main__": + main() diff --git a/recipes/distributed/pyproject.toml b/recipes/distributed/pyproject.toml new file mode 100644 index 000000000..268e945a3 --- /dev/null +++ b/recipes/distributed/pyproject.toml @@ -0,0 +1,27 @@ +[project] +name = "earth2studio.recipe.distributed" +version = "0.1.0" +description = "Distributed Inference Recipe" +readme = "README.md" +requires-python = ">=3.10" +authors = [ + { name="NVIDIA Earth-2 Team" }, +] +dependencies = [ + "earth2studio[fcn,precip-afno]", + "hydra-core>=1.3.0", +] + +[project.urls] +Homepage = "https://github.com/NVIDIA/earth2studio/recipes/distributed" +Documentation = "https://nvidia.github.io/earth2studio" +Issues = "https://github.com/NVIDIA/earth2studio/issues" +Changelog = "https://github.com/NVIDIA/earth2studio/blob/main/CHANGELOG.md" + +[tool.uv.sources] +omegaconf = { git = "https://github.com/omry/omegaconf.git" } +earth2studio = { path = "../../", editable = true } + +[tool.hatch.build.targets.sdist] +include = ["src/**/*.py"] +exclude = [] diff --git a/recipes/distributed/requirements.txt b/recipes/distributed/requirements.txt new file mode 100644 index 000000000..a7e52c411 --- /dev/null +++ b/recipes/distributed/requirements.txt @@ -0,0 +1,510 @@ +# This file was autogenerated by uv via the following command: +# uv export --format requirements-txt --no-hashes +-e ../../ + # via earth2studio-recipe-distributed +absl-py==2.2.2 + # via dm-tree +aiobotocore==2.22.0 + # via s3fs +aiofiles==24.1.0 + # via ngcsdk +aiohappyeyeballs==2.6.1 + # via aiohttp +aiohttp==3.11.18 + # via + # aiobotocore + # gcsfs + # ngcsdk + # s3fs +aioitertools==0.12.0 + # via aiobotocore +aiosignal==1.3.2 + # via aiohttp +antlr4-python3-runtime==4.9.3 + # via hydra-core +asciitree==0.3.3 ; python_full_version < '3.11' + # via zarr +astunparse==1.6.3 + # via nvidia-dali-cuda120 +async-timeout==5.0.1 ; python_full_version < '3.11' + # via aiohttp +attrs==25.3.0 + # via + # aiohttp + # cfgrib + # dm-tree + # eccodes +bokeh==3.7.3 + # via dask +boto3==1.37.3 + # via ngcsdk +botocore==1.37.3 + # via + # aiobotocore + # boto3 + # ngcsdk + # s3transfer +cachetools==5.5.2 + # via google-auth +certifi==2025.4.26 + # via + # netcdf4 + # ngcsdk + # nvidia-physicsnemo + # requests +cffi==1.17.1 + # via + # cryptography + # eccodes +cfgrib==0.9.15.0 + # via earth2studio +cftime==1.6.4.post1 + # via + # earth2studio + # netcdf4 +charset-normalizer==3.4.2 + # via requests +click==8.2.0 + # via + # cfgrib + # dask + # distributed +cloudpickle==3.1.1 + # via + # dask + # distributed +colorama==0.4.6 ; sys_platform == 'win32' + # via + # click + # loguru + # tqdm +contourpy==1.3.2 + # via bokeh +crc32c==2.7.1 ; python_full_version >= '3.11' + # via numcodecs +cryptography==45.0.2 + # via ngcsdk +dask==2025.5.0 + # via + # distributed + # xarray +decorator==5.2.1 + # via gcsfs +distributed==2025.5.0 + # via dask +dm-tree==0.1.9 + # via nvidia-dali-cuda120 +docker==7.1.0 + # via ngcsdk +donfig==0.8.1.post1 ; python_full_version >= '3.11' + # via zarr +eccodes==2.41.0 + # via cfgrib +fasteners==0.19 ; python_full_version < '3.11' and sys_platform != 'emscripten' + # via zarr +filelock==3.18.0 + # via + # huggingface-hub + # torch +findlibs==0.1.1 + # via eccodes +frozenlist==1.6.0 + # via + # aiohttp + # aiosignal +fsspec==2025.3.2 + # via + # dask + # earth2studio + # gcsfs + # huggingface-hub + # nvidia-physicsnemo + # s3fs + # torch +gast==0.6.0 + # via nvidia-dali-cuda120 +gcsfs==2025.3.2 + # via earth2studio +google-api-core==2.24.2 + # via + # google-cloud-core + # google-cloud-storage +google-auth==2.40.1 + # via + # gcsfs + # google-api-core + # google-auth-oauthlib + # google-cloud-core + # google-cloud-storage +google-auth-oauthlib==1.2.2 + # via gcsfs +google-cloud-core==2.4.3 + # via google-cloud-storage +google-cloud-storage==3.1.0 + # via gcsfs +google-crc32c==1.7.1 + # via + # google-cloud-storage + # google-resumable-media +google-resumable-media==2.7.2 + # via google-cloud-storage +googleapis-common-protos==1.70.0 + # via google-api-core +h5netcdf==1.6.1 + # via earth2studio +h5py==3.13.0 + # via + # earth2studio + # h5netcdf +huggingface-hub==0.31.4 + # via + # earth2studio + # timm +hydra-core==1.3.0 + # via earth2studio-recipe-distributed +idna==3.10 + # via + # requests + # yarl +importlib-metadata==8.7.0 ; python_full_version < '3.12' + # via dask +isodate==0.7.2 + # via ngcsdk +jinja2==3.1.6 + # via + # bokeh + # dask + # distributed + # torch +jmespath==1.0.1 + # via + # aiobotocore + # boto3 + # botocore +locket==1.0.0 + # via + # distributed + # partd +loguru==0.7.3 + # via earth2studio +lz4==4.4.4 + # via dask +markdown-it-py==3.0.0 + # via rich +markupsafe==3.0.2 + # via jinja2 +mdurl==0.1.2 + # via markdown-it-py +mpmath==1.3.0 + # via sympy +msgpack==1.1.0 + # via distributed +multidict==6.4.4 + # via + # aiobotocore + # aiohttp + # yarl +narwhals==1.40.0 + # via bokeh +nest-asyncio==1.6.0 + # via earth2studio +netcdf4==1.7.2 + # via earth2studio +networkx==3.4.2 + # via torch +ngcsdk==3.148.1 + # via earth2studio +numcodecs==0.13.1 ; python_full_version < '3.11' + # via + # earth2studio + # zarr +numcodecs==0.14.1 ; python_full_version >= '3.11' + # via + # earth2studio + # zarr +numpy==2.2.6 + # via + # bokeh + # cfgrib + # cftime + # contourpy + # dask + # dm-tree + # eccodes + # h5py + # netcdf4 + # numcodecs + # nvidia-physicsnemo + # onnx + # onnx-weekly + # pandas + # torchvision + # xarray + # zarr +nvidia-cublas-cu12==12.8.3.14 ; platform_machine == 'x86_64' and sys_platform == 'linux' + # via + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 + # torch +nvidia-cuda-cupti-cu12==12.8.57 ; platform_machine == 'x86_64' and sys_platform == 'linux' + # via torch +nvidia-cuda-nvrtc-cu12==12.8.61 ; platform_machine == 'x86_64' and sys_platform == 'linux' + # via torch +nvidia-cuda-runtime-cu12==12.8.57 ; platform_machine == 'x86_64' and sys_platform == 'linux' + # via torch +nvidia-cudnn-cu12==9.7.1.26 ; platform_machine == 'x86_64' and sys_platform == 'linux' + # via torch +nvidia-cufft-cu12==11.3.3.41 ; platform_machine == 'x86_64' and sys_platform == 'linux' + # via torch +nvidia-cufile-cu12==1.13.0.11 ; platform_machine == 'x86_64' and sys_platform == 'linux' + # via torch +nvidia-curand-cu12==10.3.9.55 ; platform_machine == 'x86_64' and sys_platform == 'linux' + # via torch +nvidia-cusolver-cu12==11.7.2.55 ; platform_machine == 'x86_64' and sys_platform == 'linux' + # via torch +nvidia-cusparse-cu12==12.5.7.53 ; platform_machine == 'x86_64' and sys_platform == 'linux' + # via + # nvidia-cusolver-cu12 + # torch +nvidia-cusparselt-cu12==0.6.3 ; platform_machine == 'x86_64' and sys_platform == 'linux' + # via torch +nvidia-dali-cuda120==1.43.0 + # via nvidia-physicsnemo +nvidia-nccl-cu12==2.26.2 ; platform_machine == 'x86_64' and sys_platform == 'linux' + # via torch +nvidia-nvimgcodec-cu12==0.5.0.13 + # via nvidia-dali-cuda120 +nvidia-nvjitlink-cu12==12.8.61 ; platform_machine == 'x86_64' and sys_platform == 'linux' + # via + # nvidia-cufft-cu12 + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 + # torch +nvidia-nvtx-cu12==12.8.55 ; platform_machine == 'x86_64' and sys_platform == 'linux' + # via torch +nvidia-physicsnemo==1.0.1 + # via earth2studio +oauthlib==3.2.2 + # via requests-oauthlib +omegaconf @ git+https://github.com/omry/omegaconf.git@117f7de07285e4d1324b9229eaf873de15279457 + # via + # earth2studio-recipe-distributed + # hydra-core +onnx==1.18.0 + # via nvidia-physicsnemo +onnx-weekly==1.19.0.dev20250519 ; python_full_version >= '3.13' + # via earth2studio +packaging==25.0 + # via + # bokeh + # dask + # distributed + # h5netcdf + # huggingface-hub + # hydra-core + # ngcsdk + # xarray + # zarr +pandas==2.2.3 + # via + # bokeh + # dask + # xarray +partd==1.4.2 + # via dask +pillow==11.2.1 + # via + # bokeh + # torchvision +polling2==0.5.0 + # via ngcsdk +prettytable==3.16.0 + # via ngcsdk +propcache==0.3.1 + # via + # aiohttp + # yarl +proto-plus==1.26.1 + # via google-api-core +protobuf==6.31.0 + # via + # google-api-core + # googleapis-common-protos + # onnx + # onnx-weekly + # proto-plus +psutil==7.0.0 + # via + # distributed + # ngcsdk +pyarrow==20.0.0 + # via dask +pyasn1==0.6.1 + # via + # pyasn1-modules + # rsa +pyasn1-modules==0.4.2 + # via google-auth +pycparser==2.22 + # via cffi +pygments==2.19.1 + # via rich +python-dateutil==2.9.0.post0 + # via + # aiobotocore + # botocore + # ngcsdk + # pandas +python-dotenv==1.1.0 + # via earth2studio +pytz==2025.2 + # via pandas +pywin32==310 ; sys_platform == 'win32' + # via docker +pyyaml==6.0.2 + # via + # bokeh + # dask + # distributed + # donfig + # huggingface-hub + # omegaconf + # timm +requests==2.32.3 + # via + # docker + # gcsfs + # google-api-core + # google-cloud-storage + # huggingface-hub + # ngcsdk + # requests-oauthlib + # requests-toolbelt +requests-oauthlib==2.0.0 + # via google-auth-oauthlib +requests-toolbelt==1.0.0 + # via ngcsdk +rich==14.0.0 + # via ngcsdk +rsa==4.9.1 + # via google-auth +s3fs==2025.3.2 + # via + # earth2studio + # nvidia-physicsnemo +s3transfer==0.11.3 + # via boto3 +safetensors==0.5.3 + # via timm +semver==3.0.4 + # via ngcsdk +setuptools==80.7.1 + # via + # ngcsdk + # nvidia-physicsnemo + # torch + # triton +shortuuid==1.0.13 + # via ngcsdk +six==1.17.0 + # via + # astunparse + # nvidia-dali-cuda120 + # python-dateutil + # treelib +sortedcontainers==2.4.0 + # via distributed +sympy==1.14.0 + # via torch +tblib==3.1.0 + # via distributed +timm==1.0.15 + # via nvidia-physicsnemo +toolz==1.0.0 + # via + # dask + # distributed + # partd +torch==2.7.0 ; sys_platform == 'darwin' + # via + # earth2studio + # nvidia-physicsnemo + # timm + # torchvision +torch==2.7.0+cpu ; sys_platform != 'darwin' and sys_platform != 'linux' + # via + # earth2studio + # nvidia-physicsnemo + # timm + # torchvision +torch==2.7.0+cu128 ; sys_platform == 'linux' + # via + # earth2studio + # nvidia-physicsnemo + # timm + # torchvision +torchvision==0.22.0 + # via timm +tornado==6.5 + # via + # bokeh + # distributed +tqdm==4.67.1 + # via + # earth2studio + # huggingface-hub + # nvidia-physicsnemo +treelib==1.7.1 + # via nvidia-physicsnemo +triton==3.3.0 ; sys_platform == 'linux' + # via torch +typing-extensions==4.13.2 + # via + # huggingface-hub + # multidict + # onnx + # onnx-weekly + # rich + # torch + # zarr +tzdata==2025.2 + # via pandas +urllib3==2.4.0 + # via + # botocore + # distributed + # docker + # ngcsdk + # requests +validators==0.35.0 + # via ngcsdk +wcwidth==0.2.13 + # via prettytable +wheel==0.45.1 + # via astunparse +win32-setctime==1.2.0 ; sys_platform == 'win32' + # via loguru +wrapt==1.17.2 + # via + # aiobotocore + # dm-tree +xarray==2025.4.0 + # via + # earth2studio + # nvidia-physicsnemo +xyzservices==2025.4.0 + # via bokeh +yarl==1.20.0 + # via aiohttp +zarr==2.18.3 ; python_full_version < '3.11' + # via + # earth2studio + # nvidia-physicsnemo +zarr==3.0.8 ; python_full_version >= '3.11' + # via + # earth2studio + # nvidia-physicsnemo +zict==3.0.0 + # via distributed +zipp==3.21.0 ; python_full_version < '3.12' + # via importlib-metadata diff --git a/recipes/distributed/src/diagnostic_distributed.py b/recipes/distributed/src/diagnostic_distributed.py new file mode 100644 index 000000000..170a5d5af --- /dev/null +++ b/recipes/distributed/src/diagnostic_distributed.py @@ -0,0 +1,210 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import OrderedDict +from datetime import datetime + +import numpy as np +import torch +from loguru import logger +from physicsnemo.distributed import DistributedManager +from tqdm import tqdm + +from earth2studio.data import DataSource, fetch_data +from earth2studio.io import IOBackend +from earth2studio.models.dx import PrecipitationAFNO +from earth2studio.models.px import PrognosticModel +from earth2studio.utils.coords import CoordSystem, map_coords, split_coords +from earth2studio.utils.distributed import ( + DistributedInference, + local_concurrent_pipeline, +) +from earth2studio.utils.time import to_time_array + +logger.remove() +logger.add(lambda msg: tqdm.write(msg, end=""), colorize=True) + + +class PrecipDiagnostic: + """Wrapper for a precipitation diagnostic model distributed using DistributedInference. + + This provides an example of a class that can be used as the `dist_diagnostic` argument + to `diagnostic_distributed`. + """ + + def __init__(self, output_coords: CoordSystem = OrderedDict({})): + dist = DistributedManager() + self.diagnostic = PrecipitationAFNO.load_model( + PrecipitationAFNO.load_default_package() + ) + self.diagnostic.to(dist.device) + self.diagnostic_ic = self.diagnostic.input_coords() + self.diagnostic_oc = self.diagnostic.output_coords(self.diagnostic_ic) + self.output_coords = output_coords + + def get_coords(self) -> tuple[CoordSystem, CoordSystem]: + """Get the input and output coordinates of the diagnostic model.""" + return (self.diagnostic_ic, self.diagnostic_oc) + + @torch.inference_mode() + def forward( + self, x: torch.Tensor, coords: CoordSystem + ) -> tuple[torch.Tensor, CoordSystem]: + """Forward pass of the diagnostic model. + + Maps the input coordinates to the diagnostic model coordinates, runs the diagnostic, + and maps the diagnostic model result to the output coordinates. + """ + x, coords = map_coords(x, coords, self.diagnostic_ic) + x, coords = self.diagnostic(x, coords) + # Subselect domain/variables as indicated in output_coords + (x, coords) = map_coords(x, coords, self.output_coords) + return (x, coords) + + +# sphinx - diagnostic start +def diagnostic_distributed( + time: list[str] | list[datetime] | list[np.datetime64], + nsteps: int, + prognostic: PrognosticModel, + dist_diagnostic: DistributedInference, + data: DataSource, + io: IOBackend, + output_coords: CoordSystem = OrderedDict({}), +) -> IOBackend: + """Distributed diagnostic workflow. + This workflow creates a distributed inference pipeline that couples a prognostic + model on the local rank with a diagnostic model on remote rank(s). + + Parameters + ---------- + time : list[str] | list[datetime] | list[np.datetime64] + List of string, datetimes or np.datetime64 + nsteps : int + Number of forecast steps + prognostic : PrognosticModel + Prognostic model + dist_diagnostic: DistributedInference + Wrapper for a diagnostic model distributed using DistributedInference, + must be on same coordinate axis as prognostic. Must implement a `forward` + method that wraps the call to the diagnostic model, and a `get_coords` + method that returns a 2-tuple of (input_coords, output_coords). + data : DataSource + Data source + io : IOBackend + IO object + output_coords: CoordSystem, optional + IO output coordinate system override, by default OrderedDict({}) + device : torch.device, optional + Device to run inference on, by default None + + Returns + ------- + IOBackend + Output IO object + """ + # sphinx - diagnostic end + logger.info("Running diagnostic workflow!") + + dist = DistributedManager() + device = dist.device + + # Get information about the prognostic model + logger.info(f"Prognostic rank: {dist.rank}") + logger.info(f"Prognostic device: {device}") + prognostic = prognostic.to(device) + # Fetch data from data source and load onto device + prognostic_ic = prognostic.input_coords() + time = to_time_array(time) + + # Fetch initial conditions from data source and load onto device + x, coords = fetch_data( + source=data, + time=time, + variable=prognostic_ic["variable"], + lead_time=prognostic_ic["lead_time"], + device=device, + ) + logger.success(f"Fetched data from {data.__class__.__name__}") + + # Get the input and output coordinates of the remote diagnostic model + logger.info(f"Diagnostic ranks: {dist_diagnostic.remote_ranks}") + (diagnostic_ic, diagnostic_oc) = dist_diagnostic.call_func("get_coords")[0] + + # Set up IO backend and create output variables + _setup_io(io, time, nsteps, prognostic, diagnostic_oc, output_coords=output_coords) + + # Map lat and lon if needed + x, coords = map_coords(x, coords, prognostic_ic) + # Create prognostic iterator + model = prognostic.create_iterator(x, coords) + + def prognostic_loop() -> None: + """Pull outputs from prognostic model and pass to diagnostic models asynchronously.""" + for step, (x, coords) in enumerate(model): + dist_diagnostic(x.clone(), coords) + if step == nsteps: + break + dist_diagnostic.wait() + + def io_loop() -> None: + """Receive outputs from diagnostic models and write to IO backend.""" + with tqdm(total=nsteps + 1, desc="Waiting for diagnostic model data") as pbar: + for x, coords in dist_diagnostic.results(): + io.write(*split_coords(x, coords)) + pbar.update(1) + + logger.info("Inference starting!") + # launch the functions making up the inference pipeline in their own threads + # and wait for them to finish + local_concurrent_pipeline([prognostic_loop, io_loop]) + logger.success("Inference complete") + + return io + + +def _setup_io( + io: IOBackend, + time: list[str] | list[datetime] | list[np.datetime64], + nsteps: int, + prognostic: PrognosticModel, + diagnostic_oc: CoordSystem, + output_coords: CoordSystem = OrderedDict({}), +) -> None: + """Set up IO backend and create output variables.""" + + total_coords = prognostic.output_coords(prognostic.input_coords()) + for key, value in prognostic.output_coords( + prognostic.input_coords() + ).items(): # Scrub batch dims + if key in diagnostic_oc: + total_coords[key] = diagnostic_oc[key] + if value.shape == (0,): + del total_coords[key] + total_coords["time"] = time + total_coords["lead_time"] = np.asarray( + [ + prognostic.output_coords(prognostic.input_coords())["lead_time"] * i + for i in range(nsteps + 1) + ] + ).flatten() + total_coords.move_to_end("lead_time", last=False) + total_coords.move_to_end("time", last=False) + + for key, value in total_coords.items(): + total_coords[key] = output_coords.get(key, value) + var_names = total_coords.pop("variable") + io.add_array(total_coords, var_names) diff --git a/recipes/distributed/test/.gitignore b/recipes/distributed/test/.gitignore new file mode 100644 index 000000000..bbf2e23a1 --- /dev/null +++ b/recipes/distributed/test/.gitignore @@ -0,0 +1 @@ +test_figures/tp_*.png diff --git a/recipes/distributed/test/README.md b/recipes/distributed/test/README.md new file mode 100644 index 000000000..29585d287 --- /dev/null +++ b/recipes/distributed/test/README.md @@ -0,0 +1,32 @@ +# Tests + +## Test 1: Check diagnostic model outputs + +Run the distributed diagnostic model example in the parent directory, as indicated in the main +`README`: + +```bash +torchrun --standalone --nnodes=1 --nproc-per-node= main.py \ + --config-name=diagnostic.yaml +``` + +Check that the run finishes without errors. Then run the `check_diagnostic_outputs.py` script: + +```bash +python check_diagnostic_outputs.py +``` + +### Expected Result + +You should see an output similar to this: + +```bash +Minimum tp: 0.0 +Maximum tp: 0.06553447246551514 +Mean tp: 0.0005333806620910764 +``` + +There should also be a figure as a PNG file for each time step in the prediction in the +`test_figures` directory. Check that the outputs look reasonable for all time steps. Below +is an example: +![Sample of diagnostic model output](test_figures/diagnostic_sample.png) diff --git a/recipes/distributed/test/check_diagnostic_outputs.py b/recipes/distributed/test/check_diagnostic_outputs.py new file mode 100644 index 000000000..7f0237c2f --- /dev/null +++ b/recipes/distributed/test/check_diagnostic_outputs.py @@ -0,0 +1,46 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import xarray as xr +from matplotlib import pyplot as plt + + +def check_diagnostic_outputs(fn: str = "../diagnostic_distributed.zarr") -> None: + """Check the diagnostic outputs.""" + + with xr.open_dataset(fn, engine="zarr") as ds: + tp = ds["tp"].values + + min_tp = tp.min() + max_tp = tp.max() + mean_tp = tp.mean() + + print(f"Minimum tp: {min_tp}") + print(f"Maximum tp: {max_tp}") + print(f"Mean tp: {mean_tp}") + + os.makedirs("test_figures", exist_ok=True) + for i in range(tp.shape[1]): + fig, ax = plt.subplots(figsize=(10, 10)) + im = ax.imshow(tp[0, i, :, :]) + fig.colorbar(im, ax=ax, orientation="vertical") + fig.savefig(f"test_figures/tp_{i:02d}.png", bbox_inches="tight") + + +if __name__ == "__main__": + check_diagnostic_outputs() diff --git a/recipes/distributed/test/test_figures/diagnostic_sample.png b/recipes/distributed/test/test_figures/diagnostic_sample.png new file mode 100644 index 000000000..4ef7531bf Binary files /dev/null and b/recipes/distributed/test/test_figures/diagnostic_sample.png differ