Skip to content

Commit 4ae29b1

Browse files
authored
Adding support_full_params functionality for FSDP2 (#3907)
1 parent 81d4f8e commit 4ae29b1

File tree

8 files changed

+609
-16
lines changed

8 files changed

+609
-16
lines changed

composer/distributed/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
prepare_tp_module,
1212
)
1313
from composer.distributed.prepare_distributed import parallelize_composer_model
14+
from composer.distributed.shared_utils import (
15+
get_summon_params_fn,
16+
)
1417

1518
__all__ = [
1619
'DDPSyncStrategy',
@@ -19,4 +22,5 @@
1922
'prepare_fsdp_module',
2023
'prepare_tp_module',
2124
'parallelize_composer_model',
25+
'get_summon_params_fn',
2226
]

composer/distributed/fsdp2_utils.py

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99

1010
import torch
1111
import torch.nn as nn
12+
from torch.distributed.fsdp import FSDPModule
1213
from torch.distributed.fsdp.wrap import CustomPolicy
14+
from torch.distributed.tensor import DTensor, distribute_tensor
1315
from torchmetrics import Metric, MetricCollection
1416

1517
from composer.models import ComposerModel
@@ -378,3 +380,191 @@ def lambda_fn(current_module: nn.Module) -> bool | dict[str, Any]:
378380
return cached_submodules_to_wrap.get(current_module, False)
379381

