Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions tritonbench/operators/fp32_to_mx4/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
from typing import Callable, Generator, List, Optional, Tuple

import torch
from tritonbench.utils.jagged_utils import GIGABYTES_PER_BYTE
from tritonbench.utils.python_utils import try_import
from tritonbench.utils.triton_op import BenchmarkOperatorMetrics, register_metric

# We are benchmarking the kernel used inside quantize_comm. Insofar, we are using the fp32_to_mx4 fbgemm API rather than the quantize_mx API.
with try_import("HAS_FBGEMM"):
Expand All @@ -16,6 +18,8 @@


class Operator(BenchmarkOperator):
is_compute_bound: bool = False

def __init__(
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
):
Expand Down Expand Up @@ -47,3 +51,23 @@ def get_x_val(self, example_inputs) -> Tuple[int, int, int, int, RoundingMode, i
rounding_mode,
stochastic_casting,
)

@register_metric()
def gbps(
self,
fn,
example_inputs: Tuple[torch.Tensor, int, int, int, RoundingMode, bool],
metrics: BenchmarkOperatorMetrics,
) -> float:
# fp32_to_mx4: a[M] -> out[M / 2 + M / group_size] (int8)
return (
(
example_inputs[0].element_size() * example_inputs[0].numel()
+ (example_inputs[0].numel() + 1) // 2
+ (example_inputs[0].numel() + example_inputs[1] - 1)
// example_inputs[1]
)
/ metrics.latency
* 1e3
* GIGABYTES_PER_BYTE
)
26 changes: 26 additions & 0 deletions tritonbench/operators/mx4_to_fp32/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
from typing import Callable, Generator, List, Optional, Tuple

import torch
from tritonbench.utils.jagged_utils import GIGABYTES_PER_BYTE

# We are benchmarking the kernel used inside quantize_comm. Insofar, we are using the fp32_to_mx4 fbgemm API rather than the quantize_mx API.

from tritonbench.utils.python_utils import try_import
from tritonbench.utils.triton_op import (
BenchmarkOperator,
BenchmarkOperatorMetrics,
register_benchmark,
register_metric,
register_x_val,
)

Expand All @@ -17,6 +20,8 @@


class Operator(BenchmarkOperator):
is_compute_bound: bool = False

def __init__(
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
):
Expand Down Expand Up @@ -53,3 +58,24 @@ def fbgemm_mx4_to_fp32(
def get_x_val(self, example_inputs) -> Tuple[int, int, int, int]:
input_tensor, group_size, ebits, mbits = example_inputs
return (input_tensor.numel(), group_size, ebits, mbits)

@register_metric()
def gbps(
self,
fn,
example_inputs: Tuple[torch.Tensor, int, int, int],
metrics: BenchmarkOperatorMetrics,
) -> float:
# mx4_to_fp32: a[M / 2 + M / group_size] (int 8) -> out[M]
packed_group_size = example_inputs[1] // 2 + 1
num_groups = example_inputs[0].numel() // packed_group_size
out_size = num_groups * example_inputs[1]
return (
(
example_inputs[0].element_size() * example_inputs[0].numel()
+ out_size * 4
)
/ metrics.latency
* 1e3
* GIGABYTES_PER_BYTE
)