-
Couldn't load subscription status.
- Fork 355
[WIP] Move float8 cutlass sparse layout to Float8SemiSparseTensor #3182
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Summary: Moving float8 cutlass sparse layout into its own class: https://github.com/pytorch/ao/blob/main/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py Differential Revision: D84467190
Signed-off-by: Benji Beck <benjibeck@meta.com>
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3182
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit 960bc91 with merge base 30082cb ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| needed for the rest of the system to understand the specific format that's adopted. | ||
| """ | ||
| OPAQUE = "opaque" | ||
| # todo: add semi-sparse |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jerryzh168 It seems we may want to add a packing format for sparse. Wondering if there's a preference between adding it here or in a separate file (similar to int4) for float8?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need packing format if we have a separate config? It looks like packing format is mostly to support different Int4WeightOnlyConfig kernel options (tinygemm, sparse marlin, etc).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, I noticed that we seem to replace the dense weight with quantized semi-sparse in the transform Would it make more sense to integrate Float8SemiSparseTensor here rather than gating with packing-format as I proposed previously? cc @jerryzh168
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup I think that the transform should call your subclass.
| from torchao.testing.utils import skip_if_rocm | ||
| from torchao.utils import torch_version_at_least | ||
|
|
||
| BF16_ACT_CONFIG = Float8WeightOnlyConfig( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this config makes sense, it's not something we support. From what I understand this is a bf16 a + fp8 sparse weight? We only have kernel support for fp8xfp8 +2:4 sparse matmul, no support for mixed input dtypes currently.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right, it seems I should be mirroring test_fp8_cutlass_sparse (from test_sparse_api.py) instead
with the difference being using the new flag/config which exposes the tensor subclass being added?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think Float8DynamicActivationFloat8SemiSparseWeightConfig should eventually resolve to your subclass.
But I would like to hold off on that until
- We're out of the QRT period (cc @RandySheriff just FYI on refactor plans)
- We have the same functionality (addmm/mm support)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good. In that case, should we sequence the changes as follows?
- Land
Float8SemiSparseTensorwith linear support - Add mm / addmm ops for feature parity
- Integrate into
Float8DynamicActivationFloat8SemiSparseWeightConfigafter QRT
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jcaip @jerryzh168 Mind confirming that you're onboard with this direction?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, that sounds good to me.
| implements_torch_function = Float8SemiSparseTensor.implements_torch_function | ||
|
|
||
|
|
||
| @implements(aten.linear.default) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'll also need to make sure mm and addmm are supported ops as well. The arg order is different from linear but it should be the same logic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, I'm onboard with that. Mind if I add those ops in a follow-up diff after this lands?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup that's fine with me :)
Signed-off-by: Benji Beck <benjibeck@meta.com>
Signed-off-by: Benji Beck <benjibeck@meta.com>
Signed-off-by: Benji Beck <benjibeck@meta.com>
| ((2, 32, 128), 256, 128), | ||
| ], | ||
| ) | ||
| def test_sparse_vs_dense_fp8(self, sizes): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jcaip Updated testing to follow style of [test_sparse_apu.py](https://fburl.com/18n157bf. For now I'm omitting any config related changes until QRT, however this diff does include all ops (linear, addmm, mm) so that integration can be done as follow up.
Could I get feedback on the construction of the implementations and test before adding similar for addmm, mm and adding polish?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @bbeckca I think the implementation and test look good, left a couple of comments.
| ) | ||
| dense_output = torch.nn.functional.linear(input_fp8, weight_fp8, linear.bias) | ||
|
|
||
| weight_sparse_fp8 = Float8SemiSparseTensor.from_hp(linear.weight.data, [1, K]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: .detach() instead of .data?
|
|
||
|
|
||
| class Float8SemiSparseTensor(TorchAOBaseTensor): | ||
| tensor_data_names = ["sparse", "meta", "scale"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: we should use [ compressed_values, metadata ] instead of sparse and meta here.
| implements_torch_function = Float8SemiSparseTensor.implements_torch_function | ||
|
|
||
|
|
||
| @implements(aten.t.default) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you have to implement transpose? Transpose on a sparse matrix is kind of tricky, we should probably throw an ValueError if it's called.
| Float8Tensor, | ||
| ) | ||
|
|
||
| if isinstance(input_tensor, Float8Tensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: what does this conditional do?
| else: # aten.mm.default | ||
| input_tensor, weight_tensor = args | ||
| bias = None | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you may need to do some transpose trickery here to support mm and addmm,
my understanding is that linear(x, w, bias) will return xW^t + bias so for mm / addmm you need to pass in a transposed weight
Moving float8 cutlass sparse layout into its own class:
https://github.com/pytorch/ao/blob/main/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py
cc @jerryzh168 @jcaip @danielvegamyhre