Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 119 additions & 0 deletions sakana_kernels/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Reproducing Sakana's Result

We focus on invesigating and understanding the results of Sakana's kernels.

We have thoroughly examined 2 problems. There might be more, and we will continue to update.
* Level 1 Problem 15: `15_Matmul_for_lower_triangular_matrices`
* Level 2 Problem 23: `23_Conv3d_GroupNorm_Mean`

For each problem, we put the kernel code in a folder with the following structure:
* We have the original code from Sakana, which is `_sakana.cu`. This is pure CUDA code and then bind to the model `forward` function using `pybind11`.
* We have the code in the KernelBench format `ModelNew`, which is `_kernelbench.py`. This is a PyTorch module with custom inline CUDA kernel, which is the KernelBench task format.

### Note on Sakana's Eval System

⚠️ **To be clear** ⚠️: There are many differnces between Sakana's eval system and our eval system -- while our eval system is not completely robust, there are some important differences to discuss. Here is an example of the Sakana eval, provided by one of the [Sakana paper authors](https://x.com/RobertTLange/status/1892489402070220989). A huge difference is how they wrap their inline CUDA code -- we query the model to generate an entirely new model and forward function, while they choose to overwrite the forward function of a fixed model. These differences change the behavior of some of the caching hacks that the Sakana model was able to use (notably, the infamous Matmul for TriLower matrices that gets a 150x speedup fails the correctness checks on our eval). Furthermore, we use synchronization markers (CUDA events) in our eval to prevent hacky solutions from passing -- these are not the most robust ways to time kernels (which we want to address too) and may even add some extra unwanted overhead, but at the very least it mitigates some hacky solutions.

You can use `scripts/run_and_check.py` to evaluate **using the KernelBench Eval code**.

### Level 1 Problem 15: `15_Matmul_for_lower_triangular_matrices`

