|
9 | 9 |
|
10 | 10 | import torch |
11 | 11 | import torch.nn as nn |
| 12 | +from torch.distributed.fsdp import FSDPModule |
12 | 13 | from torch.distributed.fsdp.wrap import CustomPolicy |
| 14 | +from torch.distributed.tensor import DTensor, distribute_tensor |
13 | 15 | from torchmetrics import Metric, MetricCollection |
14 | 16 |
|
15 | 17 | from composer.models import ComposerModel |
@@ -378,3 +380,191 @@ def lambda_fn(current_module: nn.Module) -> bool | dict[str, Any]: |
378 | 380 | return cached_submodules_to_wrap.get(current_module, False) |
379 | 381 |
|
380 | 382 | return CustomPolicy(lambda_fn) |
| 383 | + |
| 384 | + |
| 385 | +# TODO: We want to eventually use model.named_parameters(recurse=False) to get all the params |
| 386 | +# associated with that module specifically. We need to do the following since we're following FSDP1 |
| 387 | +# conventions (and also since it's easy to support tied weights) but we want to move away from this |
| 388 | +# approach in the future. |
| 389 | +def _get_params_to_summon_fsdp2(module: torch.nn.Module, recurse: bool = True): |
| 390 | + """Gets the DTensors to materialize for an FSDP2 model based on recurse. |
| 391 | +
|
| 392 | + If recurse=False, we can encounter the following state: |
| 393 | + FSDPModule_1 |
| 394 | + |- weight (DTensor) <-- handled |
| 395 | + |- FSDPModule_2 |
| 396 | + | |- weight (DTensor) |
| 397 | + |- RegularModule_1 |
| 398 | + | |- weight (DTensor) <-- handled |
| 399 | + | |- FSDPModule_3 |
| 400 | + | | |- weight (DTensor) |
| 401 | + Where summon_full_params(FSDPModule_1) should materialize RegularModule_1.weight |
| 402 | + alongside the original FSDPModule_1.weight. Therefore, we use a dfs traversal |
| 403 | + to get all DTensors not owned by downstream FSDPModules. |
| 404 | + """ |
| 405 | + dtensor_params = {} |
| 406 | + |
| 407 | + def _dfs(module: torch.nn.Module, prefix: str = ''): |
| 408 | + # Add all DTensors within this (FSDP)module |
| 409 | + for name, param in module.named_parameters( |
| 410 | + recurse=False, |
| 411 | + remove_duplicate=False, |
| 412 | + ): |
| 413 | + if isinstance(param, DTensor): |
| 414 | + full_name = f'{prefix}.{name}' if prefix else name |
| 415 | + dtensor_params[full_name] = param |
| 416 | + for child_name, child in module.named_children(): |
| 417 | + if isinstance(child, FSDPModule) and not recurse: |
| 418 | + continue |
| 419 | + full_name = f'{prefix}.{child_name}' if prefix else child_name |
| 420 | + _dfs(child, full_name) |
| 421 | + |
| 422 | + _dfs(module, '') |
| 423 | + return dtensor_params |
| 424 | + |
| 425 | + |
| 426 | +# TODO: This function only works when model is a FSDPModule and doesn't work with other parallelisms |
| 427 | +# (like TP) which can make DTensors that are not FSDP specific. We want to support summon_full_params |
| 428 | +# for other kinds of parallelisms so this approach might need a generalized rework in the future. |
| 429 | +# Especially when we are able to deprecate FSDP1. |
| 430 | +# |
| 431 | +# Since supporting tied weights is the biggest concern, a potential approach is supporting |
| 432 | +# taking in a dict of the model params at the start, figuring out which params are tied, |
| 433 | +# and handling those tied params correctly for 3D parallelism and beyond. |
| 434 | +@contextlib.contextmanager |
| 435 | +def summon_full_params_fsdp2( |
| 436 | + model: torch.nn.Module, |
| 437 | + writeback: bool = True, |
| 438 | + recurse: bool = True, |
| 439 | + rank0_only: bool = False, |
| 440 | + offload_to_cpu: bool = False, |
| 441 | + with_grads: bool = False, |
| 442 | +): |
| 443 | + """Context manager to get full params for FSDP2 models with DTensor APIs. |
| 444 | +
|
| 445 | + Note: Although FSDP1 uses `unshard` and `reshard` for summoning full params, we use DTensor APIs |
| 446 | + to materialize the full parameters as that is the preferred approach for FSDP2. Additionally, |
| 447 | + `unshard` and `reshard` with writeback functionality is not supported for FSDP2 models. |
| 448 | +
|
| 449 | + Writeback limitation: Only in-place modifications to parameter data are supported. The context |
| 450 | + manager cannot write back structural changes such as replacing a parameter with a different |
| 451 | + object type or setting it to None. If this occurs, the context manager will just use the original |
| 452 | + DTensor. |
| 453 | +
|
| 454 | + We currently don't support rank0_only, offload_to_cpu, and with_grads. |
| 455 | + """ |
| 456 | + # TODO: We want to support these arguments in the future. |
| 457 | + if any([rank0_only, offload_to_cpu, with_grads]): |
| 458 | + raise ValueError( |
| 459 | + 'rank0_only, offload_to_cpu, and with_grads are not supported for FSDP2 models. ' |
| 460 | + 'The defaults supported are: rank0_only=False, offload_to_cpu=False, with_grads=False.', |
| 461 | + ) |
| 462 | + |
| 463 | + dtensor_params = _get_params_to_summon_fsdp2(model, recurse=recurse) |
| 464 | + |
| 465 | + if not dtensor_params: |
| 466 | + yield |
| 467 | + return |
| 468 | + |
| 469 | + model_dtensors = {} |
| 470 | + metadata = {} |
| 471 | + tied_params = {} |
| 472 | + |
| 473 | + # We want to get the module and attr of the param, so we can assign |
| 474 | + # module.attr = param.full_tensor() before we yield and |
| 475 | + # module.attr = distributed (potentially updated) tensor after we yield. |
| 476 | + def _get_module_and_attr(module: torch.nn.Module, param_name: str): |
| 477 | + module_path, local_param_name = param_name.rsplit('.', 1) |
| 478 | + submodule = module.get_submodule(module_path) |
| 479 | + return submodule, local_param_name |
| 480 | + |
| 481 | + # Group parameters by their underlying tensor to handle tied parameters |
| 482 | + tensor_to_names = {} |
| 483 | + for name, dtensor_param in dtensor_params.items(): |
| 484 | + if dtensor_param not in tensor_to_names: |
| 485 | + tensor_to_names[dtensor_param] = [] |
| 486 | + tensor_to_names[dtensor_param].append(name) |
| 487 | + |
| 488 | + # Process parameters, handling tied parameters correctly |
| 489 | + # since there are cases where two regular modules share the same |
| 490 | + # weight within an FSDPModule (e.g. weight tied embedding layers |
| 491 | + # in a GPT architecture). |
| 492 | + processed_tensors = set() |
| 493 | + for name, dtensor_param in dtensor_params.items(): |
| 494 | + metadata[name] = { |
| 495 | + 'device_mesh': dtensor_param.device_mesh, # type: ignore |
| 496 | + 'placements': dtensor_param.placements, # type: ignore |
| 497 | + 'requires_grad': dtensor_param.requires_grad, # type: ignore |
| 498 | + } |
| 499 | + model_dtensors[name] = dtensor_param |
| 500 | + |
| 501 | + # Only materialize the full tensor once per unique tensor |
| 502 | + if dtensor_param not in processed_tensors: |
| 503 | + full_tensor = dtensor_param.full_tensor() |
| 504 | + new_param = torch.nn.Parameter(full_tensor.detach().clone()) |
| 505 | + |
| 506 | + # Set the same parameter instance for all tied parameters |
| 507 | + for tied_name in tensor_to_names[dtensor_param]: |
| 508 | + parent_module, attr_name = _get_module_and_attr(model, tied_name) |
| 509 | + setattr(parent_module, attr_name, new_param) |
| 510 | + tied_params[tied_name] = new_param |
| 511 | + |
| 512 | + processed_tensors.add(dtensor_param) |
| 513 | + |
| 514 | + try: |
| 515 | + yield |
| 516 | + finally: |
| 517 | + # Process tied parameters to ensure writeback works correctly |
| 518 | + processed_tensors = set() |
| 519 | + tensor_to_updated_dtensor = {} |
| 520 | + |
| 521 | + for name, dtensor_param in dtensor_params.items(): |
| 522 | + parent_module, attr_name = _get_module_and_attr(model, name) |
| 523 | + |
| 524 | + if writeback and dtensor_param not in processed_tensors: |
| 525 | + # We update model_dtensors[name] to use the updated param |
| 526 | + # after the model changes. For tied parameters, we only need |
| 527 | + # to do this once per unique tensor. |
| 528 | + current_param = getattr(parent_module, attr_name) |
| 529 | + if hasattr( |
| 530 | + current_param, |
| 531 | + 'data', |
| 532 | + ) and current_param.data is not None: |
| 533 | + meta = metadata[name] |
| 534 | + sharded = distribute_tensor( |
| 535 | + current_param.data, |
| 536 | + meta['device_mesh'], |
| 537 | + meta['placements'], |
| 538 | + ) |
| 539 | + new_param = torch.nn.Parameter(sharded) |
| 540 | + new_param.requires_grad = meta['requires_grad'] |
| 541 | + tensor_to_updated_dtensor[dtensor_param] = new_param |
| 542 | + processed_tensors.add(dtensor_param) |
| 543 | + else: |
| 544 | + warnings.warn( |
| 545 | + f'Parameter {name} cannot be written back because it has no .data attribute ' |
| 546 | + f'or .data is None. The original DTensor will be restored instead as structural ' |
| 547 | + f'changes are not supported.', |
| 548 | + ) |
| 549 | + |
| 550 | + # Restore the appropriate DTensor for this parameter |
| 551 | + if writeback and dtensor_param in tensor_to_updated_dtensor: |
| 552 | + setattr(parent_module, attr_name, tensor_to_updated_dtensor[dtensor_param]) |
| 553 | + else: |
| 554 | + setattr(parent_module, attr_name, model_dtensors[name]) |
| 555 | + |
| 556 | + |
| 557 | +def validate_all_dtensors_are_fsdp_based(model: torch.nn.Module): |
| 558 | + """Validates that all DTensors in the model are made by a call to `fully_shard`.""" |
| 559 | + all_params = {param for param in model.parameters() if isinstance(param, DTensor)} |
| 560 | + fsdp_params = set() |
| 561 | + for module in model.modules(): |
| 562 | + if isinstance(module, FSDPModule): |
| 563 | + for param in module.parameters(): |
| 564 | + if isinstance(param, DTensor): |
| 565 | + fsdp_params.add(param) |
| 566 | + if all_params != fsdp_params: |
| 567 | + raise ValueError( |
| 568 | + 'All DTensors in the model must be made by a call to `fully_shard`. ' |
| 569 | + f'Found {len(all_params - fsdp_params)} DTensors that were not made by `fully_shard`.', |
| 570 | + ) |
0 commit comments