diff --git a/tritonbench/operators/fp32_to_mx4/operator.py b/tritonbench/operators/fp32_to_mx4/operator.py index e4a8ab8b4..06f049074 100644 --- a/tritonbench/operators/fp32_to_mx4/operator.py +++ b/tritonbench/operators/fp32_to_mx4/operator.py @@ -2,7 +2,9 @@ from typing import Callable, Generator, List, Optional, Tuple import torch +from tritonbench.utils.jagged_utils import GIGABYTES_PER_BYTE from tritonbench.utils.python_utils import try_import +from tritonbench.utils.triton_op import BenchmarkOperatorMetrics, register_metric # We are benchmarking the kernel used inside quantize_comm. Insofar, we are using the fp32_to_mx4 fbgemm API rather than the quantize_mx API. with try_import("HAS_FBGEMM"): @@ -16,6 +18,8 @@ class Operator(BenchmarkOperator): + is_compute_bound: bool = False + def __init__( self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None ): @@ -47,3 +51,23 @@ def get_x_val(self, example_inputs) -> Tuple[int, int, int, int, RoundingMode, i rounding_mode, stochastic_casting, ) + + @register_metric() + def gbps( + self, + fn, + example_inputs: Tuple[torch.Tensor, int, int, int, RoundingMode, bool], + metrics: BenchmarkOperatorMetrics, + ) -> float: + # fp32_to_mx4: a[M] -> out[M / 2 + M / group_size] (int8) + return ( + ( + example_inputs[0].element_size() * example_inputs[0].numel() + + (example_inputs[0].numel() + 1) // 2 + + (example_inputs[0].numel() + example_inputs[1] - 1) + // example_inputs[1] + ) + / metrics.latency + * 1e3 + * GIGABYTES_PER_BYTE + ) diff --git a/tritonbench/operators/mx4_to_fp32/operator.py b/tritonbench/operators/mx4_to_fp32/operator.py index 73c1983c4..d463e8e1a 100644 --- a/tritonbench/operators/mx4_to_fp32/operator.py +++ b/tritonbench/operators/mx4_to_fp32/operator.py @@ -2,13 +2,16 @@ from typing import Callable, Generator, List, Optional, Tuple import torch +from tritonbench.utils.jagged_utils import GIGABYTES_PER_BYTE # We are benchmarking the kernel used inside quantize_comm. Insofar, we are using the fp32_to_mx4 fbgemm API rather than the quantize_mx API. from tritonbench.utils.python_utils import try_import from tritonbench.utils.triton_op import ( BenchmarkOperator, + BenchmarkOperatorMetrics, register_benchmark, + register_metric, register_x_val, ) @@ -17,6 +20,8 @@ class Operator(BenchmarkOperator): + is_compute_bound: bool = False + def __init__( self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None ): @@ -53,3 +58,24 @@ def fbgemm_mx4_to_fp32( def get_x_val(self, example_inputs) -> Tuple[int, int, int, int]: input_tensor, group_size, ebits, mbits = example_inputs return (input_tensor.numel(), group_size, ebits, mbits) + + @register_metric() + def gbps( + self, + fn, + example_inputs: Tuple[torch.Tensor, int, int, int], + metrics: BenchmarkOperatorMetrics, + ) -> float: + # mx4_to_fp32: a[M / 2 + M / group_size] (int 8) -> out[M] + packed_group_size = example_inputs[1] // 2 + 1 + num_groups = example_inputs[0].numel() // packed_group_size + out_size = num_groups * example_inputs[1] + return ( + ( + example_inputs[0].element_size() * example_inputs[0].numel() + + out_size * 4 + ) + / metrics.latency + * 1e3 + * GIGABYTES_PER_BYTE + )