In this problem, it was discovered online that the runtime numbers were incorrect (see [this X thread](https://x.com/main_horse/status/1892446384910987718)). It turned
out that the model-generated kernel was doing nothing (effectively a no-op), and was caching results from the PyTorch reference outputs and using them as the
solution.

To use the KernelBench eval on this problem, you can run the following command:
```
python3 scripts/run_and_check.py ref_origin=kernelbench level=1 problem_id=15 kernel_src_path=sakana_kernels/level1_problem15/15_Matmul_for_lower_triangular_matrices_kernelbench.py
```

For this problem, the CUDA kernel is initialized with a 1D grid as follows:
```
const int threadsPerBlock = 256; // Increased thread count per block
const int numBlocks = N;
triangular_mm_kernel<<<numBlocks, threadsPerBlock>>>(
A.data_ptr<float>(),
B.data_ptr<float>(),
C.data_ptr<float>(),
N
);
```
However, in the actual kernel, we compute the row and column a thread computes as if we're using a 2D grid & block:
```
const int row = blockIdx.y * blockDim.y + threadIdx.y;
const int col = blockIdx.x * blockDim.x + threadIdx.x;
```
In this case: `blockIdx.y` will always be 0, `blockDim.y` will always be 1, and `threadIdx.y` will always be 0. So, the value of `row` will always be 0. So it
actually only computes values for the first row. Instead, the hypothesized reason why this kernel passes correctness checks is that it grabs values from
the same location of allocated memory (using `torch.empty_like`, similar to `malloc` as opposed to `torch.zeros_like` which writes over the values in memory) as the PyTorch reference kernel (which is run first). So this kernel actually is "cheating", but interestingly in the code there's no indication that the model is intentionally doing this. The fix to the first problem is by configuring a 2D grid/block instead:
```
dim3 block(16, 16);
dim3 grid((N + block.x - 1) / block.x, (N + block.y - 1) / block.y);
triangular_mm_kernel<<<grid, block>>>(
A.data_ptr<float>(),
B.data_ptr<float>(),
C.data_ptr<float>(),
N
)
```

To address the hacky "copying" problem, we need to fix the overall eval to address these issues. Notably, on the KernelBench eval this kernel does not pass the correctness checks (but still passes 4/5 tests!). The most obvious solution is calling `torch.cuda.empty_cache` between correctness runs to prevent grabbing any previous solutions. To keep results consistent between the eval and our paper, we choose to add this only for correctness tests to prevent these solutions from passing without influencing runtime numbers. For the future, we also plan to add more rigorous checking during benchmarking as well to prevent convoluted and hacky solutions. We also will call the model generated kernel first to prevent any kind of "stealing solutions"-esque approaches.

### Level 2 Problem 23: `23_Conv3d_GroupNorm_Mean`

In this problem, we have a batch (128) of 1536 elements that are group-normed (you can think of this as being mean 0, with low variance). It turns out
by a (rather hand-wavy) central limit theorem and a further division by the number of elements (~10^3) because we take a mean,
the distribution of each element in the tensor has mean 0 (by symmetry) and a very low variance,
allowing output tensors of all 0's to pass the tests under a small enough error of margin. The workaround to this in the future would
be to change either the kernel itself or the input distribution for the kernel inputs.

To use the KernelBench eval on this problem, you can run the following command:
```
python3 scripts/run_and_check.py ref_origin=kernelbench level=2 problem_id=23 kernel_src_path=sakana_kernels/level2_problem23/23_Conv3d_GroupNorm_Mean_kernelbench.py
```

On NVIDIA L40S, we see with our eval code.
```
========================================
[Eval] Kernel eval result: compiled=True correctness=True metadata={'hardware': 'NVIDIA L40S', 'device': 'cuda:0', 'correctness_trials': '(5 / 5)'} runtime=0.0327 runtime_stats={'mean': 0.0327, 'std': 0.00188, 'min': 0.0307, 'max': 0.0481, 'num_trials': 100, 'hardware': 'NVIDIA L40S', 'device': 'cuda:0'}
----------------------------------------
[Timing] PyTorch Reference Eager exec time: 1.26
[Timing] PyTorch Reference torch.compile time: 0.704 ms
[Timing] Custom Kernel exec time: 0.0327 ms
----------------------------------------
[Speedup] Speedup over eager: 38.53x
[Speedup] Speedup over torch.compile: 21.53x
========================================
```

```
So the actual kernel output

[Eval] Shape of output: tensor([-3.0275e-10, -9.0826e-10, 6.0551e-10, ..., 9.0826e-10,
-1.6651e-09, -9.0826e-10], device='cuda:0')
[Eval] Mean of output_new: 0.0000


The faulty kernel
[Eval] Shape of output_new: tensor([0., 0., 0., ..., 0., 0., 0.], device='cuda:0')
[Eval] Mean of output: -0.0000
```
Interestingly, the faulty kernel doesn't actually use the weights of the convolution at all, making it obvious that it is wrong -- it instead produces all 0 outputs. The actual outputs are all mean 0, std roughly 10^-9, which passes all the atol and rtol checks.


### Takeaways
We appreciate Sakana's effort in providing the kernel code and the evaluation system. This level of transparency help the community understand and reproduce, enabling future progress in this direction.

We will continue working on making the eval robust. To keep results consistent with our current arXiv, we only modify the correctness checks for robustness, but we plan on adding the following changes:
* Prevent cached solutions by clearing the cache (from the caching allocator).
* Drawing from `triton.testing.do_bench`, run more correctness tests and clear on-device caches between runs to prevent incorrect timing analysis.
* To prevent including kernels with easy solutions (e.g. all "0"'s), explicitly filter out benchmark problems with solutions that fall within some interval `[x-0.001,x+0.001]`. Thanks to folks at [METR](https://metr.org/blog/2025-02-14-measuring-automated-kernel-engineering/) for proposing.
* Avoid extra overhead during timing analysis -- i.e. be more intentional and explicit about synchronization instructions.








259 changes: 259 additions & 0 deletions sakana_kernels/Sakana_BenchMarking_(T4).ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "468fc3ApFEq5",
"outputId": "d47d95bd-b1da-49a2-8719-2c638ebfb9c2"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Collecting ninja\n",
" Downloading ninja-1.11.1.3-py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.whl.metadata (5.3 kB)\n",
"Downloading ninja-1.11.1.3-py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (422 kB)\n",
"\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/422.9 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m422.9/422.9 kB\u001b[0m \u001b[31m24.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hInstalling collected packages: ninja\n",
"Successfully installed ninja-1.11.1.3\n"
]
}
],
"source": [
"!pip install ninja"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "2pW3WoZ3ILrf",
"outputId": "00584f1a-16ba-48d2-b47b-044b97f7f431"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Thu Feb 20 07:19:20 2025 \n",
"+-----------------------------------------------------------------------------------------+\n",
"| NVIDIA-SMI 550.54.15 Driver Version: 550.54.15 CUDA Version: 12.4 |\n",
"|-----------------------------------------+------------------------+----------------------+\n",
"| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
"| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\n",
"| | | MIG M. |\n",
"|=========================================+========================+======================|\n",
"| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |\n",
"| N/A 42C P8 9W / 70W | 0MiB / 15360MiB | 0% Default |\n",
"| | | N/A |\n",
"+-----------------------------------------+------------------------+----------------------+\n",
" \n",
"+-----------------------------------------------------------------------------------------+\n",
"| Processes: |\n",
"| GPU GI CI PID Type Process name GPU Memory |\n",
"| ID ID Usage |\n",
"|=========================================================================================|\n",
"| No running processes found |\n",
"+-----------------------------------------------------------------------------------------+\n"
]
}
],
"source": [
"!nvidia-smi"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "MRsLNgskM9T7"
},
"outputs": [],
"source": [
"cu_code = '''\n",
"# include <torch/extension.h>\n",
"# include <cuda.h>\n",
"# include <cuda_runtime.h>\n",
"\n",
"__global__ void triangular_mm_kernel(const float* __restrict__ A,\n",
" const float* __restrict__ B,\n",
" float* __restrict__ C, const int N) {\n",
" // Use 2D block configuration for better occupancy\n",
" const int row = blockIdx.y * blockDim.y + threadIdx.y;\n",
" const int col = blockIdx.x * blockDim.x + threadIdx.x;\n",
"\n",
" if (row < N && col < N) {\n",
" if (col <= row) {\n",
" // Lower triangle computation\n",
" float sum = 0.0f;\n",
" // Process elements in chunks to improve cache utilization\n",
"# pragma unroll 8\n",
" for (int k = col; k <= row; k++) {\n",
" sum += A[row * N + k] * B[k * N + col];\n",
" }\n",
" C[row * N + col] = sum;\n",
" } else {\n",
" // Upper triangle (set to zero)\n",
" C[row * N + col] = 0.0f;\n",
" }\n",
" }\n",
"}\n",
"\n",
"at::Tensor forward(at::Tensor A, at::Tensor B) {\n",
" TORCH_CHECK(A.is_cuda(), \"A must be a CUDA tensor\");\n",
" TORCH_CHECK(B.is_cuda(), \"B must be a CUDA tensor\");\n",
" TORCH_CHECK(A.dim() == 2, \"A must be a 2D tensor\");\n",
" TORCH_CHECK(B.dim() == 2, \"B must be a 2D tensor\");\n",
" TORCH_CHECK(A.size(0) == A.size(1), \"A must be square\");\n",
" TORCH_CHECK(B.size(0) == B.size(1), \"B must be square\");\n",
" TORCH_CHECK(A.size(0) == B.size(0), \"A and B must be the same size\");\n",
"\n",
" int N = A.size(0);\n",
" auto C = torch::empty_like(A);\n",
"\n",
" // Optimize thread count based on matrix size\n",
" const int threadsPerBlock = 256; // Increased thread count per block\n",
" const int numBlocks = N;\n",
"\n",
" triangular_mm_kernel<<<numBlocks, threadsPerBlock>>>(\n",
" A.data_ptr<float>(), B.data_ptr<float>(), C.data_ptr<float>(), N);\n",
"\n",
" cudaError_t err = cudaGetLastError();\n",
" TORCH_CHECK(err == cudaSuccess, \"CUDA kernel failed: \", cudaGetErrorString(err));\n",
" return C;\n",
"}\n",
"\n",
"PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n",
" m.def(\"forward\", &forward,\n",
" \"Strided efficient triangular matrix multiplication (CUDA)\");\n",
"}\n",
"'''\n",
"\n",
"with open(\"tmp.cu\", \"w\") as f:\n",
" f.write(cu_code)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "MsJcwAnGEayx",
"outputId": "b82b4324-75c5-49fc-c0b0-cae3938891a8"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using /root/.cache/torch_extensions/py311_cu124 as PyTorch extensions root...\n",
"Creating extension directory /root/.cache/torch_extensions/py311_cu124/triangular_mm...\n",
"Detected CUDA files, patching ldflags\n",
"Emitting ninja build file /root/.cache/torch_extensions/py311_cu124/triangular_mm/build.ninja...\n",
"/usr/local/lib/python3.11/dist-packages/torch/utils/cpp_extension.py:1964: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. \n",
"If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].\n",
" warnings.warn(\n",
"Building extension module triangular_mm...\n",
"Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)\n",
"Loading extension module triangular_mm...\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Time taken: 0.017692044377326965 ms\n",
"Time taken: 27.793136596679688 ms\n",
"Speedup: 1570.9397966635026\n",
"Time taken: 0.26734623312950134 ms\n",
"Time taken: 29.115659713745117 ms\n",
"Speedup: 108.90619019734466\n",
"True\n"
]
}
],
"source": [
"import torch\n",
"from torch.utils.cpp_extension import load\n",
"from triton.testing import do_bench\n",
"\n",
"# make sure you have nvcc\n",
"cuda_fn = load(\n",
" name=\"triangular_mm\",\n",
" sources=[\"tmp.cu\"],\n",
" extra_cuda_cflags=[\"-O3\", \"--use_fast_math\"],\n",
" with_cuda=True,\n",
" verbose=True,\n",
").forward\n",
"\n",
"N = 4096\n",
"\n",
"def trilmm(a, b): return torch.matmul(a, b).tril()\n",
"\n",
"a = torch.randn(N, N, device=\"cuda\")\n",
"b = torch.randn(N, N, device=\"cuda\")\n",
"\n",
"a = torch.tril(a)\n",
"b = torch.tril(b)\n",
"\n",
"do_bench(lambda: cuda_fn(a, b).mean()) # do this once jic we need more warmup\n",
"\n",
"# Normal testing\n",
"time_new = do_bench(lambda: cuda_fn(a, b))\n",
"print(f\"Time taken: {time_new} ms\")\n",
"\n",
"time_old = do_bench(lambda: trilmm(a, b))\n",
"print(f\"Time taken: {time_old} ms\")\n",
"\n",
"print(f\"Speedup: {time_old / time_new}\")\n",
"\n",
"# Incease rep and do .mean() in case ^ is only capturing dispatches\n",
"time_new = do_bench(lambda: cuda_fn(a, b).mean(), rep=10000)\n",
"print(f\"Time taken: {time_new} ms\")\n",
"\n",
"time_old = do_bench(lambda: trilmm(a, b).mean(), rep=10000)\n",
"print(f\"Time taken: {time_old} ms\")\n",
"\n",
"print(f\"Speedup: {time_old / time_new}\") # should still see a drastic speedup\n",
"\n",
"print(torch.allclose(cuda_fn(a, b), trilmm(a, b)))\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cXfnhiSnGpoD"
},
"outputs": [],
"source": []
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "T4",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Loading