Skip to content

Conversation

@lucaslie
Copy link
Member

@lucaslie lucaslie commented Oct 29, 2025

Summary by CodeRabbit

Release Notes

  • New Features

    • Added support for NemotronH mixture-of-experts models for inference and export
    • Implemented heuristic-based weight sharding detection as an alternative to factory-based configuration
    • Added streaming state machine (SSM) layer sharding capability
    • Enhanced graph transformation composition with additional operator support
  • Improvements

    • Refined parameter update handling during model sharding operations
    • Extended utility functions for broader model architecture compatibility

Description

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 29, 2025

📝 Walkthrough

Walkthrough

This pull request refactors the auto-deployment sharding system from TP-centric to weight-focused approach. Key changes include: replacing TPShardingInfo with WeightShardingInfo, introducing ShardingSource and ShardingDim enums to control sharding strategy selection, generalizing weight extraction utilities (extract_weight_node, extract_param_names_from_node), adding helper functions for graph traversal and SSM/column-based sharding detection, and updating configuration structures to support both heuristic and factory-based sharding sources.

Changes

Cohort / File(s) Change Summary
Configuration
tensorrt_llm/_torch/auto_deploy/config/default.yaml
Replaced use_sharding_from_factory: false and support_partial_config: false with sharding_source: ['heuristic'] and support_partial_config: true in detect_sharding stage; enables heuristic-based sharding.
Model Patching
tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py
Added patch layer for AutoConfig.from_pretrained to detect NemotronH MoE models and inject base model topology plan (_nemotron_h_base_model_tp_plan); wraps original method with guard checks and augments config.
Transform Interface
tensorrt_llm/_torch/auto_deploy/transform/interface.py
Added __add__ operator to TransformInfo class that delegates to existing __and__ implementation, enabling operator overloading for transform merging.
Sharding Core Library
tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py
Major refactor: replaced TPShardingInfo with WeightShardingInfo; updated ShardingTransformConfig to use ShardingSource and ShardingDim enums; added detect_ssm_shard, _process_simple_shard, _process_ssm_sharding, _process_column_sharding functions; refactored detect_sharding_from_factory_config to emit WeightShardingInfo transforms; updated ShardingTransformExecutor to process weight_sharding_transforms and parameter_update_transforms.
Transform Library
tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py
Updated BFS calls to handle tuple return values: changed assignments to destructure (result, _) pattern in _find_final_hidden_state_node, _remove_dead_inplace_nodes_in_region, and MatchMoePattern._apply.
GEMM Fusion
tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py
Replaced calls to extract_param_names_from_lin_node with extract_param_names_from_node in _insert_fused_gemm and QuantizationFusionMixin._insert_fused_quant_gemm.
Quantization Transform
tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py
Updated imports and usage of extract_param_names_from_lin_node to extract_param_names_from_node in _insert_quantized_linear function.
Node Utilities
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py
Comprehensive refactor: renamed extract_weight_node_mm_node to extract_weight_node(node), extract_param_names_from_lin_node to extract_param_names_from_node(node); extended bfs signature to return Tuple[Node, int] with depth; added get_all_layer_subgraphs, draw_graph, get_layer_after_linear_node, subgraph, is_any_lin_op, num_users_of_weight_node utility functions; generalized weight/parametrized node handling.
Quantization Utilities
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py
Updated import and call to extract_param_names_from_node (renamed from extract_param_names_from_lin_node) in should_skip_quantization.
Sharding Utilities
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
Major overhaul: introduced WeightShardingInfo, ParameterUpdateInfo, LayerType, ShardingSource, ShardingDim enums; updated ShardingConfig to use weight_sharding_transforms and parameter_update_transforms; added shard_weight_tensor unified entry point; added _insert_sharded_mamba for Mamba-specific sharding; introduced _update_node_args utility; refactored TP/EP sharding classes to use new unified APIs; updated shape validation for sharded views.
TP Sharding Tests
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
Replaced all TPShardingInfo instantiations with WeightShardingInfo across test expectations; updated pattern detection to read from weight_sharding_transforms instead of tp_transforms.

Sequence Diagram(s)

