Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions tritonbench/operators/addmm/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,68 @@ def flops(
flops = (2 * m * k * n) + (m * n)
return flops

@register_metric()
def op_gflops(
self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics
) -> float:
"""Report the raw number of GFLOPS (not GFLOPS/sec) for the addmm operation."""
_, mat1, mat2 = example_inputs
m, k = mat1.size()
k, n = mat2.size()
flops = (2 * m * k * n) + (m * n)
# Convert FLOPS to GFLOPS (divide by 10^9)
gflops = flops / 1e9
return gflops

@register_metric()
def op_gbytes(
self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics
) -> float:
"""Report the raw number of gigabytes of I/O (not GB/sec) for the addmm operation."""
a, mat1, mat2 = example_inputs
numel = (
a.numel()
+ mat1.numel()
+ mat2.numel()
+ (torch.addmm(a, mat1, mat2).numel())
)
# Convert bytes to gigabytes (divide by 1e9)
gbytes = numel * a.element_size() / 1e9
return gbytes

@register_metric()
def grid_size(
self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics
) -> float:
"""Report the total grid size (number of thread blocks) for the addmm operation."""
_, mat1, mat2 = example_inputs
m, k = mat1.size()
k, n = mat2.size()

# Automatically ensure best_config is in required_metrics
if "best_config" not in self.required_metrics:
self.required_metrics.append("best_config")

# Return None if best_config is not available (e.g., for baseline implementations)
if metrics.best_config is None:
return None

# Extract actual block sizes from the best configuration
config = metrics.best_config
BLOCK_M = config.get("BLOCK_M")
BLOCK_N = config.get("BLOCK_N")

# Return None if block sizes are not available in the config
if BLOCK_M is None or BLOCK_N is None:
return None

# Calculate grid size using triton.cdiv for consistency with hstu.py
grid_m = triton.cdiv(m, BLOCK_M)
grid_n = triton.cdiv(n, BLOCK_N)
total_grid_size = grid_m * grid_n

return float(total_grid_size)

@register_x_val(label="(M, N, K)")
def get_x_val(self, example_inputs) -> Tuple[int, int, int]:
# x-value: computation intensity
Expand Down