Skip to content

Commit 4b3cde4

Browse files
greg-kwasniewski1lucaslie
authored andcommitted
Fixed quantized sharding
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
1 parent 3f3c4a1 commit 4b3cde4

File tree

1 file changed

+17
-24
lines changed

1 file changed

+17
-24
lines changed

tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py

Lines changed: 17 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,7 @@ def _insert_sharded_mamba(
341341
add_dist=False,
342342
min_local_shape=min_local_shape,
343343
fused_weight_dims=entry_fused_dims,
344+
quantization_cb=quantization_cb,
344345
)
345346

346347
# Get all weight nodes in the subgraph except for out_proj
@@ -573,6 +574,20 @@ class WeightShardingInfo(ShardingTransformInfo):
573574
# used for TP sharding of fused weights
574575
fused_weight_dims: Optional[list] = None
575576

577+
def quantization_cb(
578+
self,
579+
gm: GraphModule,
580+
submod: nn.Module,
581+
node: Node,
582+
weight_key: str,
583+
weight_new_shape: torch.Size,
584+
dim: int,
585+
rank: int,
586+
world_size: int,
587+
) -> None:
588+
"""Quantization callback. Default does nothing for non-quantized models."""
589+
return None
590+
576591
@classmethod
577592
def from_node(cls, node: Node, **kwargs) -> "WeightShardingInfo":
578593
"""
@@ -612,6 +627,7 @@ def apply(self, gm: GraphModule, node: Node) -> None:
612627
fused_weight_dims=self.fused_weight_dims
613628
if isinstance(self.fused_weight_dims, dict)
614629
else None,
630+
quantization_cb=self.quantization_cb,
615631
)
616632
else:
617633
_shard_parameter_node(
@@ -623,6 +639,7 @@ def apply(self, gm: GraphModule, node: Node) -> None:
623639
add_dist=self.dist_op is not None,
624640
min_local_shape=self.min_local_shape,
625641
fused_weight_dims=self.fused_weight_dims,
642+
quantization_cb=self.quantization_cb,
626643
)
627644

628645

@@ -741,18 +758,6 @@ def shard_load_hook(
741758
) -> None:
742759
return
743760

744-
def apply(self, gm: GraphModule, node: Node) -> None:
745-
_shard_parameter_node(
746-
gm=gm,
747-
node=node,
748-
dim=self.split_dim.value,
749-
rank=self.rank,
750-
world_size=self.world_size,
751-
add_dist=self.dist_op is not None,
752-
min_local_shape=self.min_local_shape,
753-
quantization_cb=self.quantization_cb, # quant callback
754-
)
755-
756761

757762
def _shard_fp4_weight_scale(weight_scale, sharded_uint8_weight_shape, dim, rank, world_size):
758763
assert weight_scale.dim() == 1
@@ -809,18 +814,6 @@ def shard_load_hook(
809814
state_dict[key], weight_shape, dim, rank, world_size
810815
)
811816

812-
def apply(self, gm: GraphModule, node: Node) -> None:
813-
_shard_parameter_node(
814-
gm=gm,
815-
node=node,
816-
dim=self.split_dim.value,
817-
rank=self.rank,
818-
world_size=self.world_size,
819-
add_dist=self.dist_op is not None,
820-
min_local_shape=self.min_local_shape,
821-
quantization_cb=self.quantization_cb, # quant callback
822-
)
823-
824817

825818
TP_SHARDING_RULES = [
826819
(lambda n: is_op(n, torch.ops.auto_deploy.torch_fake_quant_fp8_linear), FP8TPShardingInfo),

0 commit comments

Comments
 (0)