sequenceDiagram
    participant Config as Config/Detect
    participant Detect as Sharding Detector
    participant Executor as Transform Executor
    participant Utils as Weight Utils
    participant Graph as Graph/Model

    Config->>Detect: sharding_source: [heuristic]
    
    rect rgb(240, 248, 255)
    Note over Detect: New Weight-Based Flow
    Detect->>Detect: detect_ssm_shard() / detect_sharding_from_factory_config()
    Detect->>Utils: extract_weight_node(node)
    Utils->>Graph: identify parametrized layers
    Utils-->>Detect: WeightShardingInfo[]
    end
    
    Detect->>Executor: weight_sharding_transforms: WeightShardingInfo[]
    Executor->>Utils: shard_weight_tensor() (unified entry)
    Utils->>Utils: handle fused weights, load hooks
    Utils->>Graph: apply parameter updates
    Executor->>Graph: mark nodes as sharded
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

  • Critical areas requiring attention:
    • tensorrt_llm/_torch/auto_deploy/utils/node_utils.py — Extensive API surface changes; refactoring of weight/node extraction logic with new return types (Tuple[Node, int] for bfs); multiple new utility functions; impacts downstream consumers across the codebase.
    • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py — Major structural overhaul with new enum types, data classes (WeightShardingInfo, ParameterUpdateInfo, LayerType, ShardingSource, ShardingDim), refactored ShardingConfig interface, and unified sharding path (shard_weight_tensor); significant logic density around Mamba-specific sharding and parameter update application.
    • tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py — Large migration from TP-centric to weight-based transforms; new detection functions (detect_ssm_shard, process*_sharding) with complex graph traversal and subgraph extraction logic; changes to ShardingTransformExecutor._apply control flow.
    • tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py — BFS return value destructuring pattern changes require verification that discarded second element doesn't lose critical state; verify logic consistency in MatchMoePattern._apply where selected_experts assignment and subsequent usage align.

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Description Check ⚠️ Warning The PR description is substantially incomplete. While the template structure is present, the critical sections are not filled in—the "## Description" and "## Test Coverage" sections contain only template placeholder comments with no substantive content explaining what changes are being made, why they are necessary, or which tests validate the implementation. Although a single checklist item is marked as complete, this does not compensate for the missing core explanation and test coverage information, which are essential for understanding and reviewing the PR's scope and validation strategy. Please fill in the Description section with a clear explanation of the problem being solved and how the changes address it, specifically describing the Nemotron H MoE sharding feature and the refactoring of the sharding infrastructure from TP-centric to weight-focused approaches. Additionally, provide a detailed Test Coverage section that lists the relevant tests (both new and existing) that validate these changes, such as tests in test_tp_sharding.py and any new MoE-specific sharding tests.
Docstring Coverage ⚠️ Warning Docstring coverage is 57.35% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (1 passed)
Check name Status Explanation
Title Check ✅ Passed The PR title "[TRTLLM-8201][feat] Nemotron H MoE Sharding" directly and specifically summarizes the main changes in this pull request. The title follows the required format with a valid JIRA ticket identifier, a feature type indicator, and a clear summary. The raw_summary confirms that the core changes involve adding MoE (Mixture of Experts) sharding support for Nemotron H models through a new patch file (nemotron_h.py) and significant refactoring of the sharding infrastructure to support weight-based sharding instead of tensor-parallel-only approaches. The title is concise, specific enough for developers to understand the primary change when scanning history, and fully captures the intent of the changeset.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 8

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (7)
tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py (1)

456-461: Fix incorrect tuple unpacking order.

The BFS function returns a tuple (Node, int), so the tuple unpacking must occur before accessing .args[0]. The current code attempts to unpack bfs(...).args[0], which will fail at runtime because you're trying to access .args[0] on the tuple returned by bfs, not on the node.

Apply this diff to fix the unpacking order:

-            selected_experts, _ = bfs(
+            selected_experts_node, _ = bfs(
                 common_ancessor2,
                 lambda node: is_op(node, torch.ops.aten.one_hot),
                 attr_next="all_input_nodes",
                 boundary=start_boundary,
-            ).args[0]
+            )
+            selected_experts = selected_experts_node.args[0]
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (2)

1-1: Add NVIDIA Apache-2.0 header (2025).

All source files must start with the NVIDIA Apache-2.0 copyright header per guidelines.

As per coding guidelines

Apply this patch:

+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#     http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.

159-185: Minor: assertion message wording.

The assertion says “linear node” though function supports other parametrized nodes (e.g., bmm). Adjust text to avoid confusion.

