Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .ci/helion/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,4 @@ cd ${tritonbench_dir}

python install.py --helion
# Helion requires tritonbench installed as a library
pip install -e .
pip install -e .
2 changes: 1 addition & 1 deletion benchmarks/gen_metadata/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def run(args: argparse.Namespace):
DTYPE_OPERATORS[op] = op_bench.DEFAULT_PRECISION
if baseline := op_bench.has_baseline():
BASELINE_OPERATORS[op] = baseline
if op_bench.has_metric("tflops") and not op in TFLOPS_SKIP_OPERATORS:
if op_bench.has_metric("tflops") and op not in TFLOPS_SKIP_OPERATORS:
TFLOPS_OPERATORS.append(op)
if op_bench.has_bwd():
BACKWARD_OPERATORS.append(op)
Expand Down
10 changes: 2 additions & 8 deletions benchmarks/mojo_matmul/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,10 @@
pip install --pre modular --index-url https://dl.modular.com/public/nightly/python/simple/
"""

import argparse
import json
import logging
import os
import sys

from os.path import abspath, exists
from typing import Dict, List


def setup_tritonbench_cwd():
Expand All @@ -36,15 +32,13 @@ def setup_tritonbench_cwd():
from typing import Callable

import max.graph as mg
import torch

from max import driver, engine
from max.graph import DeviceRef, Graph, ops, TensorType, TensorValue
from max.graph.type import DType, Shape, ShapeLike
from max.graph import DeviceRef, ops, TensorType
from max.graph.type import DType

from tritonbench.operators import load_opbench_by_name
from tritonbench.utils.parser import get_parser
from tritonbench.utils.triton_op import register_benchmark


def promote_mojo_tensor_to_fp32(mojo_tensor, dtype):
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/nightly/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def gen_run(operators: List[str], bwd: bool = False) -> Dict[str, Any]:
cmd.append("--bwd")
# add backends
run_backends = list(TRITON_OPS[op].keys())
if _has_meaningful_baseline(op) and not BASELINE_OPS[op] in run_backends:
if _has_meaningful_baseline(op) and BASELINE_OPS[op] not in run_backends:
run_backends.append(BASELINE_OPS[op])
cmd.extend(["--only", ",".join(run_backends)])
out[run_name] = {}
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/nightly/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def run():
logger.info(f"[nightly] logging result json file to {result_json_file}.")
if args.log_scuba:
log_benchmark(aggregated_obj)
logger.info(f"[nightly] logging results to scuba.")
logger.info("[nightly] logging results to scuba.")


if __name__ == "__main__":
Expand Down
1 change: 0 additions & 1 deletion benchmarks/power_analysis/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import os
import sys

import torch

CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))

Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ dependencies = [
[tool.setuptools.packages.find]
include = ["tritonbench*"]

[tool.ruff]
fix = true
exclude = ["submodules"]

[tool.ufmt]
formatter = "ruff-api"
sorter = "usort"
Expand Down
1 change: 0 additions & 1 deletion tools/flash_attn/install.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os
import subprocess
import sys

from pathlib import Path

Expand Down
3 changes: 0 additions & 3 deletions tritonbench/components/power/charts.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import csv
import logging
import os
import signal
import subprocess
import time

import matplotlib.pyplot as plt

Expand Down
3 changes: 0 additions & 3 deletions tritonbench/components/power/power_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import csv
import dataclasses
import os
import threading
import time
Expand All @@ -14,12 +13,10 @@
NVML_CLOCK_SM,
NVML_FI_DEV_POWER_CURRENT_LIMIT,
NVML_FI_DEV_POWER_INSTANT,
NVML_SUCCESS,
NVML_TEMPERATURE_GPU,
nvmlDeviceGetClock,
nvmlDeviceGetFieldValues,
nvmlDeviceGetHandleByIndex,
nvmlDeviceGetPerformanceState,
nvmlDeviceGetTemperature,
nvmlInit,
nvmlShutdown,
Expand Down
3 changes: 0 additions & 3 deletions tritonbench/components/tasks/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,8 @@ def make_instance(
class_name: str,
) -> None:
import importlib
import os
import traceback

# required as this is in child process
from tritonbench.components.power.power_manager import PowerManager

module = importlib.import_module(module_path, package=package)
Ctor = getattr(module, class_name)
Expand Down
3 changes: 0 additions & 3 deletions tritonbench/kernels/blackwell_attention_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@
generic attention kernels.
"""

