Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions test/prototype/moe_training/mxfp8/test_mxfp8_a2a.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def test_a2a_fwd_bwd(self):
tokens_per_ep_rank,
dim,
device=self.device,
dtype=torch.float32,
dtype=torch.bfloat16,
requires_grad=True,
)
ref_input_tensor = input_tensor.detach().clone().requires_grad_(True)
Expand Down Expand Up @@ -107,7 +107,7 @@ def test_a2a_fwd_bwd(self):
total_tokens_on_rank_after_a2a,
dim,
device=self.device,
dtype=torch.float32,
dtype=torch.bfloat16,
)

# Do the actual all_to_all_single
Expand Down Expand Up @@ -188,7 +188,7 @@ def test_a2a_fwd_bwd(self):
tokens_per_ep_rank,
dim,
device=self.device,
dtype=torch.float32,
dtype=torch.bfloat16,
requires_grad=True,
)
ref_input_tensor = input_tensor.detach().clone().requires_grad_(True)
Expand Down
6 changes: 3 additions & 3 deletions torchao/prototype/moe_training/kernels/mxfp8/quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def compute_blocked_scale_offsets_for_M_groups(offsets: torch.Tensor):
- starting_row_after_padding: 1D integer tensor representing the starting row after padding each to blocked format.
"""
# Calculate group sizes
zero = torch.tensor([0], dtype=offsets.dtype, device=offsets.device)
zero = torch.zeros(1, dtype=offsets.dtype, device=offsets.device)
group_sizes = torch.diff(offsets, prepend=zero)

# Round each group size up to the nearest multiple of 128
Expand Down Expand Up @@ -203,8 +203,8 @@ def compute_blocked_scale_offsets_for_K_groups(
- starting_col_after_padding: 1D integer tensor representing the starting row after padding each to blocked format.
"""
# Calculate group sizes
zero = torch.tensor(
[0], dtype=scale_group_offsets.dtype, device=scale_group_offsets.device
zero = torch.zeros(
1, dtype=scale_group_offsets.dtype, device=scale_group_offsets.device
)
group_sizes = torch.diff(scale_group_offsets, prepend=zero)

Expand Down
Loading