Skip to content

Commit dec85f5

Browse files
committed
Apply ruff formatting fix
Signed-off-by: Benji Beck <benjibeck@meta.com>
1 parent 450d4f6 commit dec85f5

File tree

3 files changed

+27
-18
lines changed

3 files changed

+27
-18
lines changed

test/quantization/quantize_/workflows/float8/test_float8_semi_sparse.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import unittest
8-
from torchao.quantization.quantize_.workflows.float8.float8_semi_sparse_tensor import Float8SemiSparseTensor
8+
from torchao.quantization.quantize_.workflows.float8.float8_semi_sparse_tensor import (
9+
Float8SemiSparseTensor,
10+
)
911
from torchao.quantization.quantize_.workflows.float8.float8_tensor import Float8Tensor
1012
from torchao.float8.inference import Float8MMConfig
1113
import torch
@@ -44,20 +46,25 @@ def test_sparse_vs_dense_fp8(self, sizes):
4446
linear = torch.nn.Linear(K, N, dtype=dtype, device=device)
4547

4648
apply_fake_sparsity(linear)
47-
49+
4850
mm_config = Float8MMConfig(use_fast_accum=True)
49-
input_fp8 = Float8Tensor.from_hp(input, float8_dtype=torch.float8_e4m3fn, mm_config=mm_config)
50-
51-
weight_fp8 = Float8Tensor.from_hp(linear.weight.data, float8_dtype=torch.float8_e4m3fn, mm_config=mm_config)
51+
input_fp8 = Float8Tensor.from_hp(
52+
input, float8_dtype=torch.float8_e4m3fn, mm_config=mm_config
53+
)
54+
55+
weight_fp8 = Float8Tensor.from_hp(
56+
linear.weight.data, float8_dtype=torch.float8_e4m3fn, mm_config=mm_config
57+
)
5258
dense_output = torch.nn.functional.linear(input_fp8, weight_fp8, linear.bias)
53-
59+
5460
weight_sparse_fp8 = Float8SemiSparseTensor.from_hp(linear.weight.data, [1, K])
55-
sparse_output = torch.nn.functional.linear(input_fp8, weight_sparse_fp8, linear.bias)
56-
57-
torch.testing.assert_close(
58-
dense_output, sparse_output, atol=3e-1, rtol=3e-1
61+
sparse_output = torch.nn.functional.linear(
62+
input_fp8, weight_sparse_fp8, linear.bias
5963
)
6064

65+
torch.testing.assert_close(dense_output, sparse_output, atol=3e-1, rtol=3e-1)
66+
67+
6168
instantiate_parametrized_tests(TestFloat8SemiSparseTensor)
6269

6370

torchao/quantization/quantize_/workflows/float8/float8_semi_sparse_tensor.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def from_hp(
8686
@implements(aten.t.default)
8787
def _(func, types, args, kwargs):
8888
from torch.utils._python_dispatch import return_and_correct_aliasing
89-
89+
9090
self = args[0]
9191
new = Float8SemiSparseTensor(
9292
sparse=self.sparse,
@@ -98,8 +98,10 @@ def _(func, types, args, kwargs):
9898

9999
def _linear_fp8_semi_sparse(input_tensor, weight_tensor, bias):
100100
from torchao.ops import rowwise_scaled_linear_sparse_cutlass_f8f8
101-
from torchao.quantization.quantize_.workflows.float8.float8_tensor import Float8Tensor
102-
101+
from torchao.quantization.quantize_.workflows.float8.float8_tensor import (
102+
Float8Tensor,
103+
)
104+
103105
if isinstance(input_tensor, Float8Tensor):
104106
input = input_tensor.qdata
105107
input_scale = input_tensor.scale
@@ -108,16 +110,16 @@ def _linear_fp8_semi_sparse(input_tensor, weight_tensor, bias):
108110
input = input_tensor.qdata
109111
input_scale = input_tensor.scale
110112
out_dtype = input_tensor.dtype
111-
113+
112114
weight = weight_tensor.sparse
113115
weight_meta = weight_tensor.meta
114116
weight_scale = weight_tensor.scale
115-
117+
116118
# Reshape input_scale if needed: kernel expects scale to match input shape minus last dim
117119
# For input [B, K], scale should be [B] not [B, 1]
118120
if input_scale.dim() > input.dim() - 1:
119121
input_scale = input_scale.squeeze(-1)
120-
122+
121123
return rowwise_scaled_linear_sparse_cutlass_f8f8(
122124
input, input_scale, weight, weight_meta, weight_scale, bias, out_dtype
123125
)
@@ -130,7 +132,7 @@ def _(func, types, args, kwargs):
130132
else: # aten.mm.default
131133
input_tensor, weight_tensor = args
132134
bias = None
133-
135+
134136
return _linear_fp8_semi_sparse(input_tensor, weight_tensor, bias)
135137

136138

torchao/quantization/quantize_/workflows/float8/float8_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def _(func, types, args, kwargs):
256256
args[1],
257257
args[2] if len(args) > 2 else None,
258258
)
259-
259+
260260
# If weight is not Float8Tensor, return NotImplemented to allow weight's dispatch to handle it
261261
if not isinstance(weight_tensor, Float8Tensor):
262262
return NotImplemented

0 commit comments

Comments
 (0)