diff --git a/tritonbench/operators/addmm/operator.py b/tritonbench/operators/addmm/operator.py index b1f51a102..ca245bef1 100644 --- a/tritonbench/operators/addmm/operator.py +++ b/tritonbench/operators/addmm/operator.py @@ -157,6 +157,68 @@ def flops( flops = (2 * m * k * n) + (m * n) return flops + @register_metric() + def op_gflops( + self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics + ) -> float: + """Report the raw number of GFLOPS (not GFLOPS/sec) for the addmm operation.""" + _, mat1, mat2 = example_inputs + m, k = mat1.size() + k, n = mat2.size() + flops = (2 * m * k * n) + (m * n) + # Convert FLOPS to GFLOPS (divide by 10^9) + gflops = flops / 1e9 + return gflops + + @register_metric() + def op_gbytes( + self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics + ) -> float: + """Report the raw number of gigabytes of I/O (not GB/sec) for the addmm operation.""" + a, mat1, mat2 = example_inputs + numel = ( + a.numel() + + mat1.numel() + + mat2.numel() + + (torch.addmm(a, mat1, mat2).numel()) + ) + # Convert bytes to gigabytes (divide by 1e9) + gbytes = numel * a.element_size() / 1e9 + return gbytes + + @register_metric() + def grid_size( + self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics + ) -> float: + """Report the total grid size (number of thread blocks) for the addmm operation.""" + _, mat1, mat2 = example_inputs + m, k = mat1.size() + k, n = mat2.size() + + # Automatically ensure best_config is in required_metrics + if "best_config" not in self.required_metrics: + self.required_metrics.append("best_config") + + # Return None if best_config is not available (e.g., for baseline implementations) + if metrics.best_config is None: + return None + + # Extract actual block sizes from the best configuration + config = metrics.best_config + BLOCK_M = config.get("BLOCK_M") + BLOCK_N = config.get("BLOCK_N") + + # Return None if block sizes are not available in the config + if BLOCK_M is None or BLOCK_N is None: + return None + + # Calculate grid size using triton.cdiv for consistency with hstu.py + grid_m = triton.cdiv(m, BLOCK_M) + grid_n = triton.cdiv(n, BLOCK_N) + total_grid_size = grid_m * grid_n + + return float(total_grid_size) + @register_x_val(label="(M, N, K)") def get_x_val(self, example_inputs) -> Tuple[int, int, int]: # x-value: computation intensity