-    assert weight_node, "Cannot identify weight parameter of linear node."
+    assert weight_node, "Cannot identify weight parameter of the given node."
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (2)

1-1: Add NVIDIA Apache-2.0 header (2025).

Please prepend the required copyright header.

As per coding guidelines

+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#     http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.

904-907: Inference parameters should not require gradients.

Set requires_grad=False for parameter slices in BMM sharding.

-                param_new = nn.Parameter(slice_tensor(param).detach().clone(), requires_grad=True)
+                param_new = nn.Parameter(slice_tensor(param).detach().clone(), requires_grad=False)
tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (2)

1-1: Add NVIDIA Apache-2.0 header (2025).

Please prepend the required copyright header.

As per coding guidelines

+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#     http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.

689-701: Enum comparison bug prevents EP/BMM heuristic runs in partial config path.

Comparing to strings won't match List[ShardingDim]. Use enum members.

-        if "ep" in sharding_config.sharding_dims:
+        if ShardingDim.EP in sharding_config.sharding_dims:
@@
-        if "bmm" in sharding_config.sharding_dims:
+        if ShardingDim.BMM in sharding_config.sharding_dims:
🧹 Nitpick comments (11)
tensorrt_llm/_torch/auto_deploy/transform/interface.py (1)

176-178: Consider documenting the __add__ semantics.

The __add__ operator delegates to __and__, which applies AND-semantics for merging (both is_clean and has_valid_shapes must be True). While this is consistent with combining transform results sequentially, the + operator traditionally suggests accumulation rather than logical AND. Consider adding a docstring to clarify this behavior for maintainers.

+    def __add__(self, other: "TransformInfo") -> "TransformInfo":
+        """Merge transform info using AND-semantics (alias for __and__).
+        
+        This is useful for sequential transform composition where both
+        transforms must succeed for the combined result to be considered clean.
+        """
+        return self.__and__(other)
-    # implement + addition operator for TransformInfo
-    def __add__(self, other: "TransformInfo") -> "TransformInfo":
-        return self.__and__(other)
tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py (1)

147-169: Consider integrating into the export patch system.

The configuration patching logic correctly injects the base_model_tp_plan for NemotronH MoE models. However, as noted in the TODO comment, this manual monkey-patching approach should eventually be integrated into a more systematic export patch system.

For future work, consider:

  1. Creating a unified patching registry similar to CUSTOM_MODULE_PATCHES
  2. Applying patches through a decorator pattern rather than direct assignment
  3. Adding proper cleanup/restoration mechanisms for patches
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (2)

648-689: Potential infinite loop if terminating linear not found.

While loop increments start_lin_index until exactly one torch_linear_simple is found; if graph doesn’t contain this op, this can loop past bounds before the break condition is evaluated later through terminating_indices. Add a max-iterations guard or bail out when start_lin_index >= len(linear_nodes).

-    while len(lin_nodes_in_subgraph) != 1:
+    max_steps = len(linear_nodes)
+    while len(lin_nodes_in_subgraph) != 1:
         forward_subgraph = subgraph(
             sources=[linear_nodes[start_lin_index]], boundary_condition=is_linear_op
         )
         lin_nodes_in_subgraph = list(
             filtered_nodes(forward_subgraph, ops=torch.ops.auto_deploy.torch_linear_simple)
         )
         start_lin_index += 1
+        if start_lin_index >= max_steps:
+            raise RuntimeError("Failed to find terminating linear node (torch_linear_simple).")

413-447: LGTM with a small nit.

bfs API/behavior is clear. Optionally add cur_node to visited to prevent re-processing across level boundaries.

tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (3)

46-54: Use dict.get and early-return to avoid double lookup and KeyError.

Simplifies code and fixes Ruff RUF019.

-    key = prefix + param_key
-    ad_logger.debug(f"Sharder LOAD hook is called for '{key}'")
-    if key not in state_dict:
-        return
-    p_to_load = state_dict[key]
+    key = prefix + param_key
+    ad_logger.debug(f"Sharder LOAD hook is called for '{key}'")
+    p_to_load = state_dict.get(key)
+    if p_to_load is None:
+        return

501-509: Preserve original args in debug log and simplify meta check.

Small correctness/clarity improvement and fixes RUF019 pattern.

-    if "sharded" in node.meta and node.meta["sharded"]:
+    if node.meta.get("sharded"):
         return