380382
return CustomPolicy(lambda_fn)
383+
384+
385+
# TODO: We want to eventually use model.named_parameters(recurse=False) to get all the params
386+
# associated with that module specifically. We need to do the following since we're following FSDP1
387+
# conventions (and also since it's easy to support tied weights) but we want to move away from this
388+
# approach in the future.
389+
def _get_params_to_summon_fsdp2(module: torch.nn.Module, recurse: bool = True):
390+
"""Gets the DTensors to materialize for an FSDP2 model based on recurse.
391+
392+
If recurse=False, we can encounter the following state:
393+
FSDPModule_1
394+
|- weight (DTensor) <-- handled
395+
|- FSDPModule_2
396+
| |- weight (DTensor)
397+
|- RegularModule_1
398+
| |- weight (DTensor) <-- handled
399+
| |- FSDPModule_3
400+
| | |- weight (DTensor)
401+
Where summon_full_params(FSDPModule_1) should materialize RegularModule_1.weight
402+
alongside the original FSDPModule_1.weight. Therefore, we use a dfs traversal
403+
to get all DTensors not owned by downstream FSDPModules.
404+
"""
405+
dtensor_params = {}
406+
407+
def _dfs(module: torch.nn.Module, prefix: str = ''):
408+
# Add all DTensors within this (FSDP)module
409+
for name, param in module.named_parameters(
410+
recurse=False,
411+
remove_duplicate=False,
412+
):
413+
if isinstance(param, DTensor):
414+
full_name = f'{prefix}.{name}' if prefix else name
415+
dtensor_params[full_name] = param
416+
for child_name, child in module.named_children():
417+
if isinstance(child, FSDPModule) and not recurse:
418+
continue
419+
full_name = f'{prefix}.{child_name}' if prefix else child_name
420+
_dfs(child, full_name)
421+
422+
_dfs(module, '')
423+
return dtensor_params
424+
425+
426+
# TODO: This function only works when model is a FSDPModule and doesn't work with other parallelisms
427+
# (like TP) which can make DTensors that are not FSDP specific. We want to support summon_full_params
428+
# for other kinds of parallelisms so this approach might need a generalized rework in the future.
429+
# Especially when we are able to deprecate FSDP1.
430+
#
431+
# Since supporting tied weights is the biggest concern, a potential approach is supporting
432+
# taking in a dict of the model params at the start, figuring out which params are tied,
433+
# and handling those tied params correctly for 3D parallelism and beyond.
434+
@contextlib.contextmanager
435+
def summon_full_params_fsdp2(
436+
model: torch.nn.Module,
437+
writeback: bool = True,
438+
recurse: bool = True,
439+
rank0_only: bool = False,
440+
offload_to_cpu: bool = False,
441+
with_grads: bool = False,
442+
):
443+
"""Context manager to get full params for FSDP2 models with DTensor APIs.
444+
445+
Note: Although FSDP1 uses `unshard` and `reshard` for summoning full params, we use DTensor APIs
446+
to materialize the full parameters as that is the preferred approach for FSDP2. Additionally,
447+
`unshard` and `reshard` with writeback functionality is not supported for FSDP2 models.
448+
449+
Writeback limitation: Only in-place modifications to parameter data are supported. The context
450+
manager cannot write back structural changes such as replacing a parameter with a different
451+
object type or setting it to None. If this occurs, the context manager will just use the original
452+
DTensor.
453+
454+
We currently don't support rank0_only, offload_to_cpu, and with_grads.
455+
"""
456+
# TODO: We want to support these arguments in the future.
457+
if any([rank0_only, offload_to_cpu, with_grads]):
458+
raise ValueError(
459+
'rank0_only, offload_to_cpu, and with_grads are not supported for FSDP2 models. '
460+
'The defaults supported are: rank0_only=False, offload_to_cpu=False, with_grads=False.',
461+
)
462+
463+
dtensor_params = _get_params_to_summon_fsdp2(model, recurse=recurse)
464+
465+
if not dtensor_params:
466+
yield
467+
return
468+
469+
model_dtensors = {}
470+
metadata = {}
471+
tied_params = {}
472+
473+
# We want to get the module and attr of the param, so we can assign
474+
# module.attr = param.full_tensor() before we yield and
475+
# module.attr = distributed (potentially updated) tensor after we yield.
476+
def _get_module_and_attr(module: torch.nn.Module, param_name: str):
477+
module_path, local_param_name = param_name.rsplit('.', 1)
478+
submodule = module.get_submodule(module_path)
479+
return submodule, local_param_name
480+
481+
# Group parameters by their underlying tensor to handle tied parameters
482+
tensor_to_names = {}
483+
for name, dtensor_param in dtensor_params.items():
484+
if dtensor_param not in tensor_to_names:
485+
tensor_to_names[dtensor_param] = []
486+
tensor_to_names[dtensor_param].append(name)
487+
488+
# Process parameters, handling tied parameters correctly
489+
# since there are cases where two regular modules share the same
490+
# weight within an FSDPModule (e.g. weight tied embedding layers
491+
# in a GPT architecture).
492+
processed_tensors = set()
493+
for name, dtensor_param in dtensor_params.items():
494+
metadata[name] = {
495+
'device_mesh': dtensor_param.device_mesh, # type: ignore
496+
'placements': dtensor_param.placements, # type: ignore
497+
'requires_grad': dtensor_param.requires_grad, # type: ignore
498+
}
499+
model_dtensors[name] = dtensor_param
500+
501+
# Only materialize the full tensor once per unique tensor
502+
if dtensor_param not in processed_tensors:
503+
full_tensor = dtensor_param.full_tensor()
504+
new_param = torch.nn.Parameter(full_tensor.detach().clone())
505+
506+
# Set the same parameter instance for all tied parameters
507+
for tied_name in tensor_to_names[dtensor_param]:
508+
parent_module, attr_name = _get_module_and_attr(model, tied_name)
509+
setattr(parent_module, attr_name, new_param)
510+
tied_params[tied_name] = new_param
511+
512+
processed_tensors.add(dtensor_param)
513+
514+
try:
515+
yield
516+
finally:
517+
# Process tied parameters to ensure writeback works correctly
518+
processed_tensors = set()
519+
tensor_to_updated_dtensor = {}
520+
521+
for name, dtensor_param in dtensor_params.items():
522+
parent_module, attr_name = _get_module_and_attr(model, name)
523+
524+
if writeback and dtensor_param not in processed_tensors:
525+
# We update model_dtensors[name] to use the updated param
526+
# after the model changes. For tied parameters, we only need
527+
# to do this once per unique tensor.
528+
current_param = getattr(parent_module, attr_name)
529+
if hasattr(
530+
current_param,
531+
'data',
532+
) and current_param.data is not None:
533+
meta = metadata[name]
534+
sharded = distribute_tensor(
535+
current_param.data,
536+
meta['device_mesh'],
537+
meta['placements'],
538+
)
539+
new_param = torch.nn.Parameter(sharded)
540+
new_param.requires_grad = meta['requires_grad']
541+
tensor_to_updated_dtensor[dtensor_param] = new_param
542+
processed_tensors.add(dtensor_param)
543+
else:
544+
warnings.warn(
545+
f'Parameter {name} cannot be written back because it has no .data attribute '
546+
f'or .data is None. The original DTensor will be restored instead as structural '
547+
f'changes are not supported.',
548+
)
549+
550+
# Restore the appropriate DTensor for this parameter
551+
if writeback and dtensor_param in tensor_to_updated_dtensor:
552+
setattr(parent_module, attr_name, tensor_to_updated_dtensor[dtensor_param])
553+
else:
554+
setattr(parent_module, attr_name, model_dtensors[name])
555+
556+
557+
def validate_all_dtensors_are_fsdp_based(model: torch.nn.Module):
558+
"""Validates that all DTensors in the model are made by a call to `fully_shard`."""
559+
all_params = {param for param in model.parameters() if isinstance(param, DTensor)}
560+
fsdp_params = set()
561+
for module in model.modules():
562+
if isinstance(module, FSDPModule):
563+
for param in module.parameters():
564+
if isinstance(param, DTensor):
565+
fsdp_params.add(param)
566+
if all_params != fsdp_params:
567+
raise ValueError(
568+
'All DTensors in the model must be made by a call to `fully_shard`. '
569+
f'Found {len(all_params - fsdp_params)} DTensors that were not made by `fully_shard`.',
570+
)

composer/distributed/shared_utils.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
"""Shared utilities for distributed training."""
55

66
import functools
7+
import warnings
8+
from contextlib import nullcontext
79
from typing import Callable, Optional
810

