diff --git a/tripy/.github/skills/tripy-compilation/SKILL.md b/tripy/.github/skills/tripy-compilation/SKILL.md new file mode 100644 index 000000000..2292c3976 --- /dev/null +++ b/tripy/.github/skills/tripy-compilation/SKILL.md @@ -0,0 +1,223 @@ +--- +name: tripy-compilation +description: 'Work with the nvtripy compilation pipeline. Use when: using tp.compile, creating InputInfo or DimensionInputInfo, understanding the Trace → MLIR → TensorRT flow, configuring optimization levels, working with Executable objects, debugging compilation, using dynamic shapes or NamedDimension.' +--- + +# nvtripy Compilation Pipeline + +## When to Use + +- Compiling functions or modules with `tp.compile()` +- Defining runtime inputs with `InputInfo` and `DimensionInputInfo` +- Working with compiled `Executable` objects +- Configuring compilation options (optimization level, timing cache) +- Understanding the Trace → MLIR → TensorRT flow +- Using dynamic shapes with min/opt/max bounds +- Debugging compilation failures + +## Compilation Flow + +``` +User Function → Trace (graph) → MLIR (IR) → TensorRT (engine) → Executable +``` + +1. **Trace**: The function is called with tracer tensors to record the computation graph +2. **MLIR**: The trace graph is lowered to MLIR using the `tensorrt` dialect +3. **TensorRT**: MLIR is compiled to a TensorRT engine +4. **Executable**: The engine is wrapped in a callable `Executable` object + +## `tp.compile()` — The Main Entry Point + +```python +compiled_fn = tp.compile( + func, # Function or Module to compile + optimization_level=3, # 0-5, higher = better runtime, longer compile + args=[...], # Positional arguments + kwargs={...}, # Keyword arguments +) +``` + +### Argument Types + +| Argument Type | Behavior | +|---------------|----------| +| `InputInfo(shape, dtype)` | Becomes a runtime input to the executable | +| `DimensionInputInfo(value_bounds)` | Becomes a runtime scalar dimension input | +| `Tensor` | Baked in as a compile-time constant | +| Any other type | Baked in as a compile-time constant | + +The compiled `Executable` only accepts parameters that were `InputInfo`/`DimensionInputInfo` in the original `compile()` call. + +## `InputInfo` — Tensor Runtime Inputs + +```python +# Static shape +inp = tp.InputInfo(shape=(2, 4), dtype=tp.float32) +# shape_bounds: min=(2,4), opt=(2,4), max=(2,4) + +# Dynamic dimensions (min, opt, max) +inp = tp.InputInfo(shape=((1, 2, 3), 4), dtype=tp.float32) +# First dim: min=1, opt=2, max=3; second dim: fixed at 4 +# shape_bounds: min=(1,4), opt=(2,4), max=(3,4) + +# Named dimensions (must be equal at runtime) +window = tp.NamedDimension("window", 3, 5, 7) +inp = tp.InputInfo(shape=(1, window, window), dtype=tp.float32) +# Both dims named "window" must have the same value at runtime +``` + +### `DimensionInputInfo` — Scalar Dimension Inputs + +For functions that take scalar shape values as parameters: + +```python +dim_info = tp.DimensionInputInfo(value_bounds=(1, 2, 4)) +# min=1, opt=2, max=4 +``` + +Used when a function parameter controls a reshape or dynamic shape operation. + +## `Executable` — Running Compiled Functions + +```python +# The executable's signature matches the InputInfo parameters +compiled_fn = tp.compile(add, args=[ + tp.InputInfo((2, 4), dtype=tp.float32), # "a" + tp.InputInfo((2, 4), dtype=tp.float32), # "b" +]) + +# Call with evaluated tensors +a = tp.ones((2, 4), dtype=tp.float32).eval() +b = tp.ones((2, 4), dtype=tp.float32).eval() +result = compiled_fn(a, b) +``` + +### Key `Executable` properties + +- `input_infos`: Dict of parameter name → `InputInfo` +- `stream`: The CUDA stream used for execution +- `__signature__`: Compatible with `inspect.signature()` for introspection + +### Important: `.eval()` for inputs + +Runtime inputs to compiled functions should be evaluated tensors (not lazy). Use `.eval()` to force evaluation before passing to the executable. + +## Compiling Modules + +```python +class MyModel(tp.Module): + def __init__(self): + super().__init__() + self.linear = tp.Linear(3, 4) + + def forward(self, x): + return self.linear(x) + +model = MyModel() +# Load real weights before compiling +model.linear.weight = tp.Tensor(weight_data) +model.linear.bias = tp.Tensor(bias_data) + +compiled_model = tp.compile( + model, + args=[tp.InputInfo(shape=(2, 3), dtype=tp.float32)], +) +``` + +When compiling a `Module`: +- The module's `state_dict()` entries are named for readable traces +- Weights become compile-time constants (baked into the engine) +- Only `InputInfo` arguments become runtime inputs + +## Dynamic Shapes + +### Basic dynamic dimensions + +```python +compiled_add = tp.compile( + add, + args=[ + tp.InputInfo(shape=((1, 2, 3), 2), dtype=tp.float32), + tp.InputInfo(shape=((1, 2, 3), 2), dtype=tp.float32), + ], +) + +# Works for any first-dim size in [1, 3]: +small = compiled_add(tp.ones((1, 2)).eval(), tp.ones((1, 2)).eval()) +big = compiled_add(tp.ones((3, 2)).eval(), tp.ones((3, 2)).eval()) +``` + +### Named dimensions for constraints + +```python +window_size = tp.NamedDimension("window_size", 3, 5, 7) +inp = tp.InputInfo((1, window_size, window_size), dtype=tp.float32) +# Both dimensions named "window_size" must be equal at runtime +``` + +### Scalar dimension inputs + +```python +def dynamic_reshape(x, s): + return tp.reshape(x, (-1, s)) + +compiled_reshape = tp.compile( + dynamic_reshape, + args=[ + tp.InputInfo(shape=(3, (2, 4, 6)), dtype=tp.float32), + tp.DimensionInputInfo(value_bounds=(1, 2, 4)), + ], +) + +result = compiled_reshape(tp.ones((3, 4)).eval(), tp.DimensionSize(2)) +assert result.shape == (6, 2) +``` + +## Compilation Options + +### Optimization Level + +| Level | Description | +|-------|-------------| +| 0 | Minimal optimization, fastest compile | +| 1–2 | Moderate optimization | +| 3 | Default — good balance | +| 4–5 | Maximum optimization, slowest compile | + +### Timing Cache + +```python +tp.config.timing_cache_file_path = "/path/to/cache" +``` + +The timing cache stores kernel profiling data across compilations, significantly speeding up repeated compilations with similar operations. + +## Compiler Internals + +The MLIR compiler (`nvtripy/backend/mlir/compiler.py`) uses these options: + +- `--tensorrt-timing-cache-path`: Path to timing cache +- `--tensorrt-builder-opt-level`: Optimization level (0-5) +- `--force-entrypoints-return-allocs`: Memory management +- `--mlir-elide-elementsattrs-if-larger`: Debug readability +- `--tensorrt-layer-info-dir`: TensorRT layer debug info + +## Function Requirements for `tp.compile` + +The function passed to `compile()` must: + +1. **Be pure** — no side effects (`print`, `assert`, file I/O) +2. **Return Tensor(s)** — only `Tensor` return types supported +3. **No collection inputs** — `List[Tensor]` or `Dict[str, Tensor]` will be frozen as constants +4. **No variadic args** — `*args` and `**kwargs` are frozen at compile time + +## Checklist + +- [ ] `InputInfo` used for all runtime tensor inputs +- [ ] `DimensionInputInfo` used for scalar shape parameters +- [ ] Dynamic dimension bounds specified as `(min, opt, max)` tuples +- [ ] Module weights loaded before calling `tp.compile()` +- [ ] Function is pure (no side effects) +- [ ] Runtime inputs `.eval()`'d before passing to executable +- [ ] Timing cache configured for repeated compilations +- [ ] Optimization level appropriate for use case diff --git a/tripy/.github/skills/tripy-constraints/SKILL.md b/tripy/.github/skills/tripy-constraints/SKILL.md new file mode 100644 index 000000000..67675abe3 --- /dev/null +++ b/tripy/.github/skills/tripy-constraints/SKILL.md @@ -0,0 +1,186 @@ +--- +name: tripy-constraints +description: 'Author input/output constraints for nvtripy operations using the declarative constraint DSL. Use when: defining input_requirements or output_guarantees, writing @wrappers.interface decorators, auto-casting dtypes, using GetInput/GetReturn/OneOf/If/Equal, debugging constraint validation errors.' +--- + +# Authoring Constraints for nvtripy Operations + +## When to Use + +- Defining type constraints for a new or existing operation +- Writing `input_requirements` or `output_guarantees` for `@wrappers.interface` +- Debugging constraint validation errors at runtime +- Understanding auto-type-casting behavior + +## Architecture Overview + +The constraint system lives in `nvtripy/frontend/constraints/` and consists of: + +- **Fetchers** (`fetcher.py`): Extract values from function arguments or return values +- **Logic** (`logic.py`): Compose constraints with boolean operators +- **Base** (`base.py`): Abstract base class for all constraints +- **Wrappers** (`nvtripy/frontend/wrappers.py`): The `@interface` decorator that applies constraints + +## Core Components + +### Fetchers — Extracting Values + +```python +from nvtripy.frontend.constraints import GetInput, GetReturn + +# Get a function parameter by name +GetInput("input") # The parameter named "input" +GetInput("dtype") # The parameter named "dtype" +GetInput("input").dtype # The dtype of the "input" parameter (uses GetDataType) + +# Get a return value by index +GetReturn(0) # First return value +GetReturn(0).dtype # Dtype of first return value +``` + +### Logic — Composing Constraints + +```python +from nvtripy.frontend.constraints import OneOf, If, GetInput, GetReturn + +# OneOf: value must be in a set +OneOf(GetInput("dtype"), [dt.float32, dt.float16, dt.bfloat16]) + +# Equal: two values must match +GetInput("weight").dtype == GetInput("input").dtype +GetReturn(0).dtype == GetInput("input").dtype + +# NotEqual +GetInput("dtype") != None + +# And: combine with & +OneOf(GetInput("input").dtype, [dt.float32, dt.float16]) +& (GetInput("weight").dtype == GetInput("input").dtype) + +# Or: combine with | +OneOf(GetInput("dtype"), [dt.float32]) | OneOf(GetInput("dtype"), [dt.float16]) + +# If: conditional constraint +If( + GetInput("dtype") != None, # condition + OneOf(GetInput("dtype"), [dt.float32]), # then: applied when condition is true + # else branch is optional +) + +# Invert with ~ +~OneOf(GetInput("dtype"), [dt.float32]) # dtype must NOT be float32 +``` + +### All Available Logic Classes + +| Class | Usage | Description | +|-------|-------|-------------| +| `OneOf(fetcher, options)` | `OneOf(GetInput("x").dtype, [dt.float32, dt.float16])` | Value must be in the list | +| `Equal` | `GetInput("a").dtype == GetInput("b").dtype` | Two values must be equal (created via `==`) | +| `NotEqual` | `GetInput("dtype") != None` | Two values must not be equal (created via `!=`) | +| `And` | `constraint1 & constraint2` | Both must be satisfied (created via `&`) | +| `Or` | `constraint1 \| constraint2` | At least one must be satisfied (created via `\|`) | +| `If(cond, then, else_)` | `If(GetInput("dtype") != None, then_constraint)` | Conditional constraint | +| `AlwaysTrue` | `AlwaysTrue()` | Always passes | +| `AlwaysFalse` | `AlwaysFalse()` | Always fails | + +## Using `@wrappers.interface` + +The `@wrappers.interface` decorator from `nvtripy/frontend/wrappers.py` accepts: + +```python +@wrappers.interface( + input_requirements=, # Pre-execution: validate inputs + output_guarantees=, # Post-execution: validate outputs + convert_to_tensors=True, # Auto-convert TensorLike to Tensor + conversion_preprocess_func=None, # Custom preprocessing before conversion +) +``` + +- **`input_requirements`**: Checked BEFORE the function runs. If a dtype mismatch is found and auto-casting can fix it, the system will automatically cast inputs. +- **`output_guarantees`**: Checked AFTER the function runs. Verifies the output properties match expectations. + +## Common Patterns + +### Simple dtype restriction + +```python +@wrappers.interface( + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16]), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, +) +def my_op(input: "nvtripy.Tensor") -> "nvtripy.Tensor": +``` + +### Multiple inputs with matching dtypes + +```python +@wrappers.interface( + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16]) + & (GetInput("weight").dtype == GetInput("input").dtype) + & (GetInput("bias").dtype == GetInput("input").dtype), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, +) +def layernorm(input, weight, bias, eps): +``` + +### Optional dtype parameter + +```python +@wrappers.interface( + input_requirements=OneOf( + GetInput("input").dtype, + [dt.float32, dt.float16, dt.bfloat16, dt.float8, dt.int8, dt.int32, dt.int64, dt.bool], + ) + & If( + GetInput("dtype") != None, + OneOf(GetInput("dtype"), [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64, dt.bool]), + ), + output_guarantees=If( + GetInput("dtype") != None, + GetReturn(0).dtype == GetInput("dtype"), + GetReturn(0).dtype == GetInput("input").dtype, + ), +) +def ones_like(input, dtype=None): +``` + +### Initializer ops (no tensor inputs, just dtype) + +```python +@wrappers.interface( + input_requirements=OneOf( + GetInput("dtype"), [dt.float32, dt.float16, dt.bfloat16, dt.int8, dt.int32, dt.int64, dt.bool] + ), + output_guarantees=GetReturn(0).dtype == GetInput("dtype"), +) +def ones(shape, dtype=dt.float32): +``` + +## How Auto-Casting Works + +When `input_requirements` include dtype constraints via `OneOf`: + +1. The system checks if all inputs satisfy constraints +2. If a dtype mismatch is found, it looks for a valid target dtype from the `OneOf` options +3. Inputs are automatically cast to the matching dtype before the function executes + +This means users don't need to manually cast, e.g., `tp.ones((2,), dtype=tp.float16) + tp.ones((2,), dtype=tp.float32)` will auto-cast. + +## Constraint Error Messages + +When constraints fail, the system generates an error like: +``` +Expected 'input' to be one of [float32, float16, bfloat16] (but it was 'int32') +``` + +The error text comes from the `__str__` and `doc_str` methods of each `Logic` class. + +## Checklist + +- [ ] `input_requirements` covers all valid input dtypes with `OneOf` +- [ ] Multi-input ops require matching dtypes with `==` constraints +- [ ] Optional parameters guarded with `If(GetInput("x") != None, ...)` +- [ ] `output_guarantees` specify the output dtype relationship +- [ ] `&` used to combine multiple requirements (not nested `And()` calls) +- [ ] Test both valid and invalid dtype combinations diff --git a/tripy/.github/skills/tripy-debugging/SKILL.md b/tripy/.github/skills/tripy-debugging/SKILL.md new file mode 100644 index 000000000..da08a81ff --- /dev/null +++ b/tripy/.github/skills/tripy-debugging/SKILL.md @@ -0,0 +1,210 @@ +--- +name: tripy-debugging +description: 'Debug and diagnose errors in nvtripy code. Use when: interpreting TripyException stack traces, enabling MLIR/TensorRT debug output, understanding error reporting with stack_info, using raise_error, configuring debug environment variables, tracing compilation failures.' +--- + +# Debugging and Error Reporting in nvtripy + +## When to Use + +- Interpreting `TripyException` error messages and stack traces +- Enabling debug output for MLIR or TensorRT compilation +- Understanding the stack info system for precise error locations +- Adding error handling to new ops or modules +- Diagnosing compilation or runtime failures + +## Error Reporting System + +### `raise_error` — The Primary Error Function + +From `nvtripy/common/exception.py`: + +```python +from nvtripy.common.exception import raise_error + +raise_error( + "Brief description of the error.", + details=[ + "Additional context line 1.", + "Additional context line 2.", + some_tensor, # Will include the tensor's creation stack info + ], +) +``` + +- First argument: The main error message (string) +- `details`: A list of strings and/or tensors. Tensors will have their stack info rendered. +- Raises `TripyException` + +### Stack Info System + +Every tensor captures its creation stack trace (`_stack_info`) for precise error reporting. This is managed in `nvtripy/utils/stack_info.py`. + +When a tensor appears in `raise_error` details, the system renders the exact line and column where the tensor was created, helping users trace back to the problematic code. + +Key points: +- `Tensor.from_trace_tensor(out, include_code_index=stack_depth)` — sets the stack depth for error reporting +- `STACK_DEPTH_OF_FROM_TRACE_TENSOR = 4` — the default depth in `create_op` +- `stack_depth_offset` parameter in `create_op` adjusts for wrapper functions + +### Exception Hierarchy + +``` +TripyException (main user-facing exception) +└── Raised by raise_error() with formatted stack info +``` + +## Debug Configuration + +### Environment Variables + +Set these before running to enable debug output: + +| Variable | Default | Description | +|----------|---------|-------------| +| `TRIPY_MLIR_DEBUG_ENABLED` | `"0"` | Enable MLIR debug output | +| `TRIPY_MLIR_DEBUG_TYPES` | `"-translate-to-tensorrt"` | Comma-separated MLIR pass types to debug | +| `TRIPY_MLIR_DEBUG_PATH` | `"/tripy/mlir-dumps"` | Directory for MLIR debug dumps | +| `TRIPY_TRT_DEBUG_ENABLED` | `"0"` | Enable TensorRT debug output | +| `TRIPY_TRT_DEBUG_PATH` | `"/tripy/tensorrt-dumps"` | Directory for TensorRT debug dumps | +| `TRIPY_EXTRA_ERROR_INFORMATION` | `""` | Comma-separated extra error info | + +### Runtime Configuration + +From `nvtripy/config.py`: + +```python +import nvtripy as tp + +# Timing cache (speeds up repeated compilations) +tp.config.timing_cache_file_path # Default: /tmp/tripy-cache + +# Input validation (disable for performance in production) +tp.config.enable_input_validation = False + +# Extra error information +tp.config.extra_error_information = ["detailed"] +``` + +### Test Helper for Config Changes + +```python +from tests import helper + +with helper.config("enable_input_validation", False): + # Code runs with validation disabled + ... +# Automatically restored after the block +``` + +## Logging System + +The logging system (`nvtripy/logging/`) provides granular control: + +```python +import nvtripy as tp + +# The global logger +logger = tp.logger + +# Set verbosity for specific modules +logger.verbosity_trie.set("nvtripy.backend", "verbose") +``` + +The `VerbosityTrie` allows setting different log levels for different module paths, using a trie data structure for efficient prefix matching. + +## Diagnosing Common Issues + +### Constraint Validation Errors + +Error pattern: `Expected 'input' to be one of [...] (but it was '...')` + +This comes from the constraint system. Check: +1. The `input_requirements` in the `@wrappers.interface` decorator +2. The actual dtypes of the inputs being passed +3. Whether auto-casting should handle this case + +### Compilation Errors + +Enable MLIR debug output: +```bash +TRIPY_MLIR_DEBUG_ENABLED=1 python my_script.py +``` + +Check the dumps in `/tripy/mlir-dumps/` for the MLIR IR at each pass. + +### Shape Mismatch Errors + +The trace system tracks shapes through `infer_rank` on trace ops. Check: +1. The `infer_rank` policy on the trace op +2. Whether broadcasting is handled correctly +3. Dynamic dimensions (`DYNAMIC_DIM = -1`) vs static shapes + +### Runtime Errors from Compiled Functions + +Enable TensorRT debug output: +```bash +TRIPY_TRT_DEBUG_ENABLED=1 python my_script.py +``` + +Check: +- Input shapes fall within the `InputInfo` bounds +- Dynamic shapes are correctly configured with min/opt/max + +## Adding Error Handling to New Code + +### In Frontend Ops + +```python +def my_op(input, dim): + if input.rank < 2: + raise_error( + f"Input must have rank >= 2, but got rank: {input.rank}", + details=[ + "Input is expected to have shape (N, *) where N is the batch size.", + input, # This renders the tensor's creation location + ], + ) +``` + +### In Modules + +```python +def forward(self, x): + if self.quant_dtype is not None and self.weight_quant_dim == 1: + raise_error( + "Unsupported quantization parameters.", + [ + "weight_quant_dim cannot be 1 when input_scale is provided.", + f"input_scale={self.input_scale}, weight_quant_dim={self.weight_quant_dim}", + ], + ) +``` + +## Testing Errors + +Use the `helper.raises` context manager: + +```python +from tests import helper +import nvtripy as tp + +def test_invalid_dtype_fails(): + a = tp.Tensor([1.0, 2.0]) + b = tp.ones((2,), dtype=tp.float16) + with helper.raises(tp.TripyException, match="Expected.*one of"): + c = a + b +``` + +The `raises` helper supports: +- `ExcType`: Expected exception type +- `match`: Regex pattern to match against error message +- `has_stack_info_for`: Verify that specific tensors' stack info appears in the error + +## Checklist + +- [ ] Use `raise_error()` instead of raw `raise` for user-facing errors +- [ ] Include relevant tensors in `details` for stack info rendering +- [ ] Error message is actionable (says what went wrong AND what to do) +- [ ] Test error cases with `helper.raises(tp.TripyException, match=...)` +- [ ] Debug env vars documented if adding new debug output diff --git a/tripy/.github/skills/tripy-documentation/SKILL.md b/tripy/.github/skills/tripy-documentation/SKILL.md new file mode 100644 index 000000000..4dc458452 --- /dev/null +++ b/tripy/.github/skills/tripy-documentation/SKILL.md @@ -0,0 +1,176 @@ +--- +name: tripy-documentation +description: 'Write API documentation for nvtripy following project conventions. Use when: writing docstrings for ops or modules, adding code examples, using @export.public_api document_under paths, creating Sphinx RST cross-references, understanding the docs build pipeline.' +--- + +# API Documentation for nvtripy + +## When to Use + +- Writing or updating docstrings for public API functions, classes, or modules +- Adding working code examples to documentation +- Choosing the correct `document_under` path for `@export.public_api` +- Understanding how documentation is generated and built + +## Documentation Pipeline + +1. **`@export.public_api(document_under="...")`** registers APIs in `PUBLIC_APIS` list +2. **`docs/generate_rsts.py`** reads `PUBLIC_APIS` and generates `.rst` files in the docs hierarchy +3. **Sphinx** builds the final HTML docs from those `.rst` files +4. Docstring code examples are extracted and validated during testing + +## `@export.public_api` Parameters + +```python +@export.public_api( + document_under="operations/functions", # Doc hierarchy path + autodoc_options=[":special-members:"], # Sphinx autodoc options + bypass_dispatch=True, # Skip function registry overhead +) +``` + +### `document_under` Path Conventions + +| Path | Use For | +|------|---------| +| `"operations/functions"` | General tensor operations (softmax, reshape, etc.) | +| `"operations/initializers"` | Tensor creation (ones, zeros, full, arange) | +| `"operations/modules"` | Neural network modules (Linear, LayerNorm, etc.) | +| `"compiling_code/compile.rst"` | Compilation-related APIs | +| `"compiling_code/input_info/index.rst"` | InputInfo and related classes | +| `"config.rst"` | Configuration variables | + +The path creates a directory structure: `"operations/functions"` → `operations/functions/.rst`. + +APIs targeting the same `.rst` file render on the same page. + +### `autodoc_options` + +- `[":special-members:"]` — Include `__init__`, `__call__`, etc. +- `[":no-members:", ":no-special-members:"]` — Show only the class/module itself +- `[":no-value:"]` — Hide the default value of a variable + +### `bypass_dispatch` + +- `True` on a function: Disables the function registry's overload dispatch and type-checking (performance optimization) +- `True` on a class: Bypass dispatch for ALL methods +- `["__init__", "__call__"]`: Bypass only for listed methods + +## Docstring Format + +### Functions + +```python +def my_op(input: "nvtripy.Tensor", dim: int = 0) -> "nvtripy.Tensor": + r""" + Brief one-line description of what the function does. + + Longer description with math if applicable: + + :math:`\text{my_op}(x) = f(x)` + + Args: + input: The input tensor. + dim: The dimension to operate along. + + Returns: + A tensor of the same shape as the input. + + .. code-block:: python + :linenos: + + input = tp.iota([2, 3], dtype=tp.float32) + output = tp.my_op(input, dim=0) + + assert tp.allclose(output, expected) + + .. seealso:: :func:`related_func`, :class:`RelatedClass` + """ +``` + +### Classes (Modules) + +```python +class MyModule(Module): + r""" + Brief description with math notation. + + :math:`\text{MyModule}(x) = xW^T + b` + """ + + dtype: datatype.dtype + r"""The data type used to perform the operation.""" + + weight: Tensor + r"""The :math:`W` parameter of shape :math:`[\text{out}, \text{in}]`.""" + + def __init__(self, features: int, dtype: datatype.dtype = datatype.float32) -> None: + r""" + Args: + features: Size of the feature dimension. + dtype: The data type for parameters. + + .. code-block:: python + :linenos: + + module = tp.MyModule(3) + module.weight = tp.iota(module.weight.shape) + + input = tp.iota((2, 3)) + output = module(input) + + assert cp.from_dlpack(output).get().shape == (2, 3) + + torch_out = torch.nn.functional.my_op(torch.from_dlpack(input)) # doc: omit + assert np.allclose(cp.from_dlpack(output).get(), cp.from_dlpack(torch_out).get()) + """ +``` + +## Code Example Conventions + +### Required elements + +1. **Setup**: Create inputs using `tp.iota()`, `tp.ones()`, `tp.zeros()`, or `tp.Tensor()` +2. **Operation**: Call the function/module under test +3. **Assertion**: Verify the result with `assert` (using `tp.allclose`, `np.array_equal`, or shape checks) + +### Available imports in code blocks + +Code examples automatically have access to: +- `tp` (nvtripy) +- `np` (numpy) +- `cp` (cupy) +- `torch` + +### Special directives + +- `# doc: omit` — Line is excluded from rendered documentation but still executes +- `# doc: no-print-locals ` — Suppresses automatic printing of the variable +- `:linenos:` — Always include for numbered lines + +### Cross-references + +- `:func:`function_name`` — Link to a function +- `:class:`ClassName`` — Link to a class +- `:math:`expression`` — Inline LaTeX math +- `.. seealso:: :func:`related`" — "See also" section at the end + +### Math notation + +Use `r"""` raw strings for docstrings containing `:math:` to avoid backslash issues. + +LaTeX examples: +- Inline: `:math:`\text{softmax}(x_{i})`` +- Block: Use `\Large` / `\normalsize` for fraction sizing +- Common: `:math:`\bar{x}`` (mean), `:math:`\sigma^2`` (variance), `:math:`\epsilon`` (epsilon) + +## Checklist + +- [ ] `@export.public_api(document_under="...")` with correct hierarchy path +- [ ] Docstring uses `r"""` if it contains `:math:` directives +- [ ] Args section documents all parameters with types +- [ ] Returns section describes output shape/type +- [ ] `.. code-block:: python` with `:linenos:` and working assertions +- [ ] `.. seealso::` links to related functions/classes +- [ ] Field docstrings for all dataclass fields on modules +- [ ] `# doc: omit` for verification-only lines that shouldn't appear in docs diff --git a/tripy/.github/skills/tripy-new-module/SKILL.md b/tripy/.github/skills/tripy-new-module/SKILL.md new file mode 100644 index 000000000..e20ab84bc --- /dev/null +++ b/tripy/.github/skills/tripy-new-module/SKILL.md @@ -0,0 +1,241 @@ +--- +name: tripy-new-module +description: 'Add a new neural network module to nvtripy. Use when: creating an nn layer, implementing a Module subclass, adding a new layer like Linear/LayerNorm/Conv, defining parameters with DefaultParameter or OptionalParameter, using constant_fields decorator.' +--- + +# Adding a New Module to nvtripy + +## When to Use + +- Creating a new neural network layer (e.g., normalization, attention, convolution) +- Implementing a `Module` subclass with learnable parameters +- Adding a module that wraps existing ops into a reusable component + +## Architecture Overview + +Modules live in `nvtripy/frontend/module/` and follow this pattern: + +1. **Optional helper function**: A standalone function (not exported) that implements the math, decorated with `@wrappers.interface` for constraints. +2. **Module class**: A `@dataclass` subclass of `Module` with `@export.public_api` and `@constant_fields`. +3. **Parameters**: Use `DefaultParameter` (must be set before use) or `OptionalParameter` (can be None). + +## Procedure + +### Step 1: Create the Module File + +Create `nvtripy/frontend/module/.py`: + +```python +from dataclasses import dataclass +from typing import Optional, Sequence, Union + +from nvtripy import export, utils +from nvtripy.common import datatype +from nvtripy.frontend import wrappers +from nvtripy.frontend.module.module import Module +from nvtripy.frontend.module.parameter import DefaultParameter, OptionalParameter +from nvtripy.frontend.tensor import Tensor +from nvtripy.frontend.wrappers import constant_fields +from nvtripy.frontend.ops import utils as op_utils + +# If needed, import the trace op: +from nvtripy.trace.ops.my_op import MyOp + +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf + + +# Optional: standalone function with constraints (used by the module's forward()) +@wrappers.interface( + input_requirements=OneOf(GetInput("input").dtype, [datatype.float32, datatype.float16, datatype.bfloat16]) + & (GetInput("weight").dtype == GetInput("input").dtype), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, +) +def my_layer_func( + input: "nvtripy.Tensor", + weight: "nvtripy.Tensor", + bias: "nvtripy.Tensor", + eps: float, +) -> "nvtripy.Tensor": + # Implementation using existing ops or create_op + return op_utils.create_op(MyOp, [input, weight, bias], eps=eps) + + +@export.public_api(document_under="operations/modules") +@dataclass +@constant_fields(["dtype"]) +class MyLayer(Module): + r""" + Brief math description of the layer. + + :math:`\text{MyLayer}(x) = f(x, W, b)` + """ + + dtype: datatype.dtype + r"""The data type used to perform the operation.""" + + weight: Tensor + r"""The weight parameter of shape :math:`[\text{features}]`.""" + + bias: Optional[Tensor] + r"""The bias parameter of shape :math:`[\text{features}]`.""" + + eps: float + """A small value for numerical stability.""" + + def __init__( + self, + features: int, + bias: bool = True, + dtype: datatype.dtype = datatype.float32, + eps: float = 1e-5, + ) -> None: + r""" + Args: + features: Size of the feature dimension. + bias: Whether to include a bias term. + dtype: The data type for parameters. + eps: Small constant for numerical stability. + + .. code-block:: python + :linenos: + + layer = tp.MyLayer(3) + + layer.weight = tp.iota(layer.weight.shape) + layer.bias = tp.iota(layer.bias.shape) + + input = tp.iota((2, 3), dim=1) + output = layer(input) + + assert cp.from_dlpack(output).get().shape == (2, 3) + """ + super().__init__() + + self.dtype = dtype + self.weight = DefaultParameter((features,), dtype=dtype) + + self.bias = None + if bias: + self.bias = DefaultParameter((features,), dtype=dtype) + + self.eps = eps + + def forward(self, x: "nvtripy.Tensor") -> "nvtripy.Tensor": + r""" + Args: + x: The input tensor. + + Returns: + The output tensor. + """ + return my_layer_func(x, self.weight, self.bias, self.eps) +``` + +### Step 2: Understand Parameter Types + +**`DefaultParameter(shape, dtype)`**: Creates a placeholder that MUST be replaced with real data before the module runs. Used for required weights: + +```python +self.weight = DefaultParameter((out_features, in_features), dtype=dtype) +``` + +**`OptionalParameter(shape, dtype)`**: Can be `None` — used for optional weights like quantization scales: + +```python +self.input_scale = OptionalParameter(shape=[], dtype=dtype) +``` + +### Step 3: Use `@constant_fields` + +The `@constant_fields(["field1", "field2"])` decorator marks fields as compile-time constants. These fields will be baked into the compiled graph and cannot change at runtime. Use for: + +- `dtype` — data type configuration +- `normalized_shape` — shape parameters that define the layer structure +- `quant_dtype` — quantization configuration + +### Step 4: Register the Module + +The `@export.public_api(document_under="operations/modules")` decorator handles registration. The module will be accessible as `tp.MyLayer(...)`. + +Ensure the module file is imported in `nvtripy/frontend/module/__init__.py`. + +## Complete Example: LayerNorm + +```python +# Helper function with constraints +@wrappers.interface( + input_requirements=OneOf(GetInput("input").dtype, [datatype.float32, datatype.float16, datatype.bfloat16]) + & (GetInput("weight").dtype == GetInput("input").dtype) + & (GetInput("bias").dtype == GetInput("input").dtype), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, +) +def layernorm(input, weight, bias, eps): + normalized_shape = weight.shape + D = len(normalized_shape) + input_rank = input.rank + + if input_rank < 2: + raise_error(f"Input must have rank >= 2, got {input.rank}") + + if input_rank > D: + broadcast_shape = (1,) * (input_rank - D) + normalized_shape + weight = reshape(weight, broadcast_shape) + bias = reshape(bias, broadcast_shape) + + return op_utils.create_op(LayerNormOp, [input, weight, bias], + normalized_shape=normalized_shape, eps=eps) + + +@export.public_api(document_under="operations/modules") +@dataclass +@constant_fields(["dtype", "normalized_shape"]) +class LayerNorm(Module): + dtype: datatype.dtype + normalized_shape: Sequence[int] + weight: Tensor + bias: Tensor + eps: float + + def __init__(self, normalized_shape, dtype=datatype.float32, eps=1e-5): + super().__init__() + self.dtype = dtype + if isinstance(normalized_shape, int): + normalized_shape = (normalized_shape,) + self.normalized_shape = normalized_shape + self.weight = DefaultParameter(normalized_shape, dtype=dtype) + self.bias = DefaultParameter(normalized_shape, dtype=dtype) + self.eps = eps + + def forward(self, x): + return layernorm(x, self.weight, self.bias, self.eps) +``` + +## Complete Example: Linear + +Key patterns from `Linear`: + +- Weight shape: `(out_features, in_features)` — transposed in `forward()` +- Optional bias with sentinel: `self.bias = DefaultParameter(...) if bias else None` +- Quantization support with `OptionalParameter` for scales +- Uses `@constant_fields(["dtype", "quant_dtype"])` for compile-time config + +## Module Base Class Features + +The `Module` base class (`nvtripy/frontend/module/module.py`) provides: + +- `state_dict()`: Recursively collects all `Tensor` parameters (supports nested modules, lists, dicts) +- `load_state_dict(state_dict, strict=True)`: Loads parameters with shape/dtype validation +- `__setattr__`: Validates parameter compatibility on assignment +- `__call__`: Calls `forward()` — modules are callable like functions + +## Checklist + +- [ ] Module file created in `nvtripy/frontend/module/` +- [ ] Inherits from `Module`, decorated with `@dataclass` and `@export.public_api` +- [ ] `@constant_fields` applied for compile-time configuration fields +- [ ] `__init__` calls `super().__init__()` and uses `DefaultParameter`/`OptionalParameter` +- [ ] `forward()` method implemented +- [ ] Docstrings with math notation and working code examples +- [ ] Helper function with `@wrappers.interface` constraints (if applicable) +- [ ] Registered in `nvtripy/frontend/module/__init__.py` +- [ ] Tests added in `tests/frontend/module/` diff --git a/tripy/.github/skills/tripy-new-operation/SKILL.md b/tripy/.github/skills/tripy-new-operation/SKILL.md new file mode 100644 index 000000000..c0c3101f8 --- /dev/null +++ b/tripy/.github/skills/tripy-new-operation/SKILL.md @@ -0,0 +1,178 @@ +--- +name: tripy-new-operation +description: 'Add a new operation to nvtripy. Use when: implementing a new op, adding a frontend op, creating a trace op, registering an op in the API. Covers the full Frontend → Trace → MLIR pipeline including export decorators, constraint definitions, and init registration.' +--- + +# Adding a New Operation to nvtripy + +## When to Use + +- Adding a new mathematical, tensor, or neural network operation +- Creating a new frontend function that maps to TensorRT/MLIR ops +- Extending the op registry with unary, binary, or custom operations + +## Architecture Overview + +Operations in nvtripy follow a **Frontend → Trace → MLIR** pipeline: + +1. **Trace Op** (`nvtripy/trace/ops/`): Defines the computational graph node — rank inference, dtype inference, and MLIR code generation. +2. **Frontend Op** (`nvtripy/frontend/ops/`): The public API function — exports, constraints, docstring, and bridges to the trace op via `create_op()`. +3. **Registration**: Both `__init__.py` files must be updated so the op is discoverable. + +## Procedure + +### Step 1: Create the Trace Operation + +Create a file in `nvtripy/trace/ops/.py`: + +```python +from dataclasses import dataclass + +import nvtripy.trace.ops.utils as op_utils +from mlir_tensorrt.compiler.dialects import tensorrt +from nvtripy.trace.ops.base import TraceOp + + +@dataclass(repr=False) +class MyOp(TraceOp): + # Add any op-specific parameters as dataclass fields: + dim: int + + # Choose a rank inference policy: + infer_rank = op_utils.InferRankPolicies.same_as_input() + + def to_mlir(self, inputs, outputs): + # Generate MLIR using the tensorrt dialect: + return [tensorrt.some_op(inputs[0], self.dim)] +``` + +**Key base class requirements** (from `TraceOp`): + +- `infer_rank` (required): Set output rank. Use policies from `InferRankPolicies`: + - `same_as_input(idx=0)` — output rank matches input[idx] + - `same_shape_as_input(idx=0)` — output has same shape (not just rank) + - `same_as_shape_of_shape_input(idx=0)` — rank from a shape tensor + - `max_of_inputs()` — rank is max across all inputs + - Or define a custom function +- `to_mlir(self, inputs, outputs)` (required): Return list of MLIR operations +- `infer_dtypes()` (optional): Default propagates from `inputs[0]`. Override for multi-dtype ops. +- `infer_devices()` (optional): Default sets all outputs to GPU. +- `get_num_outputs()` (optional): Default is 1. Override for multi-output ops. +- `str_skip_fields()` (optional): Fields to omit from string representation. + +**Factory pattern** for families of similar ops (see `trace/ops/unary.py`): + +```python +def make_unary_op(name, attr_name): + @dataclass(repr=False) + class UnaryOp(TraceOp): + infer_rank = op_utils.InferRankPolicies.same_as_input() + + def to_mlir(self, inputs, outputs): + return [tensorrt.unary(inputs[0], tensorrt.UnaryOperationAttr.get(attr_name))] + + UnaryOp.__name__ = name + return UnaryOp + +Exp = make_unary_op("Exp", "kEXP") +``` + +### Step 2: Create the Frontend Operation + +Create a file in `nvtripy/frontend/ops/.py`: + +```python +from typing import Optional + +from nvtripy import export +from nvtripy.common import datatype as dt +from nvtripy.frontend import wrappers +from nvtripy.frontend.constraints import GetInput, GetReturn, OneOf +from nvtripy.frontend.ops import utils as op_utils +from nvtripy.trace.ops.my_op import MyOp + + +@export.public_api(document_under="operations/functions") +@wrappers.interface( + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16]), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, +) +def my_op(input: "nvtripy.Tensor", dim: Optional[int] = None) -> "nvtripy.Tensor": + r""" + Brief description of what the op does. + + Args: + input: The input tensor. + dim: The dimension to operate on. + + Returns: + A tensor of the same shape as the input. + + .. code-block:: python + :linenos: + + input = tp.iota([2, 3], dtype=tp.float32) + output = tp.my_op(input, dim=0) + + assert tp.allclose(output, expected_tensor) + """ + dim = op_utils.process_dim(dim, input.rank) + return op_utils.create_op(MyOp, [input], dim=dim) +``` + +**Key decorator details:** + +- `@export.public_api(document_under="...")`: Registers in public API and docs hierarchy. Common paths: + - `"operations/functions"` — general tensor ops + - `"operations/initializers"` — tensor creation ops (ones, zeros, full) + - `"operations/modules"` — nn module classes +- `@wrappers.interface(...)`: Defines input constraints and output guarantees (see constraint skill) +- Bridge to trace via `op_utils.create_op(TraceOpClass, [inputs], **kwargs)` + +### Step 3: Register in `__init__.py` Files + +**`nvtripy/frontend/ops/__init__.py`**: Add import so auto-discovery finds the module. + +**`nvtripy/trace/ops/__init__.py`**: Usually empty — trace ops are imported directly by frontend ops. + +### Step 4: Add as Tensor Method (Optional) + +If the op should be callable as `tensor.my_op()`, register it in the `TENSOR_METHOD_REGISTRY` via the frontend tensor metaclass system. Check `nvtripy/frontend/tensor.py` for the pattern. + +## Complete Example: Softmax + +**Trace op** (`nvtripy/trace/ops/softmax.py`): + +```python +@dataclass(repr=False) +class Softmax(TraceOp): + dim: int + infer_rank = op_utils.InferRankPolicies.same_as_input() + + def to_mlir(self, inputs, outputs): + return [tensorrt.softmax(inputs[0], self.dim)] +``` + +**Frontend op** (`nvtripy/frontend/ops/softmax.py`): + +```python +@export.public_api(document_under="operations/functions") +@wrappers.interface( + input_requirements=OneOf(GetInput("input").dtype, [dt.float32, dt.float16, dt.bfloat16]), + output_guarantees=GetReturn(0).dtype == GetInput("input").dtype, +) +def softmax(input: "nvtripy.Tensor", dim: Optional[int] = None) -> "nvtripy.Tensor": + # Handle None dim by flattening + # Handle rank < 2 by unsqueezing (TensorRT requirement) + dim = op_utils.process_dim(dim, input.rank) + return op_utils.create_op(Softmax, [input], dim=dim) +``` + +## Checklist + +- [ ] Trace op created in `nvtripy/trace/ops/` with `infer_rank` and `to_mlir` +- [ ] Frontend op created in `nvtripy/frontend/ops/` with `@export.public_api` and `@wrappers.interface` +- [ ] Constraints defined for valid dtypes and output guarantees +- [ ] Docstring includes Args, Returns, and a working `.. code-block:: python` example +- [ ] `__init__.py` updated if needed for auto-discovery +- [ ] Tests added in `tests/frontend/ops/` and `tests/trace/ops/` (see testing skill) diff --git a/tripy/.github/skills/tripy-testing/SKILL.md b/tripy/.github/skills/tripy-testing/SKILL.md new file mode 100644 index 000000000..ba028cf78 --- /dev/null +++ b/tripy/.github/skills/tripy-testing/SKILL.md @@ -0,0 +1,277 @@ +--- +name: tripy-testing +description: 'Write tests for nvtripy following project conventions. Use when: adding tests for ops, modules, trace operations, or compilation, using pytest parametrize, testing error cases with helper.raises, testing dtype combinations, understanding test directory structure.' +--- + +# Testing Patterns for nvtripy + +## When to Use + +- Adding tests for new operations, modules, or features +- Understanding the test directory structure and conventions +- Testing error cases and dtype validation +- Writing parametrized tests for multiple configurations + +## Test Directory Structure + +The test directory mirrors the source tree: + +``` +tests/ +├── frontend/ +│ ├── ops/ +│ │ ├── test_binary.py +│ │ ├── test_softmax.py +│ │ └── ... +│ ├── module/ +│ │ ├── test_linear.py +│ │ ├── test_layernorm.py +│ │ └── ... +│ ├── test_tensor.py +│ └── ... +├── trace/ +│ ├── ops/ +│ │ ├── test_binary.py +│ │ └── ... +│ └── ... +├── backend/ +│ └── ... +├── integration/ +│ └── ... +├── helper.py # Test utilities +└── conftest.py # Shared fixtures +``` + +## Test Utilities (`tests/helper.py`) + +### `helper.raises` — Error Testing + +```python +from tests import helper +import nvtripy as tp + +# Basic error test +with helper.raises(tp.TripyException): + result = bad_operation() + +# With message matching (regex) +with helper.raises(tp.TripyException, match="Expected.*one of"): + result = bad_operation() + +# With stack info verification +a = tp.Tensor([1.0]) +with helper.raises(tp.TripyException, has_stack_info_for=[a]): + result = bad_operation(a) +``` + +### `helper.config` — Temporary Config Changes + +```python +with helper.config("enable_input_validation", False): + # Validation disabled in this block + result = operation() +# Automatically restored +``` + +### `NUMPY_TO_TRIPY` — Dtype Mapping + +```python +from tests.helper import NUMPY_TO_TRIPY + +# Maps numpy dtypes to tripy dtypes: +# bool → tp.bool, np.int8 → tp.int8, np.int32 → tp.int32, +# np.int64 → tp.int64, np.float16 → tp.float16, np.float32 → tp.float32 +``` + +## Common Test Patterns + +### Basic Op Test + +```python +import cupy as cp +import numpy as np +import nvtripy as tp + + +class TestMyOp: + def test_basic(self): + input = tp.Tensor([1.0, 2.0, 3.0]) + output = tp.my_op(input) + + expected = np.array([...]) # Compute expected result + assert np.array_equal(cp.from_dlpack(output).get(), expected) + + def test_with_specific_dim(self): + input = tp.iota([2, 3], dtype=tp.float32) + output = tp.my_op(input, dim=1) + + assert output.shape == [2, 3] +``` + +### Parametrized Dtype Tests + +```python +import pytest + +class TestMyOp: + @pytest.mark.parametrize("dtype", [tp.float32, tp.float16, tp.bfloat16]) + def test_supported_dtypes(self, dtype): + input = tp.ones([2, 3], dtype=dtype) + output = tp.my_op(input) + assert output.dtype == dtype + + @pytest.mark.parametrize( + "dtype", + [tp.int8, tp.int32], + ids=["int8", "int32"], + ) + def test_unsupported_dtypes_fail(self, dtype): + input = tp.ones([2, 3], dtype=dtype) + with helper.raises(tp.TripyException, match="Expected.*one of"): + tp.my_op(input) +``` + +### Testing from NumPy Data + +```python +from tests.helper import NUMPY_TO_TRIPY + +class TestTensor: + @pytest.mark.parametrize("dtype", list(NUMPY_TO_TRIPY.keys())) + def test_dtype_from_numpy(self, dtype): + np_array = np.array([1, 2, 3], dtype=dtype) + tensor = tp.Tensor(np_array) + assert tensor.dtype == NUMPY_TO_TRIPY[dtype] +``` + +### Mismatched Dtype Error Tests + +```python +class TestBinaryOps: + def test_mismatched_dtypes_fails(self): + a = tp.Tensor([1.0, 2.0]) + b = tp.ones((2,), dtype=tp.float16) + with helper.raises(tp.TripyException): + c = a + b +``` + +### Module Tests + +```python +class TestLinear: + def test_basic(self): + linear = tp.Linear(3, 4) + linear.weight = tp.iota(linear.weight.shape) + linear.bias = tp.iota(linear.bias.shape) + + input = tp.iota((2, 3)) + output = linear(input) + + assert cp.from_dlpack(output).get().shape == (2, 4) + + def test_no_bias(self): + linear = tp.Linear(3, 4, bias=False) + linear.weight = tp.iota(linear.weight.shape) + + input = tp.iota((2, 3)) + output = linear(input) + assert output.shape == [2, 4] + + def test_state_dict(self): + linear = tp.Linear(3, 4) + sd = linear.state_dict() + assert "weight" in sd + assert "bias" in sd +``` + +### Trace Op Tests + +```python +from nvtripy.trace.ops.my_op import MyOp +from nvtripy.trace.ops.base import TraceOp + +class TestMyTraceOp: + def test_creates_correct_op(self): + input = tp.Tensor([1.0, 2.0]) + output = tp.my_op(input) + + assert isinstance(output.trace_tensor.producer, MyOp) + + def test_infer_rank(self): + input = tp.Tensor([1.0, 2.0]) + output = tp.my_op(input) + + assert output.trace_tensor.rank == 1 +``` + +### Allclose Comparisons + +```python +class TestSoftmax: + def test_matches_torch(self): + input = tp.iota([2, 3], dtype=tp.float32) + output = tp.softmax(input, dim=1) + + torch_input = torch.from_dlpack(input) + torch_output = torch.softmax(torch_input, dim=1) + + assert tp.allclose(output, tp.Tensor(torch_output)) +``` + +### Parametrize with IDs + +```python +@pytest.mark.parametrize( + "tensor_a, tensor_b, rtol, atol, expected", + [ + (tp.Tensor([1.0]), tp.Tensor([1.0]), 1e-5, 1e-8, True), + (tp.Tensor([1.0]), tp.Tensor([2.0]), 1e-5, 1e-8, False), + ], + ids=["equal", "not_equal"], +) +def test_allclose(self, tensor_a, tensor_b, rtol, atol, expected): + result = tp.allclose(tensor_a, tensor_b, rtol=rtol, atol=atol) + assert result == expected +``` + +## Verifying Outputs + +| Method | Use When | +|--------|----------| +| `cp.from_dlpack(output).get()` | Convert tripy tensor → numpy array (via cupy) | +| `np.array_equal(a, b)` | Exact equality for integer/bool results | +| `np.allclose(a, b)` | Approximate equality for float results | +| `tp.allclose(a, b)` | Compare two tripy tensors directly | +| `torch.from_dlpack(tensor)` | Convert tripy tensor → torch tensor | +| `output.shape` | Check output shape | +| `output.dtype` | Check output dtype | + +## Running Tests + +```bash +# Run all tests +pytest tests/ + +# Run specific test file +pytest tests/frontend/ops/test_softmax.py + +# Run specific test +pytest tests/frontend/ops/test_softmax.py::TestSoftmax::test_basic + +# Run with verbose output +pytest -v tests/frontend/ops/test_softmax.py + +# Run tests matching a pattern +pytest -k "softmax" +``` + +## Checklist + +- [ ] Test file created at `tests//test_.py` +- [ ] Tests organized in a class (e.g., `TestMyOp`) +- [ ] Basic functionality test with assertion +- [ ] Parametrized dtype tests for all supported dtypes +- [ ] Error case tests using `helper.raises(tp.TripyException)` +- [ ] Shape validation tests +- [ ] Comparison against reference implementation (torch/numpy) where applicable +- [ ] Test IDs provided for parametrized tests to make failures readable diff --git a/tripy/nvtripy/__init__.py b/tripy/nvtripy/__init__.py index 9ac751b6f..57a351ec9 100644 --- a/tripy/nvtripy/__init__.py +++ b/tripy/nvtripy/__init__.py @@ -15,7 +15,7 @@ # limitations under the License. # -__version__ = "0.1.6" +__version__ = "0.1.7" # Import TensorRT to make sure all dependent libraries are loaded first. import tensorrt diff --git a/tripy/nvtripy/backend/mlir/compiler.py b/tripy/nvtripy/backend/mlir/compiler.py index e8e77843b..85e3928e0 100644 --- a/tripy/nvtripy/backend/mlir/compiler.py +++ b/tripy/nvtripy/backend/mlir/compiler.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,6 +15,7 @@ # limitations under the License. # +import threading from typing import Optional, Tuple import mlir_tensorrt.compiler.api as compiler @@ -26,21 +27,22 @@ G_COMPILER_CLIENT = None G_TIMING_CACHE_FILE = cfg.timing_cache_file_path +_COMPILER_LOCK = threading.Lock() # Avoid instantiating the compiler more than once. def _get_compiler_objects() -> Tuple[ir.Context, compiler.CompilerClient]: global G_COMPILER_CLIENT, G_TIMING_CACHE_FILE - if G_TIMING_CACHE_FILE != cfg.timing_cache_file_path: - # Reinitialize the compiler if the timing cache file path has changed. - global G_COMPILER_CLIENT - G_COMPILER_CLIENT = None - G_TIMING_CACHE_FILE = cfg.timing_cache_file_path + with _COMPILER_LOCK: + if G_TIMING_CACHE_FILE != cfg.timing_cache_file_path: + # Reinitialize the compiler if the timing cache file path has changed. + G_COMPILER_CLIENT = None + G_TIMING_CACHE_FILE = cfg.timing_cache_file_path - ctx = make_ir_context() - if G_COMPILER_CLIENT is None: - G_COMPILER_CLIENT = compiler.CompilerClient(ctx) - return ctx, G_COMPILER_CLIENT + ctx = make_ir_context() + if G_COMPILER_CLIENT is None: + G_COMPILER_CLIENT = compiler.CompilerClient(ctx) + return ctx, G_COMPILER_CLIENT class Compiler: diff --git a/tripy/nvtripy/backend/mlir/utils.py b/tripy/nvtripy/backend/mlir/utils.py index d2d445753..f967e1acf 100644 --- a/tripy/nvtripy/backend/mlir/utils.py +++ b/tripy/nvtripy/backend/mlir/utils.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -259,14 +259,16 @@ def starts_with_any(line, *starts): # file descriptor to something we can intercept. `contextlib.redirect_stderr` does not do this. @contextlib.contextmanager def redirect_stderr() -> BinaryIO: + f = tempfile.NamedTemporaryFile() try: - f = tempfile.NamedTemporaryFile() sys.stderr.flush() original_stderr = os.dup(2) new_stderr = os.dup(2) - os.dup2(os.open(f.name, os.O_WRONLY | os.O_TRUNC | os.O_CREAT), 2) + temp_fd = os.open(f.name, os.O_WRONLY | os.O_TRUNC | os.O_CREAT) + os.dup2(temp_fd, 2) + os.close(temp_fd) sys.stderr = os.fdopen(new_stderr, "w") yield f diff --git a/tripy/nvtripy/common/exception.py b/tripy/nvtripy/common/exception.py index 33799d742..e6cd06265 100644 --- a/tripy/nvtripy/common/exception.py +++ b/tripy/nvtripy/common/exception.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -66,7 +66,7 @@ def apply_color(inp, color): # won't include it in that case. try: candidate_column_offsets = utils.ast.get_candidate_column_offsets(source_info, callee_info) - except: + except Exception: pass else: if len(candidate_column_offsets) == 1: @@ -120,7 +120,7 @@ def should_exclude(source_info): return None -def raise_error(summary: str, details: List[Any] = []): +def raise_error(summary: str, details: Optional[List[Any]] = None): """ Raises a Tripy exception with a formatted message. @@ -142,6 +142,7 @@ def raise_error(summary: str, details: List[Any] = []): """ pre_summary = "" + details = utils.utils.default(details, []) stack_info = utils.stack_info.get_stack_info() user_frame_index = stack_info.get_first_user_frame_index() if user_frame_index is not None: @@ -191,24 +192,7 @@ def search_for_missing_attr(module_name: str, name: str, look_in: List[Tuple[Any # then call `search_for_missing_attr` ad infinitum. stack = inspect.stack() - stack_modules = [] - stack_classes = [] - for frame in stack: - module = inspect.getmodule(frame.frame) - if module: - stack_modules.append(module) - - self_arg = frame.frame.f_locals.get("self") - if self_arg is not None: - try: - class_type = self_arg.__class__ - except: - pass - else: - stack_classes.append(class_type) - stack_modules = list(filter(lambda mod: mod is not None, [inspect.getmodule(frame.frame) for frame in stack])) - stack_classes = list([]) msg = f"Module: '{module_name}' does not have attribute: '{name}'" # If a symbol isn't found in the top-level, we'll look at specific classes/modules @@ -216,13 +200,18 @@ def search_for_missing_attr(module_name: str, name: str, look_in: List[Tuple[Any # We provide the names as well since the object name will be the fully qualified name, # which is not necessarily what the user uses. + # Unsupported dtypes mapped to their closest supported alternatives. + _DTYPE_SUGGESTIONS = { + "float64": "float32", + "int16": "int32", + } + for obj, obj_name in look_in: # Avoid infinite recursion - see comment above. - if obj in stack_modules + stack_classes: - if name == "float64": - msg += f". Did you mean: 'float32'?" - if name == "int16": - msg += f". Did you mean: 'int32'?" + if obj in stack_modules: + suggestion = _DTYPE_SUGGESTIONS.get(name) + if suggestion: + msg += f". Did you mean: '{suggestion}'?" continue if hasattr(obj, name): diff --git a/tripy/nvtripy/frontend/tensor.py b/tripy/nvtripy/frontend/tensor.py index bbc117ff1..b273b8491 100644 --- a/tripy/nvtripy/frontend/tensor.py +++ b/tripy/nvtripy/frontend/tensor.py @@ -90,7 +90,8 @@ def __init__( tensor = tp.Tensor([1.0, 2.0, 3.0]) """ # We use None internally but users should not be permitted to do it - assert data is not None, "Data argument to Tensor must not be None" + if data is None: + raise_error("Data argument to Tensor must not be None.") if isinstance(data, Tensor): raise_error( "Cannot initialize Tensor with another Tensor.", [f"Note: `data` argument was defined here:", data] diff --git a/tripy/nvtripy/trace/ops/base.py b/tripy/nvtripy/trace/ops/base.py index d527cbc30..774762c0b 100644 --- a/tripy/nvtripy/trace/ops/base.py +++ b/tripy/nvtripy/trace/ops/base.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -19,19 +19,18 @@ from dataclasses import dataclass, field from typing import List, Set +import itertools + from nvtripy import utils from nvtripy.common.device import device +from nvtripy.common.exception import raise_error from nvtripy.trace.tensor import TraceTensor -_COUNT = 0 +_COUNT = itertools.count() def _get_unique_name(): - global _COUNT - - name = f"%t{_COUNT}" - _COUNT += 1 - return name + return f"%t{next(_COUNT)}" @dataclass(repr=False) @@ -77,13 +76,17 @@ def infer_dtypes(self): """ Infers dtypes for the operation and updates output tensor dtypes accordingly. """ - assert self.inputs, "Default implementation cannot handle cases where there are no inputs. Please override." - assert ( - len(self.outputs) == 1 - ), f"Default implementation expects exactly one output, but got {len(self.outputs)}. Please override." - assert all( - inp.dtype == self.inputs[0].dtype for inp in self.inputs - ), f"Default implementation cannot handle cases where inputs have different dtypes, but got {[inp.dtype for inp in self.inputs]}. Please override." + if not self.inputs: + raise_error("Default implementation cannot handle cases where there are no inputs. Please override.") + if len(self.outputs) != 1: + raise_error( + f"Default implementation expects exactly one output, but got {len(self.outputs)}. Please override." + ) + if not all(inp.dtype == self.inputs[0].dtype for inp in self.inputs): + raise_error( + f"Default implementation cannot handle cases where inputs have different dtypes, " + f"but got {[inp.dtype for inp in self.inputs]}. Please override." + ) self.outputs[0].dtype = self.inputs[0].dtype diff --git a/tripy/nvtripy/utils/ast.py b/tripy/nvtripy/utils/ast.py index 79830b5a5..7781f80f3 100644 --- a/tripy/nvtripy/utils/ast.py +++ b/tripy/nvtripy/utils/ast.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -179,7 +179,7 @@ def get_candidate_column_offsets(cur_frame: SourceInfo, callee: SourceInfo) -> L try: ast_node_name = get_ast_node_func_name(node) - except: + except Exception: continue if ast_node_name is None: diff --git a/tripy/nvtripy/utils/result.py b/tripy/nvtripy/utils/result.py index 89d2efae5..ad681765c 100644 --- a/tripy/nvtripy/utils/result.py +++ b/tripy/nvtripy/utils/result.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -44,9 +44,11 @@ def __bool__(self) -> bool: def __getattribute__(self, name: str) -> Any: if name == "value": - assert self.is_ok, "Cannot retrieve value of an error result" + if not super().__getattribute__("is_ok"): + raise RuntimeError("Cannot retrieve value of an error result") if name == "error_details": - assert not self.is_ok, "Cannot retrieve error details of an ok result" + if super().__getattribute__("is_ok"): + raise RuntimeError("Cannot retrieve error details of an ok result") return super().__getattribute__(name) diff --git a/tripy/nvtripy/utils/utils.py b/tripy/nvtripy/utils/utils.py index 26190b1da..0b9c5a209 100644 --- a/tripy/nvtripy/utils/utils.py +++ b/tripy/nvtripy/utils/utils.py @@ -21,6 +21,7 @@ import inspect import math import os +import threading import time import typing from typing import Any, List, Optional, Sequence, Tuple, Union @@ -55,12 +56,14 @@ def call_once(func): Decorator that makes it so that the decorated function can only be called once. Any subsequent calls will do nothing. """ + lock = threading.Lock() @functools.wraps(func) def wrapper(*args, **kwargs): - if wrapper.never_run: - wrapper.never_run = False - return func(*args, **kwargs) + with lock: + if wrapper.never_run: + wrapper.never_run = False + return func(*args, **kwargs) wrapper.never_run = True return wrapper @@ -278,7 +281,7 @@ def save_file( os.fsync(dest.fileno()) try: content_bytes = len(contents.encode()) - except: + except Exception: pass else: if bytes_written != content_bytes: diff --git a/tripy/pyproject.toml b/tripy/pyproject.toml index f417d2685..0b6ae8547 100644 --- a/tripy/pyproject.toml +++ b/tripy/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "nvtripy" -version = "0.1.6" +version = "0.1.7" authors = [{ name = "NVIDIA", email = "svc_tensorrt@nvidia.com" }] description = "Tripy: A Python Programming Model For TensorRT" readme = "README.md" diff --git a/tripy/tests/backend/mlir/test_compiler.py b/tripy/tests/backend/mlir/test_compiler.py index 33cb44749..ed749d7b2 100644 --- a/tripy/tests/backend/mlir/test_compiler.py +++ b/tripy/tests/backend/mlir/test_compiler.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -21,6 +21,22 @@ import nvtripy as tp +class TestCompilerClient: + def test_concurrent_compiler_init(self): + import concurrent.futures + from nvtripy.backend.mlir.compiler import _get_compiler_objects + + results = [] + with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor: + futures = [executor.submit(_get_compiler_objects) for _ in range(4)] + for f in concurrent.futures.as_completed(futures): + results.append(f.result()) + + # All threads should get back the same compiler client instance. + clients = [client for _, client in results] + assert all(c is clients[0] for c in clients), "All threads should share the same CompilerClient" + + # Tests to ensure that we're able to map errors from MLIR-TRT back to the Python code cleanly. class TestErrorMapping: def test_invalid_slice(self): diff --git a/tripy/tests/trace/ops/test_general.py b/tripy/tests/trace/ops/test_general.py index 6268a2fc1..dbe4b085d 100644 --- a/tripy/tests/trace/ops/test_general.py +++ b/tripy/tests/trace/ops/test_general.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -47,3 +47,15 @@ def test_has_no_dataclass_repr(self, OpType): assert ( OpType.__repr__ is TraceOp.__repr__ ), "Use @dataclass(repr=False) to avoid extremely verbose __repr__ implementations" + + +class TestGetUniqueName: + def test_concurrent_unique_names(self): + import concurrent.futures + from nvtripy.trace.ops.base import _get_unique_name + + with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor: + futures = [executor.submit(_get_unique_name) for _ in range(100)] + names = [f.result() for f in futures] + + assert len(names) == len(set(names)), "All generated names should be unique" diff --git a/tripy/tests/utils/test_result.py b/tripy/tests/utils/test_result.py index 0e647e513..b783e1a5f 100644 --- a/tripy/tests/utils/test_result.py +++ b/tripy/tests/utils/test_result.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -22,10 +22,10 @@ class TestResult: def test_cannot_retrieve_value_of_error(self): result: Result[int] = Result.err(["error!"]) - with helper.raises(AssertionError): + with helper.raises(RuntimeError): result.value def test_cannot_retrieve_error_details_of_ok(self): result: Result[int] = Result.ok(0) - with helper.raises(AssertionError): + with helper.raises(RuntimeError): result.error_details diff --git a/tripy/tests/utils/test_utils.py b/tripy/tests/utils/test_utils.py index 298e4f166..a1adf4ae5 100644 --- a/tripy/tests/utils/test_utils.py +++ b/tripy/tests/utils/test_utils.py @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -141,3 +141,34 @@ def func(a, *args, b=10): assert all_args == [("a", 1)] assert omitted_args == [("b", 10)] assert var_arg_info is None + + +class TestCallOnce: + def test_runs_only_once(self): + counter = 0 + + @utils.utils.call_once + def increment(): + nonlocal counter + counter += 1 + + increment() + increment() + increment() + assert counter == 1 + + def test_concurrent_calls_run_only_once(self): + import concurrent.futures + + counter = 0 + + @utils.utils.call_once + def increment(): + nonlocal counter + counter += 1 + + with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor: + futures = [executor.submit(increment) for _ in range(8)] + concurrent.futures.wait(futures) + + assert counter == 1