diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 237bbfab2473..9b1bbbef0492 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -6,7 +6,7 @@ import math import os import types -from typing import Callable, Iterable, Union +from typing import Callable, Iterable from enum import Enum import functools import itertools @@ -702,8 +702,12 @@ def wait(self, handle_dependency=True) -> None: self.__original_dtype).to(self.__param.device) elif self.__quantization: instrument_w_nvtx(self.__quantization.quant_handle.wait)() - self.__param.data = self.__quantization.backend.dequantize( - self.__quantization.quantized_param, self.__quantization.scale_buffer).to(self.__param.device) + dequantized = self.__quantization.backend.dequantize(self.__quantization.quantized_param, + self.__quantization.scale_buffer) + # Fix for issue #7775: convert dequantized tensor back to original dtype (e.g., bf16) + if self.__original_dtype is not None: + dequantized = dequantized.to(self.__original_dtype) + self.__param.data = dequantized.to(self.__param.device) self.__param.ds_status = ZeroParamStatus.AVAILABLE @@ -719,6 +723,7 @@ def __init__( world_size: int, use_secondary_tensor=False, quantization=None, + original_dtype=None, ) -> None: self.allgather_handle = allgather_handle self.params = params @@ -727,6 +732,7 @@ def __init__( self.use_secondary_tensor = use_secondary_tensor self.complete = False self.quantization = quantization + self.original_dtype = original_dtype for param in self.params: if param.ds_status != ZeroParamStatus.INFLIGHT: @@ -741,8 +747,13 @@ def wait(self, handle_dependency=True) -> None: if self.quantization: instrument_w_nvtx(self.quantization.quant_handle.wait)() - flat_tensor = self.quantization.backend.dequantize( - self.quantization.quantized_param, self.quantization.scale_buffer).to(self.params[0].device) + # Fix for issue #7775: convert dequantized tensor back to original dtype (e.g., bf16) + # to prevent dtype mismatch when zero_quantized_weights is used with bf16 + dequantized = self.quantization.backend.dequantize(self.quantization.quantized_param, + self.quantization.scale_buffer) + if self.original_dtype is not None: + dequantized = dequantized.to(self.original_dtype) + flat_tensor = dequantized.to(self.params[0].device) self.partitions: List[Parameter] = [] for i in range(self.world_size): @@ -784,7 +795,7 @@ def free_buffer(): class MultipleAllGatherHandles: - def __init__(self, handles: List[Union[AllGatherHandle, AllGatherCoalescedHandle]]): + def __init__(self, handles: List[AllGatherCoalescedHandle]): self.handles = handles def wait(self, handle_dependency=True) -> None: @@ -1107,10 +1118,8 @@ def __init__(self, self.use_all_reduce_for_fetch_params = get_config_default(DeepSpeedZeroConfig, "use_all_reduce_for_fetch_params") - self.allgather_sequential = get_config_default(DeepSpeedZeroConfig, "allgather_sequential") if _ds_config is not None: self.use_all_reduce_for_fetch_params = _ds_config.zero_config.use_all_reduce_for_fetch_params - self.allgather_sequential = _ds_config.zero_config.allgather_sequential def _update_persist_config(self, ds_config): Init.apply_param_persistence = True @@ -1272,9 +1281,56 @@ def _all_gather_dtype(params, world_size, rank_in_group, ds_process_group, allga use_secondary_tensor=use_secondary_tensor, ) - def _all_gather_sequential(params, world_size, use_secondary_tensor, ds_process_group, quantize): - handles = [] + @instrument_w_nvtx + def all_gather_coalesced(params: Iterable[Parameter], + safe_mode: bool = False, + quantize: bool = False) -> AllGatherCoalescedHandle: + + # fetches from nvme if the partition is not available and in nvme + self._ensure_availability_of_partitioned_params(params) + + if self.num_partitions == 1: + return _no_gather_coalesced(params) + for param in params: + if param.ds_status != ZeroParamStatus.NOT_AVAILABLE: + raise RuntimeError(param.ds_summary()) + param.ds_status = ZeroParamStatus.INFLIGHT + + #use appropriate all gather process group + ds_process_group = self.ds_process_group + rank_in_group = self.rank + world_size = self.dp_world_size + use_secondary_tensor = params[0].ds_secondary_tensor is not None + if self.zero_param_process_group and use_secondary_tensor: + ds_process_group = self.zero_param_process_group #intragroup + rank_in_group = self.rank_in_group + world_size = self.num_ranks_in_param_group + + #pprint(dir(ds_process_group)) + # ensure that each rank has params in same order. the allgather + # is done by flattening the parameter list into a single tensor that + # can be allgathered in a single call - this means that if each rank + # gives a list of the same parameters in a different order we will + # silently get incorrect parameter values, and have very difficult + # to debug correctness issues. + params = sorted(params, key=lambda p: p.ds_id) + + if logger.isEnabledFor(logging.DEBUG): + debug_rank0(f"-allgather_coalesced: {[p.ds_id for p in params]}") + + if safe_mode: + # ensure that same list (with same ordering) of parameters are + # being allgathered across all ranks, otherwise could mix + # data between tensors. + assert_ints_same_as_other_ranks([p.ds_id for p in params]) + # ensure that tensors from each rank agree on the same ds_numel + # otherwise could mix data between tensors. + assert_ints_same_as_other_ranks([p.ds_tensor.ds_numel for p in params]) + + if len(params) == 1: + # have an opportunity to avoid some intermediate memory allocations + param = params[0] buffer_size = math.ceil(param.ds_numel / world_size) * world_size if use_secondary_tensor: buffer_size = param.ds_secondary_tensor.shape[0] * world_size #make sure out is appropriately sized @@ -1294,7 +1350,7 @@ def _all_gather_sequential(params, world_size, use_secondary_tensor, ds_process_ requires_grad=False, ) if not quantize: - handle = _dist_allgather_fn( + handles = _dist_allgather_fn( param_ds_tensor.to(get_accelerator().current_device_name()).to(allgather_dtype), param_buffer, ds_process_group, @@ -1302,7 +1358,7 @@ def _all_gather_sequential(params, world_size, use_secondary_tensor, ds_process_ if original_dtype == allgather_dtype: param.data = param_buffer.narrow(0, 0, param.ds_numel).view(param.ds_shape).to(param.device) - handles.append(AllGatherHandle(handle, param)) + return AllGatherHandle(handles, param) else: # This case is complicated: # We use `register_post_accumulate_grad_hook` to set allgather hooks. Normally, the hook is @@ -1317,8 +1373,10 @@ def _all_gather_sequential(params, world_size, use_secondary_tensor, ds_process_ # In theory, this path could be consolidated with the case where # (original_dtype == allgather_dtype), but because it changes the # state transition of DeepSpeed parameters, we keep it separate for safety. - handles.append( - AllGatherHandle(handle, param, param_buffer=param_buffer, original_dtype=original_dtype)) + return AllGatherHandle(handles, + param, + param_buffer=param_buffer, + original_dtype=original_dtype) else: if hasattr(param_ds_tensor, "ds_quant_scale"): scales = param_ds_tensor.ds_quant_scale @@ -1342,154 +1400,109 @@ def _all_gather_sequential(params, world_size, use_secondary_tensor, ds_process_ quant_info.backend = self.quantizer_module quant_info.quant_handle = quant_handle quant_info.scale_buffer = quant_scale_buffer - handles.append(AllGatherHandle(handle, param, quantization=quant_info)) - return MultipleAllGatherHandles(handles) - - def _all_gather_coalesced(params, world_size, rank_in_group, use_secondary_tensor, ds_process_group, quantize): - if self.use_all_reduce_for_fetch_params and not quantize and not use_secondary_tensor: + # Pass original_dtype for proper dtype restoration after dequantization + return AllGatherHandle(handle, param, quantization=quant_info, original_dtype=original_dtype) - # Use all_reduce instead of all_gather to fetch the module params - flat_buffer_size = sum(p.ds_numel_aligned for p in params) - flat_tensor = torch.zeros(flat_buffer_size, - dtype=get_only_unique_item(p.ds_tensor.dtype for p in params), - device=get_accelerator().current_device_name(), - requires_grad=False) - start_param = 0 - for param in params: - param.data = flat_tensor.narrow(0, start_param, param.ds_numel).view(param.ds_shape) - start = start_param + param.ds_tensor.ds_numel * self.get_partition_rank() - flat_tensor.narrow(0, start, param.ds_tensor.ds_numel).copy_(param.ds_tensor) - - start_param += param.ds_numel - - handle = dist.all_reduce(flat_tensor, group=ds_process_group, async_op=True) - - return AllReduceCoalescedHandle(handle=handle, params=params) else: - if not quantize: - dtype_params = defaultdict(list) - for p in params: - allgather_dtype = get_allgather_dtype(p, p.ds_tensor) - dtype_params[allgather_dtype].append(p) - handles = [] - for dtype in sort_dtypes(dtype_params.keys()): - handles.append( - _all_gather_dtype(dtype_params[dtype], world_size, rank_in_group, ds_process_group, dtype)) - - return MultipleAllGatherHandles(handles) + if self.use_all_reduce_for_fetch_params and not quantize and not use_secondary_tensor: - else: - partition_sz = sum(p.ds_tensor.ds_numel for p in params) - - if use_secondary_tensor: - partition_sz = sum(p.ds_tensor.ds_numel * p.ds_secondary_tensor_num_of_groups for p in params) - - flat_tensor = torch.empty(partition_sz * world_size, - dtype=torch.int8, + # Use all_reduce instead of all_gather to fetch the module params + flat_buffer_size = sum(p.ds_numel_aligned for p in params) + flat_tensor = torch.zeros(flat_buffer_size, + dtype=get_only_unique_item(p.ds_tensor.dtype for p in params), device=get_accelerator().current_device_name(), requires_grad=False) + start_param = 0 + for param in params: + param.data = flat_tensor.narrow(0, start_param, param.ds_numel).view(param.ds_shape) + start = start_param + param.ds_tensor.ds_numel * self.get_partition_rank() + flat_tensor.narrow(0, start, param.ds_tensor.ds_numel).copy_(param.ds_tensor) - if use_secondary_tensor: - if hasattr(params[0].ds_secondary_tensor, "ds_quant_scale"): - quantized_param = instrument_w_nvtx(torch.cat)([ - p.ds_secondary_tensor.data.to(get_accelerator().current_device_name()) for p in params - ]) - scales = instrument_w_nvtx(torch.cat)([ - p.ds_secondary_tensor.ds_quant_scale.to(get_accelerator().current_device_name()) - for p in params - ]) - else: - quantized_param, scales = self.quantizer_module.quantize( - instrument_w_nvtx(torch.cat)([ - p.ds_secondary_tensor.to(get_accelerator().current_device_name()) for p in params - ])) - else: - if hasattr(params[0].ds_tensor, "ds_quant_scale"): - quantized_param = instrument_w_nvtx(torch.cat)( - [p.ds_tensor.data.to(get_accelerator().current_device_name()) for p in params]) - scales = instrument_w_nvtx(torch.cat)([ - p.ds_tensor.ds_quant_scale.to(get_accelerator().current_device_name()) for p in params - ]) - else: - quantized_param, scales = self.quantizer_module.quantize( - instrument_w_nvtx(torch.cat)( - [p.ds_tensor.to(get_accelerator().current_device_name()) for p in params])) - quant_scale_buffer = torch.empty( - scales.numel() * world_size, - dtype=torch.float32, - device=get_accelerator().current_device_name(), - requires_grad=False, - ) - handle = _dist_allgather_fn(quantized_param, flat_tensor, ds_process_group) - quant_handle = _dist_allgather_fn(scales, quant_scale_buffer, ds_process_group) - quant_info = QuantizationInfo() - quant_info.quantized_param = flat_tensor - quant_info.backend = self.quantizer_module - quant_info.quant_handle = quant_handle - quant_info.scale_buffer = quant_scale_buffer - quant_info.partition_sz = partition_sz - quant_info.world_size = world_size - return AllGatherCoalescedHandle( - allgather_handle=handle, - params=params, - partitions=None, - world_size=world_size, - use_secondary_tensor=use_secondary_tensor, - quantization=quant_info, - ) - - @instrument_w_nvtx - def all_gather_coalesced(params: Iterable[Parameter], - safe_mode: bool = False, - quantize: bool = False) -> AllGatherCoalescedHandle: - - # fetches from nvme if the partition is not available and in nvme - self._ensure_availability_of_partitioned_params(params) - - if self.num_partitions == 1: - return _no_gather_coalesced(params) - - for param in params: - if param.ds_status != ZeroParamStatus.NOT_AVAILABLE: - raise RuntimeError(param.ds_summary()) - param.ds_status = ZeroParamStatus.INFLIGHT + start_param += param.ds_numel - #use appropriate all gather process group - ds_process_group = self.ds_process_group - rank_in_group = self.rank - world_size = self.dp_world_size - use_secondary_tensor = params[0].ds_secondary_tensor is not None - if self.zero_param_process_group and use_secondary_tensor: - ds_process_group = self.zero_param_process_group #intragroup - rank_in_group = self.rank_in_group - world_size = self.num_ranks_in_param_group + handle = dist.all_reduce(flat_tensor, group=ds_process_group, async_op=True) - #pprint(dir(ds_process_group)) - # ensure that each rank has params in same order. the allgather - # is done by flattening the parameter list into a single tensor that - # can be allgathered in a single call - this means that if each rank - # gives a list of the same parameters in a different order we will - # silently get incorrect parameter values, and have very difficult - # to debug correctness issues. - params = sorted(params, key=lambda p: p.ds_id) - - if logger.isEnabledFor(logging.DEBUG): - debug_rank0(f"-allgather_coalesced: {[p.ds_id for p in params]}") - - if safe_mode: - # ensure that same list (with same ordering) of parameters are - # being allgathered across all ranks, otherwise could mix - # data between tensors. - assert_ints_same_as_other_ranks([p.ds_id for p in params]) - # ensure that tensors from each rank agree on the same ds_numel - # otherwise could mix data between tensors. - assert_ints_same_as_other_ranks([p.ds_tensor.ds_numel for p in params]) + return AllReduceCoalescedHandle(handle=handle, params=params) + else: + if not quantize: + dtype_params = defaultdict(list) + for p in params: + allgather_dtype = get_allgather_dtype(p, p.ds_tensor) + dtype_params[allgather_dtype].append(p) + handles = [] + for dtype in sort_dtypes(dtype_params.keys()): + handles.append( + _all_gather_dtype(dtype_params[dtype], world_size, rank_in_group, ds_process_group, + dtype)) + + return MultipleAllGatherHandles(handles) - if self.allgather_sequential or len(params) == 1: - return _all_gather_sequential(params, world_size, use_secondary_tensor, ds_process_group, quantize) - else: - return _all_gather_coalesced(params, world_size, rank_in_group, use_secondary_tensor, ds_process_group, - quantize) + else: + partition_sz = sum(p.ds_tensor.ds_numel for p in params) + + if use_secondary_tensor: + partition_sz = sum(p.ds_tensor.ds_numel * p.ds_secondary_tensor_num_of_groups + for p in params) + + flat_tensor = torch.empty(partition_sz * world_size, + dtype=torch.int8, + device=get_accelerator().current_device_name(), + requires_grad=False) + + if use_secondary_tensor: + if hasattr(params[0].ds_secondary_tensor, "ds_quant_scale"): + quantized_param = instrument_w_nvtx(torch.cat)([ + p.ds_secondary_tensor.data.to(get_accelerator().current_device_name()) + for p in params + ]) + scales = instrument_w_nvtx(torch.cat)([ + p.ds_secondary_tensor.ds_quant_scale.to(get_accelerator().current_device_name()) + for p in params + ]) + else: + quantized_param, scales = self.quantizer_module.quantize( + instrument_w_nvtx(torch.cat)([ + p.ds_secondary_tensor.to(get_accelerator().current_device_name()) + for p in params + ])) + else: + if hasattr(params[0].ds_tensor, "ds_quant_scale"): + quantized_param = instrument_w_nvtx(torch.cat)( + [p.ds_tensor.data.to(get_accelerator().current_device_name()) for p in params]) + scales = instrument_w_nvtx(torch.cat)([ + p.ds_tensor.ds_quant_scale.to(get_accelerator().current_device_name()) + for p in params + ]) + else: + quantized_param, scales = self.quantizer_module.quantize( + instrument_w_nvtx(torch.cat)( + [p.ds_tensor.to(get_accelerator().current_device_name()) for p in params])) + quant_scale_buffer = torch.empty( + scales.numel() * world_size, + dtype=torch.float32, + device=get_accelerator().current_device_name(), + requires_grad=False, + ) + handle = _dist_allgather_fn(quantized_param, flat_tensor, ds_process_group) + quant_handle = _dist_allgather_fn(scales, quant_scale_buffer, ds_process_group) + quant_info = QuantizationInfo() + quant_info.quantized_param = flat_tensor + quant_info.backend = self.quantizer_module + quant_info.quant_handle = quant_handle + quant_info.scale_buffer = quant_scale_buffer + quant_info.partition_sz = partition_sz + quant_info.world_size = world_size + # Get the original dtype from param's ds_tensor for proper dtype restoration after dequantization + original_dtype = params[0].ds_tensor.dtype if params else None + return AllGatherCoalescedHandle( + allgather_handle=handle, + params=params, + partitions=None, + world_size=world_size, + use_secondary_tensor=use_secondary_tensor, + quantization=quant_info, + original_dtype=original_dtype, + ) def partition(param_list=None, hierarchy=0, has_been_updated=False, free_data=True): cls = param @@ -1620,8 +1633,8 @@ def _all_gather(self, param_list, async_op=False, hierarchy=None): all_gather_list.append(param) # note: param_list may contain params that are already in flight / aviailable. So we need to use all_gather_list if not async_op: - if self.allgather_sequential or len(all_gather_list) == 1: - ret_value = self._allgather_params_sequential(all_gather_list, hierarchy=hierarchy) + if len(all_gather_list) == 1: + ret_value = self._allgather_params(all_gather_list, hierarchy=hierarchy) else: all_gather_quantize_list = [] all_gather_nonquantize_list = [] @@ -1999,61 +2012,83 @@ def _allgather_params_coalesced(self, param_list, hierarchy=0, quantize=False): return None - def _allgather_params_sequential(self, param_list, hierarchy=0): + @torch.no_grad() + def _allgather_params(self, param_list, hierarchy=0): if len(param_list) == 0: return - for param in param_list: - partition_size = param.ds_tensor.ds_numel - tensor_size = partition_size * self.num_partitions + partition_size = sum([param.ds_tensor.ds_numel for param in param_list]) - flat_tensor = torch.empty(tensor_size, dtype=param.ds_tensor.dtype, device=self.local_device) - flat_tensor.requires_grad = False - if self.use_all_gather_into_tensor: - dist.all_gather_into_tensor(flat_tensor, - param.ds_tensor.to(get_accelerator().device_name()), - group=self.get_partition_dp_group(param), - async_op=False) - else: - partitions = [] - for i in range(self.num_partitions): - partitions.append(flat_tensor.narrow(0, partition_size * i, partition_size)) - if i == self.get_partition_rank(): - partitions[i].data.copy_(param.ds_tensor.data, non_blocking=True) - dist.all_gather(partitions, - partitions[self.get_partition_rank()], - group=self.get_partition_dp_group(param), - async_op=False) - - if hasattr(param.ds_tensor, 'ds_quant_scale'): - scale_size = param.ds_tensor.ds_quant_scale.numel() - scale_tensor_size = scale_size * self.num_partitions - flat_scale_tensor = torch.empty(scale_tensor_size, - dtype=param.ds_tensor.ds_quant_scale.dtype, - device=self.local_device) - flat_scale_tensor.requires_grad = False - if self.use_all_gather_into_tensor: - dist.all_gather_into_tensor(flat_scale_tensor, - param.ds_tensor.ds_quant_scale.to(get_accelerator().device_name()), - group=self.get_partition_dp_group(param), - async_op=False) - else: - scale_partitions = [] - for i in range(self.num_partitions): - scale_partitions.append(flat_scale_tensor.narrow(0, scale_size * i, scale_size)) - if i == self.get_partition_rank(): - scale_partitions[i].data.copy_(param.ds_tensor.ds_quant_scale.data, non_blocking=True) - dist.all_gather(scale_partitions, - scale_partitions[self.get_partition_rank()], + tensor_size = partition_size * self.num_partitions + flat_tensor = torch.empty(tensor_size, dtype=param_list[0].ds_tensor.dtype, device=self.local_device) + partitions = [] + for i in range(self.num_partitions): + start = partition_size * i + + partitions.append(flat_tensor.narrow(0, start, partition_size)) + + if i == self.get_partition_rank(): + offset = 0 + for param in param_list: + param_numel = param.ds_tensor.ds_numel + + partitions[i].narrow(0, offset, param_numel).copy_(param.ds_tensor.data) + + offset += param_numel + + if hasattr(param_list[0], 'ds_quant_scale'): + scale_size = sum([param.ds_tensor.ds_quant_scale.numel() for param in param_list]) + scale_tensor_size = scale_size * self.world_size + flat_scale_tensor = torch.empty(scale_tensor_size, + dtype=param_list[0].ds_tensor.ds_quant_scale.dtype, + device=self.local_device) + scale_partitions = [] + for i in range(self.world_size): + start = scale_tensor_size * i + scale_partitions.append(flat_scale_tensor.narrow(0, start, scale_tensor_size)) + if i == self.rank: + offset = 0 + for param in param_list: + param_scale_numel = param.ds_tensor.ds_quant_scale.ds_numel + + scale_partitions[i].narrow(0, offset, + param_scale_numel).copy_(param.ds_tensor.ds_quant_scale.data) + + offset += param_scale_numel + + dist.all_gather_into_tensor(flat_tensor, + partitions[self.get_partition_rank()], group=self.get_partition_dp_group(param), async_op=False) - flat_tensor = self.quantizer_module.dequantize(flat_tensor, flat_scale_tensor) + if hasattr(param_list[0], 'ds_quant_scale'): + dist.all_gather(flat_scale_tensor, + param_list[0].ds_quant_scale, + group=self.get_partition_dp_group(param), + async_op=False) + param_offset = 0 + + for param in param_list: + param_partition_size = param.ds_tensor.ds_numel + param_size = param.ds_numel + replicated_tensor = torch.empty(param.ds_shape, dtype=param.ds_tensor.dtype, device=self.local_device) - param.data = flat_tensor.narrow(0, 0, param.ds_numel).view(param.ds_shape) + for i in range(self.num_partitions): - # guarantee the communication to be completed - if not get_accelerator().resolves_data_dependency(): - get_accelerator().synchronize() + start = i * partition_size + + param_start = i * param_partition_size + + if param_start < param_size: + numel_to_copy = min(param_size - param_start, param_partition_size) + + part_to_copy = partitions[i].narrow(0, param_offset, numel_to_copy) + + replicated_tensor.view(-1).narrow(0, param_start, numel_to_copy).copy_(part_to_copy) + #param_offset += param.data.numel() + param_offset += param.ds_tensor.ds_numel + if hasattr(param_list[0], 'ds_quant_scale'): + replicated_tensor = self.quantizer_module.dequantize(replicated_tensor, flat_scale_tensor) + param.data = replicated_tensor.data return None diff --git a/tests/unit/runtime/zero/test_zero_quant_bf16.py b/tests/unit/runtime/zero/test_zero_quant_bf16.py new file mode 100644 index 000000000000..47045865c400 --- /dev/null +++ b/tests/unit/runtime/zero/test_zero_quant_bf16.py @@ -0,0 +1,63 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch +import deepspeed +from unit.common import DistributedTest +from unit.simple_model import SimpleModel, random_dataloader + + +class TestZeroQuantBF16(DistributedTest): + world_size = 2 + + @pytest.mark.parametrize("zero_quantized_weights", [True]) + def test_bf16_quantized_weights(self, zero_quantized_weights): + if not deepspeed.get_accelerator().is_bf16_supported(): + pytest.skip("bf16 is not supported by this accelerator") + + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "zero_optimization": { + "stage": 3, + "zero_quantized_weights": zero_quantized_weights, + }, + "bf16": { + "enabled": True + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-3 + } + } + } + + hidden_dim = 128 + model = SimpleModel(hidden_dim=hidden_dim) + model, _, _, _ = deepspeed.initialize(model=model, config=config_dict) + + # Ensure model is in bf16 + for param in model.parameters(): + assert param.dtype == torch.bfloat16 + + data_loader = random_dataloader(model=model, + total_samples=2, + hidden_dim=hidden_dim, + device=model.device, + dtype=torch.bfloat16) + + for n, batch in enumerate(data_loader): + # This triggers all_gather and dequantization + loss = model(batch[0], batch[1]) + + # Verify that param.data is indeed bfloat16 after all_gather + for name, param in model.named_parameters(): + assert param.data.dtype == torch.bfloat16, \ + f"Parameter {name} data dtype is {param.data.dtype}, expected torch.bfloat16" + + model.backward(loss) + model.step() + break