@@ -341,6 +341,7 @@ def _insert_sharded_mamba(
341341 add_dist = False ,
342342 min_local_shape = min_local_shape ,
343343 fused_weight_dims = entry_fused_dims ,
344+ quantization_cb = quantization_cb ,
344345 )
345346
346347 # Get all weight nodes in the subgraph except for out_proj
@@ -573,6 +574,20 @@ class WeightShardingInfo(ShardingTransformInfo):
573574 # used for TP sharding of fused weights
574575 fused_weight_dims : Optional [list ] = None
575576
577+ def quantization_cb (
578+ self ,
579+ gm : GraphModule ,
580+ submod : nn .Module ,
581+ node : Node ,
582+ weight_key : str ,
583+ weight_new_shape : torch .Size ,
584+ dim : int ,
585+ rank : int ,
586+ world_size : int ,
587+ ) -> None :
588+ """Quantization callback. Default does nothing for non-quantized models."""
589+ return None
590+
576591 @classmethod
577592 def from_node (cls , node : Node , ** kwargs ) -> "WeightShardingInfo" :
578593 """
@@ -612,6 +627,7 @@ def apply(self, gm: GraphModule, node: Node) -> None:
612627 fused_weight_dims = self .fused_weight_dims
613628 if isinstance (self .fused_weight_dims , dict )
614629 else None ,
630+ quantization_cb = self .quantization_cb ,
615631 )
616632 else :
617633 _shard_parameter_node (
@@ -623,6 +639,7 @@ def apply(self, gm: GraphModule, node: Node) -> None:
623639 add_dist = self .dist_op is not None ,
624640 min_local_shape = self .min_local_shape ,
625641 fused_weight_dims = self .fused_weight_dims ,
642+ quantization_cb = self .quantization_cb ,
626643 )
627644
628645
@@ -741,18 +758,6 @@ def shard_load_hook(
741758 ) -> None :
742759 return
743760
744- def apply (self , gm : GraphModule , node : Node ) -> None :
745- _shard_parameter_node (
746- gm = gm ,
747- node = node ,
748- dim = self .split_dim .value ,
749- rank = self .rank ,
750- world_size = self .world_size ,
751- add_dist = self .dist_op is not None ,
752- min_local_shape = self .min_local_shape ,
753- quantization_cb = self .quantization_cb , # quant callback
754- )
755-
756761
757762def _shard_fp4_weight_scale (weight_scale , sharded_uint8_weight_shape , dim , rank , world_size ):
758763 assert weight_scale .dim () == 1
@@ -809,18 +814,6 @@ def shard_load_hook(
809814 state_dict [key ], weight_shape , dim , rank , world_size
810815 )
811816
812- def apply (self , gm : GraphModule , node : Node ) -> None :
813- _shard_parameter_node (
814- gm = gm ,
815- node = node ,
816- dim = self .split_dim .value ,
817- rank = self .rank ,
818- world_size = self .world_size ,
819- add_dist = self .dist_op is not None ,
820- min_local_shape = self .min_local_shape ,
821- quantization_cb = self .quantization_cb , # quant callback
822- )
823-
824817
825818TP_SHARDING_RULES = [
826819 (lambda n : is_op (n , torch .ops .auto_deploy .torch_fake_quant_fp8_linear ), FP8TPShardingInfo ),
0 commit comments