Skip to content

Commit 3f3c4a1

Browse files
greg-kwasniewski1lucaslie
authored andcommitted
Fixed QKVO col-row sharding
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
1 parent 92b808b commit 3f3c4a1

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,13 @@ def get_model_from_config_patched(config, **kwargs):
148148
_nemotron_h_base_model_tp_plan = {
149149
"in_proj": "mamba",
150150
"out_proj": "rowwise",
151+
"q_proj": "colwise",
152+
"k_proj": "colwise",
153+
"v_proj": "colwise",
154+
"o_proj": "rowwise",
151155
"up_proj": "colwise",
152156
"down_proj": "rowwise",
153-
"*": "gather",
157+
# "*": "gather",
154158
}
155159

156160

tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,8 @@ def _validate_sharded_shapes(
8383
next_lin_node, _ = bfs(node, is_any_lin_op, include_root=False)
8484
nodes_to_validate = subgraph(
8585
[node],
86-
[next_lin_node],
8786
include=lambda n: is_op(n, [torch.ops.aten.view, torch.ops.aten.reshape]),
87+
boundary_condition=is_any_lin_op,
8888
)
8989
for view_node in nodes_to_validate:
9090
if len(view_node.args) < 2:
@@ -96,7 +96,7 @@ def _validate_sharded_shapes(
9696
continue
9797
if len(view_shape) >= 3 and isinstance(view_shape[2], int) and view_shape[2] != -1:
9898
args = list(view_node.args)
99-
view_shape[2] = view_shape[2] // world_size
99+
view_shape[2] = -1 # view_shape[2] // world_size
100100
args[1] = tuple(view_shape)
101101
view_node.args = tuple(args)
102102
view_node.meta["sharded"] = True

0 commit comments

Comments
 (0)