import os
from functools import lru_cache

import torch
import triton

Expand Down
1 change: 0 additions & 1 deletion tritonbench/kernels/blackwell_triton_fused_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
is_blackwell,
is_cuda,
is_hip,
is_hopper,
supports_host_descriptor,
)

Expand Down
2 changes: 0 additions & 2 deletions tritonbench/kernels/gluon_attention_forward.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import itertools

import torch
import triton
import triton.language as tl
Expand Down
1 change: 0 additions & 1 deletion tritonbench/kernels/gluon_attention_persistent_forward.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import copy
import itertools

import torch
import triton
Expand Down
2 changes: 0 additions & 2 deletions tritonbench/operators/addmm/hstu.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import importlib

from typing import Tuple

import torch
Expand Down
1 change: 0 additions & 1 deletion tritonbench/operators/addmm/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
except ImportError:
streamk_matmul = None

from tritonbench.operators.gemm import stream_k
from tritonbench.utils.triton_op import (
BenchmarkOperator,
BenchmarkOperatorMetrics,
Expand Down
2 changes: 0 additions & 2 deletions tritonbench/operators/bf16xint16_gemm/bf16xint16_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,11 @@
"""

import argparse
import statistics

from typing import Any, List, Optional

import torch
import triton
import triton.language as tl

from tritonbench.utils.triton_op import (
BenchmarkOperator,
Expand Down
2 changes: 1 addition & 1 deletion tritonbench/operators/blackwell_attentions/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def _is_sdpa_cudnn_attention_available():
try:
_sdpa_cudnn_attention(q, k, v)
return True
except RuntimeError as e:
except RuntimeError:
return False


Expand Down
3 changes: 1 addition & 2 deletions tritonbench/operators/flex_attention/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
except ImportError:
pass

from tritonbench.utils.input import input_filter
from tritonbench.utils.triton_op import (
BenchmarkOperator,
BenchmarkOperatorMetrics,
Expand Down Expand Up @@ -379,7 +378,7 @@ def flash_v3(
)
elif mod_type == "document_mask":
# Document mask requires special handling with varlen function
print(f"[SKIP] Flash Attention v3 document_mask not implemented yet")
print("[SKIP] Flash Attention v3 document_mask not implemented yet")
raise NotImplementedError(
"Flash Attention v3 document_mask not implemented yet"
)
Expand Down
2 changes: 1 addition & 1 deletion tritonbench/operators/fp8_attention/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import argparse
import math

from typing import Any, Callable, Generator, List, Optional, Tuple
from typing import Any, Callable, Generator, List, Optional

import torch

Expand Down
2 changes: 0 additions & 2 deletions tritonbench/operators/fp8_gemm/persistent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from functools import lru_cache

from typing import Optional

import torch
Expand Down
4 changes: 0 additions & 4 deletions tritonbench/operators/fp8_gemm_blockwise/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,10 @@ def parse_args(args: List[str]) -> argparse.Namespace:
HAS_CUTLASS = False
if is_cuda():
try:
import fbgemm_gpu.experimental.gen_ai

cutlass_fp8_block = torch.ops.llama_cpp.fp8_blockwise_matmul
HAS_CUTLASS = True
except:
try:
import fbgemm_gpu.experimental.gen_ai

cutlass_fp8_block = torch.ops.fbgemm.f8f8bf16_blockwise
HAS_CUTLASS = True
except:
Expand Down
3 changes: 0 additions & 3 deletions tritonbench/operators/fp8_gemm_rowwise_grouped/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,11 @@

# Import necessary libraries and modules
import argparse
import random
from typing import Any, Callable, Generator, List, Optional, Tuple

import torch
import triton

from tritonbench.utils.data_utils import get_production_shapes

from tritonbench.utils.triton_op import (
BenchmarkOperator,
BenchmarkOperatorMetrics,
Expand Down
1 change: 0 additions & 1 deletion tritonbench/operators/gdpa/gdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from typing import Tuple

import torch
import torch.nn.functional as F

import triton # @manual=//triton:triton
import triton.language as tl # @manual=//triton:triton
Expand Down
2 changes: 1 addition & 1 deletion tritonbench/operators/gdpa/gdpa_blackwell_tlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from triton.tools.tensor_descriptor import TensorDescriptor

from .gdpa_utils import get_num_sms
from .math import activation_string_to_int, fast_gelu_grad, gelu, gelu_grad
from .math import activation_string_to_int, gelu, gelu_grad


def _host_descriptor_pre_hook(nargs):
Expand Down
2 changes: 1 addition & 1 deletion tritonbench/operators/gdpa/gdpa_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# pyre-strict
import math
from functools import lru_cache
from typing import Any, List, Optional
from typing import Any, Optional

import torch
import triton # @manual=//triton:triton
Expand Down
2 changes: 1 addition & 1 deletion tritonbench/operators/gdpa/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def parse_args(args):
"--kv_len",
default=None,
type=int,
help=f"Sequence length for K/V, if None, the tensor will be jagged and have the same length as Q",
help="Sequence length for K/V, if None, the tensor will be jagged and have the same length as Q",
)
parser.add_argument(
"--activation",
Expand Down
2 changes: 1 addition & 1 deletion tritonbench/operators/grouped_gemm/cutedsl/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -1996,7 +1996,7 @@ def compile_cutedsl_grouped_gemm(
C_cpu = torch.zeros((m, n, 1), dtype=torch.float32)
torch_fp32_tensors_abc_seed.append([A_cpu, B_cpu, C_cpu])

print(f"Running Blackwell Grouped GEMM test with:")
print("Running Blackwell Grouped GEMM test with:")
print(f"{num_groups} groups")
for i, (m, n, k, l) in enumerate(problem_sizes_mnkl):
print(f"Group {i}: {m}x{n}x{k}x{l}")
Expand Down
4 changes: 1 addition & 3 deletions tritonbench/operators/int4_gemm/int4_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,11 @@
"""

import argparse
import statistics

from typing import Any, List, Optional

import torch
import triton
import triton.language as tl

from tritonbench.utils.triton_op import (
BenchmarkOperator,
Expand All @@ -21,7 +19,7 @@
register_metric,
)

from .kernel import _group_quantize_tensor, matmul, matmul_kernel, pack_2xint4
from .kernel import _group_quantize_tensor, matmul, pack_2xint4


class Operator(BenchmarkOperator):
Expand Down
3 changes: 0 additions & 3 deletions tritonbench/operators/jagged_layer_norm/operator.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
import argparse
import itertools
import math
import os
import random
from typing import Callable, Generator, List, Optional, Tuple

import torch
Expand Down
3 changes: 0 additions & 3 deletions tritonbench/operators/jagged_softmax/operator.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
import argparse
import itertools
import math
import os
import random
from typing import Callable, Generator, List, Optional, Tuple

import torch
Expand Down
9 changes: 1 addition & 8 deletions tritonbench/operators/launch_latency/operator.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,9 @@
import triton.language as tl
from torch import zeros
from torch._C import _cuda_getCurrentRawStream as get_raw_stream

from torch._inductor.utils import triton_version_uses_attrs_dict
from triton.compiler import CompiledKernel

from tritonbench.utils.triton_op import (
BenchmarkOperator,
BenchmarkOperatorMetrics,
register_benchmark,
register_metric,
)
from tritonbench.utils.triton_op import BenchmarkOperator, register_benchmark

from .kernels import get_trivial_add_kernel, nop_kernel, nop_with_args_kernel

Expand Down
3 changes: 0 additions & 3 deletions tritonbench/operators/low_mem_dropout/kernels.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import tabulate
import torch

import triton
import triton.language as tl

Expand Down
7 changes: 1 addition & 6 deletions tritonbench/operators/mamba2_chunk_scan/operator.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,17 @@
import argparse
import functools
import itertools
import os
import sys
from contextlib import nullcontext
from itertools import chain

from typing import Any, Callable, Generator, List, Optional
from typing import Any, Generator, List, Optional

import torch

from tritonbench.utils.input import input_filter
from tritonbench.utils.python_utils import try_import

from tritonbench.utils.triton_op import (
BenchmarkOperator,
BenchmarkOperatorMetrics,
Mode as BenchmarkMode,
register_benchmark,
register_metric,
register_x_val,
Expand Down
Loading
Loading