-    node.args = args
-    node.meta["sharded"] = True
-    ad_logger.debug(
-        f"Updated node {node}: replaced original arguments {node.args} with sharded arguments {args}."
-    )
+    old_args = node.args
+    node.args = args
+    node.meta["sharded"] = True
+    ad_logger.debug(f"Updated node {node}: replaced original arguments {old_args} with sharded arguments {args}.")

417-485: LGTM with a caution on num_users gating.

Skipping nodes with num_users==0 can occur if quantized graphs obscure get_attr. Consider logging node/weight names at debug to aid triage.

tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (4)

321-324: SSM path: include aten.conv1d variant too.

Match both auto_deploy and aten conv1d variants, consistent with utils.

-    conv1d_nodes = [
-        n for n in subgraph_nodes if is_op(n, [torch.ops.auto_deploy.torch_causal_conv1d])
-    ]
+    conv1d_nodes = [
+        n
+        for n in subgraph_nodes
+        if is_op(n, [torch.ops.aten.conv1d, torch.ops.auto_deploy.torch_causal_conv1d])
+    ]

452-456: Guard BFS boundary lookup.

If next linear isn’t found, subgraph() will raise later. Match the try/except pattern used elsewhere.

-    next_lin_node, _ = bfs(linear_nodes[0], is_any_lin_op, include_root=False)
+    try:
+        next_lin_node, _ = bfs(linear_nodes[0], is_any_lin_op, include_root=False)
+    except RuntimeError:
+        ad_logger.warning("Could not find next linear node after entry_node; skipping column shard updates")
+        return

379-381: Nit: Prefer next(iter(...)) over list(...)[0].

Saves allocation and matches Ruff suggestion.