911
import torch
@@ -13,8 +15,10 @@
1315
from torchmetrics import Metric, MetricCollection
1416

1517
from composer.devices import Device
18+
from composer.distributed.fsdp2_utils import summon_full_params_fsdp2, validate_all_dtensors_are_fsdp_based
1619
from composer.models import ComposerModel
1720
from composer.utils import dist, get_device
21+
from composer.utils.misc import is_model_fsdp, is_model_fsdp2
1822
from composer.utils.parallelism import FSDP2Config, FSDPConfig
1923

2024

@@ -199,3 +203,30 @@ def update_sync_module_states_if_needed(model: nn.Module, fsdp_config: FSDP2Conf
199203
f'When doing mixed initialization, Rank 0 should have parameters on non-meta device, '
200204
f'and all other ranks should have parameters on meta device.',
201205
)
206+
207+
208+
def get_summon_params_fn(model: torch.nn.Module) -> Callable:
209+
"""Returns a contextmanager for summoning the full parameters of a model or any of its submodules.
210+
211+
We are using the full model state to figure out whether we should use an FSDP1-based or FSDP2-based
212+
version of the `summon_full_params` function. Once the `summon_full_params` function has been output,
213+
it can be used on any FSDP wrapped module within the model. Both `summon_full_params` functions
214+
have the same function signature, but the FSDP2 variant has some limitations (no support for
215+
`with_grads` and `with_grads_and_buffers`, all of which default to False).
216+
217+
Args:
218+
model (torch.nn.Module): The model to get the summon_full_params function for.
219+
220+
Returns:
221+
Callable: A contextmanager for summoning the full parameters of a model or any of its submodules.
222+
"""
223+
if is_model_fsdp2(model):
224+
validate_all_dtensors_are_fsdp_based(model)
225+
return summon_full_params_fsdp2
226+
elif is_model_fsdp(model):
227+
return FullyShardedDataParallel.summon_full_params
228+
else:
229+
warnings.warn(
230+
'No FSDP(1/2) Modules detected in the model, summon_full_params will be a nullcontext.',
231+
)
232+
return nullcontext

composer/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
create_interval_scheduler,
5757
get_free_tcp_port,
5858
is_model_fsdp,
59+
is_model_fsdp2,
5960
is_notebook,
6061
model_eval_mode,
6162
partial_format,
@@ -103,6 +104,7 @@
103104
'get_save_filename',
104105
'import_object',
105106
'is_model_fsdp',
107+
'is_model_fsdp2',
106108
'is_notebook',
107109
'StringEnum',
108110
'load_checkpoint',

composer/utils/misc.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from typing import TYPE_CHECKING, Callable, Optional, Union
1212

1313
import torch
14+
from torch.distributed.fsdp import FSDPModule
15+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
1416
from torch.nn.parallel import DistributedDataParallel
1517
from torchvision import transforms
1618
from torchvision.datasets import VisionDataset
@@ -190,20 +192,23 @@ def is_model_ddp(model: torch.nn.Module) -> bool:
190192

191193

192194
def is_model_fsdp(model: torch.nn.Module) -> bool:
193-
"""Whether ``model`` is an instance of a :class:`.FullyShardedDataParallel`."""
194-
try:
195-
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
196-
197-
if isinstance(model, FSDP):
195+
"""Whether ``model`` or any of its submodules are instances of :class:`.FullyShardedDataParallel` or :class:`.FSDPModule`."""
196+
if isinstance(model, (FSDP, FSDPModule)):
197+
return True
198+
for _, obj in model.named_children():
199+
if isinstance(obj, (FSDP, FSDPModule)):
198200
return True
201+
return False
199202

200-
# Check if model is wrapped with FSDP
201-
for _, obj in model.named_children():
202-
if isinstance(obj, FSDP):
203-
return True
204-
return False
205-
except ImportError:
206-
return False
203+
204+
def is_model_fsdp2(model: torch.nn.Module) -> bool:
205+
"""Whether ``model`` or any of its submodules are instances of specifically :class:`.FSDPModule`."""
206+
if isinstance(model, FSDPModule):
207+
return True
208+
for _, obj in model.named_children():
209+
if isinstance(obj, FSDPModule):
210+
return True
211+
return False
207212

208213

209214
def is_notebook():

tests/common/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
EmbeddedWeightTiedModel,
2525
EmptyModel,
2626
EvenSimplerMLP,
27+
NestedFSDPModel,
2728
PartialWeightTiedModel,
2829
SimpleComposerMLP,
2930
SimpleConvModel,
@@ -62,6 +63,7 @@ def get_module_subclasses(module: types.ModuleType, cls: type) -> list[type]:
6263
'SimpleTransformerMaskedLM',
6364
'EmbeddedWeightTiedModel',
6465
'PartialWeightTiedModel',
66+
'NestedFSDPModel',
6567
'SimpleWeightTiedModel',
6668
'EventCounterCallback',
6769
'deep_compare',

0 commit comments

Comments
 (0)