Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tensorrt_llm/_torch/auto_deploy/config/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ transforms:
detect_sharding:
stage: sharding
simple_shard_only: false
use_sharding_from_factory: false
support_partial_config: false
sharding_source: ['heuristic']
support_partial_config: true
sharding_dims: ['tp', 'ep', 'bmm']
requires_shape_prop: true
# TODO: (hg) need to ensure run_shape_prop after sharding.
Expand Down
30 changes: 29 additions & 1 deletion tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
import torch.nn.functional as F
from einops import rearrange
from transformers import AutoModelForCausalLM
from transformers import AutoConfig, AutoModelForCausalLM

from tensorrt_llm._torch.auto_deploy.models.patches.bamba import _bamba_mixer_torch_forward

Expand Down Expand Up @@ -144,6 +144,34 @@ def get_model_from_config_patched(config, **kwargs):
# TODO: figure out how this can be incorporated into the export patch system
AutoModelForCausalLM.from_config = get_model_from_config_patched

_config_from_pretrained_original = AutoConfig.from_pretrained
_nemotron_h_base_model_tp_plan = {
"in_proj": "mamba",
"out_proj": "rowwise",
"q_proj": "colwise",
"k_proj": "colwise",
"v_proj": "colwise",
"o_proj": "rowwise",
"up_proj": "colwise",
"down_proj": "rowwise",
# "*": "gather",
}


def get_config_from_pretrained_patched(*args, **kwargs):
ret = _config_from_pretrained_original(*args, **kwargs)
config = ret[0] if isinstance(ret, tuple) else ret
# heuristic to check if it's a NemotronH MoE Model
model_type = getattr(config, "model_type", None)
num_moe_layers = getattr(config, "layers_block_type", []).count("moe")
if model_type == "nemotron_h" and num_moe_layers > 0:
config.base_model_tp_plan = _nemotron_h_base_model_tp_plan
return (config, *ret[1:]) if isinstance(ret, tuple) else config


# TODO: figure out how this can be incorporated into the export patch system
AutoConfig.from_pretrained = get_config_from_pretrained_patched

# TODO: figure out how this can be incorporated into the export patch system
# Only patch if the module isn't available
_mamba_ssm_module = "mamba_ssm"
Expand Down
4 changes: 4 additions & 0 deletions tensorrt_llm/_torch/auto_deploy/transform/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,10 @@ def __and__(self, other: "TransformInfo") -> "TransformInfo":
has_valid_shapes=self.has_valid_shapes and other.has_valid_shapes,
)

# implement + addition operator for TransformInfo
def __add__(self, other: "TransformInfo") -> "TransformInfo":
return self.__and__(other)


TransformHistory = Dict[str, TransformInfo]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def _find_final_hidden_state_node(
if not (hasattr(mul_node, "args") and len(mul_node.args) >= 2):
return None
index_node = mul_node.args[1]
index_add_node = bfs(
index_add_node, _ = bfs(
index_node, lambda n: is_op(n, torch.ops.aten.index_add_), boundary=end_boundary
)
if not index_add_node:
Expand Down Expand Up @@ -383,7 +383,7 @@ def target(n: torch.fx.Node) -> bool:
return is_op(n, {torch.ops.aten.index_add_}) and len(n.users) == 0

try:
node_to_remove = bfs(start_boundary, target, attr_next="users", boundary=end_boundary)
node_to_remove, _ = bfs(start_boundary, target, attr_next="users", boundary=end_boundary)
graph.erase_node(node_to_remove)
return True
except RuntimeError:
Expand Down Expand Up @@ -458,7 +458,7 @@ def _apply(
lambda node: is_op(node, torch.ops.aten.one_hot),
attr_next="all_input_nodes",
boundary=start_boundary,
).args[0]
)[0].args[0]
if not selected_experts:
continue

Expand Down
6 changes: 3 additions & 3 deletions tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from ...shim.interface import CachedSequenceInterface
from ...utils.cuda_mem_tracker import cuda_memory_tracker
from ...utils.logger import ad_logger
from ...utils.node_utils import extract_param_names_from_lin_node, is_linear_op, is_op
from ...utils.node_utils import extract_param_names_from_node, is_linear_op, is_op
from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry


Expand All @@ -36,7 +36,7 @@ def _insert_fused_gemm(gm: GraphModule, idx: int, parent_node: Node, linear_node
y2 = y[:, out1:out1+out2]
"""
# some info we need
keys_unfused = [extract_param_names_from_lin_node(n)[0] for n in linear_nodes]
keys_unfused = [extract_param_names_from_node(n)[0] for n in linear_nodes]
params_unfused = [gm.get_parameter(k) for k in keys_unfused]
sizes_unfused = [p.size(0) for p in params_unfused]
key_fused = f"fused_weight_{idx}"
Expand Down Expand Up @@ -128,7 +128,7 @@ def build_custom_args_for_linear(self, scale_getattrs: Dict[str, Node]) -> Tuple
def _insert_fused_quant_gemm(
self, gm: GraphModule, idx: int, parent_node: Node, linear_nodes: List[Node]
):
keys_unfused = [extract_param_names_from_lin_node(n)[0] for n in linear_nodes]
keys_unfused = [extract_param_names_from_node(n)[0] for n in linear_nodes]
params_unfused = [gm.get_parameter(k) for k in keys_unfused]
sizes_unfused = [p.size(0) for p in params_unfused]
key_fused = f"fused_weight_{idx}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
from ...utils.node_utils import (
extract_param_names_from_lin_node,
extract_param_names_from_node,
get_quantization_params_from_linear_node,
is_bmm_op,
is_linear_op,
Expand Down Expand Up @@ -136,7 +136,7 @@ def _insert_quantized_linear(

The state_dict is also updated to contain the sharded weights.
"""
param_name, _ = extract_param_names_from_lin_node(node)
param_name, _ = extract_param_names_from_node(node)
original_weight = gm.get_parameter(param_name)
new_param = nn.Parameter(self.quantize_weight(original_weight), requires_grad=False)
modname, _, attrname = param_name.rpartition(".")
Expand Down
Loading
Loading