-            WeightShardingInfo.from_node(
-                list(weight_node.users)[0],
+            WeightShardingInfo.from_node(
+                next(iter(weight_node.users)),

Based on static analysis hints


920-922: Nit: Prefer next(iter(...)) over first list element.

Small efficiency/readability win.

-        nodes_to_column_shard = list(nodes_linear.values())[0]
+        nodes_to_column_shard = next(iter(nodes_linear.values()))

Based on static analysis hints

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f2faf28 and e1f848e.

📒 Files selected for processing (11)
  • tensorrt_llm/_torch/auto_deploy/config/default.yaml (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/interface.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py (3 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py (3 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (20 hunks)
  • tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (6 hunks)
  • tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (18 hunks)
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (5 hunks)
🧰 Additional context used
📓 Path-based instructions (3)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Use only spaces, no tabs; indent with 4 spaces.

Files:

  • tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py
  • tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py
  • tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py
  • tensorrt_llm/_torch/auto_deploy/transform/interface.py
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
  • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py
  • tensorrt_llm/_torch/auto_deploy/utils/node_utils.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+.
Indent Python code with 4 spaces; do not use tabs.
Maintain module namespace when importing; prefer 'from package.subpackage import foo' then 'foo.SomeClass()' instead of importing the class directly.
Python filenames should be snake_case (e.g., some_file.py).
Python classes use PascalCase names.
Functions and methods use snake_case names.
Local variables use snake_case; prefix 'k' for variables that start with a number (e.g., k_99th_percentile).
Global variables use upper SNAKE_CASE prefixed with 'G' (e.g., G_MY_GLOBAL).
Constants use upper SNAKE_CASE (e.g., MY_CONSTANT).
Avoid shadowing variables from an outer scope.
Initialize all externally visible members of a class in the constructor.
Prefer docstrings for interfaces that may be used outside a file; comments for in-function or file-local interfaces.
Use Google-style docstrings for classes and functions (Sphinx-parsable).
Document attributes and variables inline so they render under the class/function docstring.
Avoid reflection when a simpler, explicit approach suffices (e.g., avoid dict(**locals()) patterns).
In try/except, catch the most specific exceptions possible.
For duck-typing try/except, keep the try body minimal and use else for the main logic.

Files:

  • tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py
  • tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py
  • tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py
  • tensorrt_llm/_torch/auto_deploy/transform/interface.py
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
  • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py
  • tensorrt_llm/_torch/auto_deploy/utils/node_utils.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py
**/*.{cpp,cxx,cc,h,hpp,hh,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend the NVIDIA Apache-2.0 copyright header with current year to the top of all source files (e.g., .cpp, .h, .cu, .py).

Files:

  • tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py
  • tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py
  • tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py
  • tensorrt_llm/_torch/auto_deploy/transform/interface.py
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
  • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py
  • tensorrt_llm/_torch/auto_deploy/utils/node_utils.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py
🧬 Code graph analysis (9)
tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py (1)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (3)
  • extract_param_names_from_node (159-184)
  • is_linear_op (278-288)
  • is_op (197-220)
tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py (1)
tensorrt_llm/models/automodel.py (1)
  • AutoConfig (10-49)
tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py (1)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)
  • extract_param_names_from_node (159-184)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (1)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)
  • extract_param_names_from_node (159-184)
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (1)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (1)
  • WeightShardingInfo (566-626)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (2)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (6)
  • bfs (414-447)
  • extract_param_names_from_node (159-184)
  • is_any_lin_op (274-275)
  • is_op (197-220)
  • num_users_of_weight_node (153-156)
  • subgraph (562-645)
tensorrt_llm/logger.py (2)
  • debug (144-145)
  • warning (132-133)
tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py (1)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)
  • bfs (414-447)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)
tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py (1)
  • target (382-383)
tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (3)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (6)
  • bfs (414-447)
  • extract_weight_node (109-150)
  • filtered_nodes (223-271)
  • is_any_lin_op (274-275)
  • is_op (197-220)
  • subgraph (562-645)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (15)
  • BMMShardingInfo (841-940)
  • EPShardingInfo (1101-1124)
  • LayerType (559-563)
  • ParameterUpdateInfo (629-643)
  • ShardingConfig (1208-1317)
  • ShardingDim (1199-1205)
  • ShardingSource (1192-1196)
  • ShardingTransformInfo (523-556)
  • SplitDimension (512-520)
  • WeightShardingInfo (566-626)
  • get_all_weights_in_subgraph (218-224)
  • add (1247-1266)
  • from_node (577-582)
  • from_node (1108-1113)
  • validate_config (1268-1314)
tensorrt_llm/_torch/auto_deploy/transform/interface.py (2)
  • TransformInfo (121-178)
  • get (523-525)
🪛 Ruff (0.14.2)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py

92-92: Unnecessary key check before dictionary access

Replace with dict.get

(RUF019)


135-135: Unused function argument: custom_shard_fn

(ARG001)


233-233: Unused function argument: add_dist

(ARG001)


238-238: Unused function argument: quantization_cb

(ARG001)


264-264: Unpacked variable depth is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


503-503: Unnecessary key check before dictionary access

Replace with dict.get

(RUF019)


637-637: Unused method argument: gm

(ARG002)


641-641: Unused method argument: gm

(ARG002)


1259-1259: Avoid specifying long messages outside the exception class

(TRY003)

tensorrt_llm/_torch/auto_deploy/utils/node_utils.py

619-619: Avoid specifying long messages outside the exception class

(TRY003)

tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py

379-379: Prefer next(iter(weight_node.users)) over single element slice

Replace with next(iter(weight_node.users))

(RUF015)


428-428: Unused function argument: gm

(ARG001)


484-484: Prefer next(iter(node.users)) over single element slice

Replace with next(iter(node.users))

(RUF015)


920-920: Prefer next(iter(nodes_linear.values())) over single element slice

Replace with next(iter(nodes_linear.values()))

(RUF015)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (28)
tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py (2)

320-322: LGTM!

Correctly unpacks the BFS tuple return value (Node, int) and discards the depth.


386-386: LGTM!

Correctly unpacks the BFS tuple return value (Node, int) and discards the depth.

tensorrt_llm/_torch/auto_deploy/config/default.yaml (1)

80-81: LGTM!

The configuration updates enable heuristic-based sharding and partial config support, aligning with the weight-focused sharding approach introduced in this PR.

tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py (3)

16-16: LGTM!

The import update aligns with the generalized parameter extraction API that now supports any parametrized node, not just linear nodes.


39-39: LGTM!

Correctly uses the renamed extract_param_names_from_node function.


131-131: LGTM!

Correctly uses the renamed extract_param_names_from_node function.

tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py (2)

17-17: LGTM!

The import update aligns with the generalized parameter extraction API.


139-139: LGTM!

Correctly uses the renamed extract_param_names_from_node function.

tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (2)

11-11: LGTM!

The import update aligns with the generalized parameter extraction API.


120-120: LGTM!

Correctly uses the renamed extract_param_names_from_node function.

tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (5)

18-18: LGTM!

The import update reflects the migration from TP-centric to weight-focused sharding model.


275-283: LGTM!

Correctly constructs WeightShardingInfo with all required fields, preserving the expected test behavior with the new API.


296-304: LGTM!

Correctly constructs WeightShardingInfo with all required fields, preserving the expected test behavior with the new API.


310-318: LGTM!

Correctly constructs WeightShardingInfo with all required fields, preserving the expected test behavior with the new API.


354-354: LGTM!

The property access correctly updates from tp_transforms to weight_sharding_transforms to align with the new weight-based sharding model.

tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py (1)

10-10: LGTM!

The AutoConfig import is required for the new configuration patching logic.

tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (5)

153-157: LGTM.

num_users_of_weight_node correctly delegates and guards None.


274-276: LGTM.

is_any_lin_op provides a clean predicate wrapper.


521-534: LGTM.

predecessors now dedups via seen; good.


548-560: LGTM.

successors uses seen to avoid duplicates.


691-699: LGTM.

draw_graph helper is fine; IO guarded via with.

tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (4)

566-627: LGTM.

WeightShardingInfo surface and apply routing look good.


628-644: LGTM.

ParameterUpdateInfo provides a clean post-shard arg update hook.


1191-1206: LGTM.

New enums for sharding sources and dims are clear.


1228-1266: Transform de-dup is sensible.

add() prevents duplicate targets per transform class; good.

tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (3)

91-106: LGTM.

Executor applies TP/EP/BMM and post-update transforms cleanly; warnings for invalid updates are useful.


114-138: LGTM.

_simple_shard path uses ShardingConfig.add() effectively to avoid duplicates.


214-242: LGTM.

Source routing (FACTORY vs HEURISTIC) is clear and ordered.

@lucaslie lucaslie force-pushed the ll/gk/sharding_mamba_rebased branch from 36b38d8 to 92b808b Compare October 29, 2025 05:09
@lucaslie
Copy link
Member Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #22849 [ run ] triggered by Bot. Commit: 92b808b

@greg-kwasniewski1 greg-kwasniewski1 requested a review from a team as a code owner October 29, 2025 10:01
@tensorrt-cicd
Copy link
Collaborator

PR_Github #22849 [ run ] completed with state SUCCESS. Commit: 92b808b
/LLM/main/L0_MergeRequest_PR pipeline #17233 completed with status: 'FAILURE'

@lucaslie lucaslie force-pushed the ll/gk/sharding_mamba_rebased branch from 80727ac to 4b3cde4 Compare October 30, 2025 00:12
@lucaslie
Copy link
Member Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #22938 [ run ] triggered by Bot. Commit: 4b3cde4

@tensorrt-cicd
Copy link
Collaborator

PR_Github #22938 [ run ] completed with state SUCCESS. Commit: 4b3cde4
/LLM/main/L0_MergeRequest_PR pipeline #17298 completed with status: 'FAILURE'

lucaslie and others added 8 commits October 30, 2025 12:40
Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
@lucaslie lucaslie force-pushed the ll/gk/sharding_mamba_rebased branch from 4b3cde4 to 47ffe26 Compare October 30, 2025 19:41
@lucaslie
Copy link
Member Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #23085 [ run ] triggered by Bot. Commit: 47ffe26

@tensorrt-cicd
Copy link
Collaborator

PR_Github #23085 [ run ] completed with state SUCCESS. Commit: 47ffe26
/LLM/main/L0_MergeRequest_PR pipeline #17410 completed with status: 'FAILURE'

@lucaslie
Copy link
Member Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #23121 [ run ] triggered by Bot. Commit: 47ffe26

@tensorrt-cicd
Copy link
Collaborator

PR_Github #23121 [ run ] completed with state FAILURE. Commit: 47ffe26
/LLM/main/L0_MergeRequest_PR pipeline #17436 completed with status: 'FAILURE'

@lucaslie
Copy link
Member Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #23128 [ run ] triggered by Bot. Commit: 47ffe26

@tensorrt-cicd
Copy link
Collaborator

PR_Github #23128 [ run ] completed with state SUCCESS. Commit: 47ffe26
/LLM/main/L0_MergeRequest_PR pipeline #17443 completed with status: 'FAILURE'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants