Skip to content

Conversation

etasnadi
Copy link
Contributor

@etasnadi etasnadi commented Sep 18, 2025

I am adding this, because the current conv2d alg #15635 seems to underutilize the GPU -- the Vulkan version #14316 & #14933 is 8-10 times faster on my device. Additionally, the Tensor Cores extension #15813 of the previous alg also seems to be slower than this.

There is another CUDA conv2d proposal that could be related #15805.

Furthermore, this version introduces bank conflict reduction that is not added to Vulkan yet. It seems to be effective on large problems. I expect that this version will be even more efficient than the Vulkan backend.

I do not support f16 yet, a future contribution might do that. Currently this alg will be used when for f32 inputs, otherwise it falls back to the previous implementation. GGML_CUDA_USE_LEGACY_CONV forces to use the previous (probably slower) implementation.

Perf of previous on RTX 2060:

$ GGML_CUDA_USE_LEGACY_CONV=1 ./bin/test-backend-ops -o CONV_2D -b CUDA0 perf

CONV_2D(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                      4 runs - 327359.25 us/run - 137.42 GFLOP/run - 419.79 GFLOPS
  CONV_2D(ne_input=[19,19,8,16],ne_kernel=[4,4,8,128],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                2992 runs -   357.38 us/run - 133.69 MFLOP/run - 374.09 GFLOPS
  CONV_2D(ne_input=[19,19,8,16],ne_kernel=[4,4,8,130],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                2948 runs -   362.97 us/run - 135.78 MFLOP/run - 374.08 GFLOPS
  CONV_2D(ne_input=[19,19,4,16],ne_kernel=[2,2,4,4],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                139264 runs -     7.53 us/run - 642.82 kFLOP/run -  85.42 GFLOPS
  CONV_2D(ne_input=[224,224,3,1],ne_kernel=[3,3,3,8],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                14358 runs -    92.52 us/run -  20.90 MFLOP/run - 225.87 GFLOPS
  CONV_2D(ne_input=[224,224,1,1],ne_kernel=[2,2,1,8],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                24576 runs -    49.88 us/run -   2.78 MFLOP/run -  55.83 GFLOPS
  CONV_2D(ne_input=[224,224,1,8],ne_kernel=[2,2,1,8],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 4489 runs -   386.69 us/run -  22.28 MFLOP/run -  57.61 GFLOPS
  CONV_2D(ne_input=[58,58,32,1],ne_kernel=[3,3,32,64],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                3468 runs -   328.26 us/run - 115.40 MFLOP/run - 351.56 GFLOPS
  CONV_2D(ne_input=[58,58,32,8],ne_kernel=[3,3,32,64],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 436 runs -  2535.99 us/run - 923.24 MFLOP/run - 364.05 GFLOPS
  CONV_2D(ne_input=[16,16,128,8],ne_kernel=[3,3,128,512],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                      220 runs -  4808.02 us/run -   1.85 GFLOP/run - 384.54 GFLOPS

Perf of proposed:

 
$ ./bin/test-backend-ops -o CONV_2D -b CUDA0 perf
CONV_2D(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     41 runs - 24946.49 us/run - 137.42 GFLOP/run -   5.51 TFLOPS
  CONV_2D(ne_input=[19,19,8,16],ne_kernel=[4,4,8,128],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):               20944 runs -    49.06 us/run - 133.69 MFLOP/run -   2.73 TFLOPS
  CONV_2D(ne_input=[19,19,8,16],ne_kernel=[4,4,8,130],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):               14740 runs -    69.12 us/run - 135.78 MFLOP/run -   1.96 TFLOPS
  CONV_2D(ne_input=[19,19,4,16],ne_kernel=[2,2,4,4],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                204800 runs -     5.04 us/run - 642.82 kFLOP/run - 127.47 GFLOPS
  CONV_2D(ne_input=[224,224,3,1],ne_kernel=[3,3,3,8],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                52646 runs -    20.26 us/run -  20.90 MFLOP/run -   1.03 TFLOPS
  CONV_2D(ne_input=[224,224,1,1],ne_kernel=[2,2,1,8],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                81920 runs -    12.37 us/run -   2.78 MFLOP/run - 225.14 GFLOPS
  CONV_2D(ne_input=[224,224,1,8],ne_kernel=[2,2,1,8],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                13467 runs -    85.40 us/run -  22.28 MFLOP/run - 260.88 GFLOPS
  CONV_2D(ne_input=[58,58,32,1],ne_kernel=[3,3,32,64],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):               24276 runs -    42.68 us/run - 115.40 MFLOP/run -   2.70 TFLOPS
  CONV_2D(ne_input=[58,58,32,8],ne_kernel=[3,3,32,64],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                3270 runs -   311.18 us/run - 923.24 MFLOP/run -   2.97 TFLOPS
  CONV_2D(ne_input=[16,16,128,8],ne_kernel=[3,3,128,512],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     2200 runs -   462.72 us/run -   1.85 GFLOP/run -   4.00 TFLOPS

* Extra: reduces bank conflicts
@etasnadi etasnadi changed the title Vulkan direct conv ported to CUDA ggml-cude: Vulkan direct conv ported to CUDA Sep 18, 2025
@etasnadi etasnadi changed the title ggml-cude: Vulkan direct conv ported to CUDA ggml-cuda: Vulkan direct conv 2D ported to CUDA Sep 18, 2025
@etasnadi
Copy link
Contributor Author

@Green-Sky Can you check it?

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Sep 18, 2025
@Green-Sky
Copy link
Collaborator

@Green-Sky Can you check it?

sd.cp relies exclusively on f16 kernels.

@bssrdf
Copy link
Contributor

bssrdf commented Sep 18, 2025

May I suggest using ggml_cuda_cast<float> to add support for fp16? It won't be faster, but at least @Green-Sky can test in sd.cpp.

@Green-Sky
Copy link
Collaborator

May I suggest using ggml_cuda_cast<float> to add support for fp16? It won't be faster, but at least @Green-Sky can test in sd.cpp.

This would make things easy for me, yes.

BTW, forgot to thank you @etasnadi for working on this :)

... even though we now have 3 competing prs, more or less.

@bssrdf
Copy link
Contributor

bssrdf commented Sep 18, 2025

May I suggest using ggml_cuda_cast<float> to add support for fp16? It won't be faster, but at least @Green-Sky can test in sd.cpp.

This would make things easy for me, yes.

BTW, forgot to thank you @etasnadi for working on this :)

... even though we now have 3 competing prs, more or less.

I'll close my PR. This one is way better:)

@mnehete32
Copy link
Contributor

May I suggest using ggml_cuda_cast<float> to add support for fp16? It won't be faster, but at least @Green-Sky can test in sd.cpp.

This would make things easy for me, yes.
BTW, forgot to thank you @etasnadi for working on this :)
... even though we now have 3 competing prs, more or less.

I'll close my PR. This one is way better:)

Same, closing mine too.

@etasnadi
Copy link
Contributor Author

etasnadi commented Sep 18, 2025 via email

@mnehete32
Copy link
Contributor

mnehete32 commented Sep 18, 2025

I’m new to CUDA but I’d love to give this a shot @Green-Sky @etasnadi if the fp16 isn’t super urgent?, I can take a crack at it in the next week or two.

Maybe you can add a parallel pr based on this for f16? https://etasnadi.com Proton Mail Android alkalmazásból küldve

-------- Eredeti üzenet -------- 2025. 09. 18. 19:32-kor, Erik Scholz ezt írta:
Green-Sky left a comment [(ggml-org/llama.cpp#16088)](#16088 (comment)) > May I suggest using ggml_cuda_cast to add support for fp16? It won't be faster, but at least @.(https://github.com/Green-Sky) can test in sd.cpp. This would make things easy for me, yes. BTW, forgot to thank you @.(https://github.com/etasnadi) for working on this :) ... even though we now have 3 competing prs, more or less. — Reply to this email directly, [view it on GitHub](#16088 (comment)), or unsubscribe. You are receiving this because you were mentioned.Message ID: @.***>

Copy link
Collaborator

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you give me a list of what parts of the code you changed relative to the Vulkan version, if any? Some things like fastdiv and how to retrieve the SM count have equivalents in the CUDA backend. But if this is just a copy-paste of the Vulkan code I would preferably change as little as possible.

@etasnadi
Copy link
Contributor Author

Can you give me a list of what parts of the code you changed relative to the Vulkan version, if any? Some things like fastdiv and how to retrieve the SM count have equivalents in the CUDA backend. But if this is just a copy-paste of the Vulkan code I would preferably change as little as possible.

Can you give any reference to doing fastdiv/sm_count() in proper ggml-cuda way? I will refactor then.

Only the necessary things are changed compared to Vulkan, but they are significant

  • Vulkan has specialization constants what is missing in CUDA so the kernel selection/initialization is different.
  • The unrolls in the kernel are also different and
  • The coopmat2 api we use in Vulkan is much different than the mma APIs found in CUDA, so that part is completely removed before porting.
  • Shmem size check is not needed, etc.
  • The core algorithm is mostly the same, but it is augmented with different shmem indexing to minimize bank conflicts, this is not present in the Vulkan kernel yet.

IMO it already changes as little as possible compared to Vulkan.

@JohannesGaessler
Copy link
Collaborator

JohannesGaessler commented Sep 18, 2025

For a fastdiv example, look at e.g. binbcast.cu, get the SM count via ggml_cuda_info().devices[ggml_cuda_get_device()].nsm.

@etasnadi
Copy link
Contributor Author

etasnadi commented Sep 19, 2025

@bssrdf Do you want to contribute the ggml-cuda conformant fastdiv as a patch to my branch or in a separate PR so everyone gets the authorship for conv2d for their effort?

@bssrdf
Copy link
Contributor

bssrdf commented Sep 19, 2025

@bssrdf Do you want to contribute the ggml-cuda conformant fastdiv as a patch to my branch or in a separate PR so everyone gets the authorship for conv2d for their effort?

@etasnadi, I can give a try. Will do a patch on your branch.
@Green-Sky, you can try etasnadi#1 in sd.cpp.

@Green-Sky
Copy link
Collaborator

In the f16 pr by @bssrdf , we found that sd.cpp crashes with this pr. I double checked by forcing sd.cpp to use f32 for the kernel without @bssrdf 's pr.

[INFO ] stable-diffusion.cpp:2166 - generating 2 latent images completed, taking 84.68s
[INFO ] stable-diffusion.cpp:2169 - decoding 2 latents
[INFO ] ggml_extend.hpp:1648 - vae offload params ( 94.47 MB, 140 tensors) to runtime backend (CUDA0), taking 0.01s
[DEBUG] ggml_extend.hpp:1550 - vae compute buffer size: 1928.64 MB(VRAM)
[ERROR] ggml_extend.hpp:71   - CUDA error: an illegal memory access was encountered
[ERROR] ggml_extend.hpp:71   -   current device: 0, in function ggml_backend_cuda_synchronize at /build/pqlxhx4zgf1dr2wyx5qdm2gb2b6c73sf-source/ggml/src/ggml-cuda/ggml-cuda.cu:2628
[ERROR] ggml_extend.hpp:71   -   cudaStreamSynchronize(cuda_ctx->stream())
/build/pqlxhx4zgf1dr2wyx5qdm2gb2b6c73sf-source/ggml/src/ggml-cuda/ggml-cuda.cu:88: CUDA error
#4  0x000000000057b955 in ggml_backend_cuda_synchronize (backend=<optimized out>) at ggml/src/ggml-cuda/ggml-cuda.cu:2628
#5  0x0000000000a40714 in ggml_backend_synchronize (backend=backend@entry=0x22775310) at ggml/src/ggml-backend.cpp:327
#6  0x0000000000a40a6b in ggml_backend_graph_compute (backend=0x22775310, cgraph=<optimized out>) at ggml/src/ggml-backend.cpp:353

$ result/bin/sd -m models/CyberRealistic_V9_FP16.safetensors --sampling-method dpm++2m --scheduler karras --cfg-scale 5 -W 768 -H 1024 --diffusion-fa --steps 20 -b 2 -v -n "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry" -p "a lovely cat" --vae-conv-direct --offload-to-cpu

@bssrdf
Copy link
Contributor

bssrdf commented Sep 19, 2025

In the f16 pr by @bssrdf , we found that sd.cpp crashes with this pr. I double checked by forcing sd.cpp to use f32 for the kernel without @bssrdf 's pr.

[INFO ] stable-diffusion.cpp:2166 - generating 2 latent images completed, taking 84.68s
[INFO ] stable-diffusion.cpp:2169 - decoding 2 latents
[INFO ] ggml_extend.hpp:1648 - vae offload params ( 94.47 MB, 140 tensors) to runtime backend (CUDA0), taking 0.01s
[DEBUG] ggml_extend.hpp:1550 - vae compute buffer size: 1928.64 MB(VRAM)
[ERROR] ggml_extend.hpp:71   - CUDA error: an illegal memory access was encountered
[ERROR] ggml_extend.hpp:71   -   current device: 0, in function ggml_backend_cuda_synchronize at /build/pqlxhx4zgf1dr2wyx5qdm2gb2b6c73sf-source/ggml/src/ggml-cuda/ggml-cuda.cu:2628
[ERROR] ggml_extend.hpp:71   -   cudaStreamSynchronize(cuda_ctx->stream())
/build/pqlxhx4zgf1dr2wyx5qdm2gb2b6c73sf-source/ggml/src/ggml-cuda/ggml-cuda.cu:88: CUDA error
#4  0x000000000057b955 in ggml_backend_cuda_synchronize (backend=<optimized out>) at ggml/src/ggml-cuda/ggml-cuda.cu:2628
#5  0x0000000000a40714 in ggml_backend_synchronize (backend=backend@entry=0x22775310) at ggml/src/ggml-backend.cpp:327
#6  0x0000000000a40a6b in ggml_backend_graph_compute (backend=0x22775310, cgraph=<optimized out>) at ggml/src/ggml-backend.cpp:353

$ result/bin/sd -m models/CyberRealistic_V9_FP16.safetensors --sampling-method dpm++2m --scheduler karras --cfg-scale 5 -W 768 -H 1024 --diffusion-fa --steps 20 -b 2 -v -n "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry" -p "a lovely cat" --vae-conv-direct --offload-to-cpu

@Green-Sky, it may be due to my changes. I'll investigate.

@Green-Sky
Copy link
Collaborator

In the f16 pr by @bssrdf , we found that sd.cpp crashes with this pr. I double checked by forcing sd.cpp to use f32 for the kernel without @bssrdf 's pr.

@Green-Sky, it may be due to my changes. I'll investigate.

I redid the test without your changes, and the issue was the same, as I state right there.

@bssrdf
Copy link
Contributor

bssrdf commented Sep 19, 2025

In the f16 pr by @bssrdf , we found that sd.cpp crashes with this pr. I double checked by forcing sd.cpp to use f32 for the kernel without @bssrdf 's pr.

@Green-Sky, it may be due to my changes. I'll investigate.

I redid the test without your changes, and the issue was the same, as I state right there.

@Green-Sky, without my change, it will fall back to using the slow direct version. Did it even fail there?

@Green-Sky
Copy link
Collaborator

Green-Sky commented Sep 19, 2025

In the f16 pr by @bssrdf , we found that sd.cpp crashes with this pr. I double checked by forcing sd.cpp to use f32 for the kernel without @bssrdf 's pr.

@Green-Sky, it may be due to my changes. I'll investigate.

I redid the test without your changes, and the issue was the same, as I state right there.

@Green-Sky, without my change, it will fall back to using the slow direct version. Did it even fail there?

I patched sd.cpp to cast the kernel to f32, so it would fall back.
(ggml_cast w at the direct callsite)

-    x = ggml_conv_2d_direct(ctx, w, x, s0, s1, p0, p1, d0, d1);
+    x = ggml_conv_2d_direct(ctx, ggml_cast(ctx, w, GGML_TYPE_F32), x, s0, s1, p0, p1, d0, d1);

I guess I should have made my report more wordy (:

edit: fun side note, it seems like with the current naive fallback and f32 kernel, sd.cpp vae decode is ever so slightly faster (~35s vs ~33s)

@JohannesGaessler
Copy link
Collaborator

So what is the current state of convolution? Is this PR in its current state something that the three of you can agree should be reviewed and merged?

@etasnadi
Copy link
Contributor Author

etasnadi commented Sep 24, 2025

So what is the current state of convolution? Is this PR in its current state something that the three of you can agree should be reviewed and merged?

@bssrdf added the missing parts for this PR in etasnadi#1. We had an issue that seems to be fixed now. If it is ready, I will update the branch behind this PR to be merged (so it will contain commits authored by both of us).

Are you going to squash the commits to two commits to preserve @bssrdf's commit or you plan to create one commit with co-authored-by when merging?

I think @mnehete32 will re-open his PR and rebase their code to this to add Tensor Cores support.

@JohannesGaessler
Copy link
Collaborator

Pull requests in llama.cpp/ggml are always squashed to a single commit.

@bssrdf
Copy link
Contributor

bssrdf commented Sep 24, 2025

So what is the current state of convolution? Is this PR in its current state something that the three of you can agree should be reviewed and merged?

My testing so far showed this PR is still lagging behind im2col+gemm in terms of performance, by several times for some cases. The only benefit is memory savings. We still have an memory access issue to be resolved.

Down the road, I think applications using conv_2d op like sd.cpp should choose either im2col or implicit depending on problem size. For small activation dimensions, im2col is still the way to go.

@etasnadi
Copy link
Contributor Author

Now I think that the implicit GEMM is faster, so I suggest to merge that one. Actually when it was tested with stable-diffusion the code wasn't activated, so always the previous, less efficient implementation was used.

When I properly merged the implicit gemm proposed by @bssrdf to my tree it showed that it performs somewhat better than this alg so I suggest to continue to work on that PR.

See my comment for details: etasnadi#1 (comment)

@etasnadi
Copy link
Contributor Author

Now I think that the implicit GEMM is faster, so I suggest to merge that one. Actually when it was tested with stable-diffusion the code wasn't activated, so always the previous, less efficient implementation was used.

When I properly merged the implicit gemm proposed by @bssrdf to my tree it showed that it performs somewhat better than this alg so I suggest to continue to work on that PR.

See my comment for details: etasnadi#1 (comment)

Regarding the suggestion to merge the implicit gemm:

I do not see any license in the repository @bssrdf forked the code (https://github.com/Qwesh157/conv_op_optimization) from.

@JohannesGaessler, @bssrdf I suggest to make sure that we can legally use all the code pulled from that repo in llama.cpp.

@bssrdf
Copy link
Contributor

bssrdf commented Sep 27, 2025

Now I think that the implicit GEMM is faster, so I suggest to merge that one. Actually when it was tested with stable-diffusion the code wasn't activated, so always the previous, less efficient implementation was used.
When I properly merged the implicit gemm proposed by @bssrdf to my tree it showed that it performs somewhat better than this alg so I suggest to continue to work on that PR.
See my comment for details: etasnadi#1 (comment)

Regarding the suggestion to merge the implicit gemm:

I do not see any license in the repository @bssrdf forked the code (https://github.com/Qwesh157/conv_op_optimization) from.

@JohannesGaessler, @bssrdf I suggest to make sure that we can legally use all the code pulled from that repo in llama.cpp.

@etasnadi, my PR may be slightly faster than yours, but it is based on https://github.com/Qwesh157/conv_op_optimization which has no license (I am too lazy to write from scratch). I reached out to the author and asked for adding a license but who knows whether/when they will respond. Pending the license issue, we may continue working on your PR and improve it further ( I have some ideas). Plus your code is more in line with ggml's style. What do you think?

@etasnadi
Copy link
Contributor Author

Now I think that the implicit GEMM is faster, so I suggest to merge that one. Actually when it was tested with stable-diffusion the code wasn't activated, so always the previous, less efficient implementation was used.
When I properly merged the implicit gemm proposed by @bssrdf to my tree it showed that it performs somewhat better than this alg so I suggest to continue to work on that PR.
See my comment for details: etasnadi#1 (comment)

Regarding the suggestion to merge the implicit gemm:
I do not see any license in the repository @bssrdf forked the code (https://github.com/Qwesh157/conv_op_optimization) from.
@JohannesGaessler, @bssrdf I suggest to make sure that we can legally use all the code pulled from that repo in llama.cpp.

@etasnadi, my PR may be slightly faster than yours, but it is based on https://github.com/Qwesh157/conv_op_optimization which has no license (I am too lazy to write from scratch). I reached out to the author and asked for adding a license but who knows whether/when they will respond. Pending the license issue, we may continue working on your PR and improve it further ( I have some ideas). Plus your code is more in line with ggml's style. What do you think?

That's great. It depends on how much you want to work on this.

You can either wait for the approval from their side or you can integrate their optimizations to this PR. I think the optimizations are additional, so merging the two would make sense and since you don't use their code directly just the ideas you don't need any license.

@Qwesh157
Copy link

Qwesh157 commented Sep 28, 2025

I can give some of my suggestions. I don’t know how Vulkan implements conv2d, but the code in my repo is mainly intended to push the GPU CUDA cores to their performance limits under as large conv2d parameters as possible (larger h, w, and number of filters, etc.), rather than being optimized for a specific conv2d parameter. So you can observe that the Implicit GEMM implementation is generally closer to the theoretical performance value under large size. If the goal is to achieve ideal performance in all aspects, I suggest adding as many tile shapes as possible. For example, in my repo I used 128x128, but you could also try 32x128, 64x128, and so on to generate enough CTAs to fully occupy the GPU.

@bssrdf
Copy link
Contributor

bssrdf commented Sep 28, 2025

I can give some of my suggestions. I don’t know how Vulkan implements conv2d, but the code in my repo is mainly intended to push the GPU CUDA cores to their performance limits under as large conv2d parameters as possible (larger h, w, and number of filters, etc.), rather than being optimized for a specific conv2d parameter. So you can observe that the Implicit GEMM implementation is generally closer to the theoretical performance value under large size. If the goal is to achieve ideal performance in all aspects, I suggest adding as many tile shapes as possible. For example, in my repo I used 128x128, but you could also try 32x128, 64x128, and so on to generate enough CTAs to fully occupy the GPU.

@Qwesh157, thanks for the suggestions and offering the generous license. I agree. To achieve higher performance for all sizes of input, there has to be multiple block/tile shapes. This PR has 3 tile shapes and I already see the differences over limited test cases. I am going to implement other shapes in my PR. I am also exploring other optimizations, e.g., vectorized load, split-k etc.

@JohannesGaessler
Copy link
Collaborator

Is there now consensus that the implicit GEMM kernel is the one that should be reviewed and merged?

@Qwesh157 would you be fine with licensing your code to us with the MIT license? We already have a copy of the MIT license at the project root, so my suggestion would be to simply add a copyright notice and a link to your repository in the file containing the copied CUDA code.

@etasnadi
Copy link
Contributor Author

Is there now consensus that the implicit GEMM kernel is the one that should be reviewed and merged?

@Qwesh157 would you be fine with licensing your code to us with the MIT license? We already have a copy of the MIT license at the project root, so my suggestion would be to simply add a copyright notice and a link to your repository in the file containing the copied CUDA code.

Currently, @Qwesh157's implicit implementation fork proposed by @bssrdf in #15805 is significantly faster (mostly 20% improvement, but there is a test case where it is 100%), but that code uses optimizations missing from this PR. However, this PR also has optimizations missing from the other. I will add warp tiling and double buffering asap and you can decide base on the numbers. I expect that once I updated the code with these optimizations, this will be the fastest. (Adding optimizations of this PR to #15805 is also possible, but I know this code better, so I will work on this.)

With stable diffusion, the results are mixed. It was shown on a device that this PR is already marginally better than implcit conv (etasnadi#1 (comment)), but on my device the implicit conv is faster (etasnadi#1 (comment)).

@Qwesh157
Copy link

Is there now consensus that the implicit GEMM kernel is the one that should be reviewed and merged?现在是否已经达成共识,认为隐式 GEMM 内核是应该审查和合并的内核?

@Qwesh157 would you be fine with licensing your code to us with the MIT license? We already have a copy of the MIT license at the project root, so my suggestion would be to simply add a copyright notice and a link to your repository in the file containing the copied CUDA code.您同意使用 MIT 许可证授权您的代码给我们吗?我们在项目根目录下已经有一份 MIT 许可证的副本,所以我的建议是,在包含复制的 CUDA 代码的文件中添加版权声明和指向您代码库的链接。

OK, i changed my license.

@Qwesh157
Copy link

Qwesh157 commented Sep 29, 2025

Is there now consensus that the implicit GEMM kernel is the one that should be reviewed and merged?现在是否已经达成共识,认为隐式 GEMM 内核是应该审查和合并的内核?
@Qwesh157 would you be fine with licensing your code to us with the MIT license? We already have a copy of the MIT license at the project root, so my suggestion would be to simply add a copyright notice and a link to your repository in the file containing the copied CUDA code.您同意使用 MIT 许可证授权您的代码给我们吗?我们在项目根目录下已经有一份 MIT 许可证的副本,所以我的建议是,在包含复制的 CUDA 代码的文件中添加版权声明和指向您代码库的链接。

Currently, @Qwesh157's implicit implementation fork proposed by @bssrdf in #15805 is significantly faster (mostly 20% improvement, but there is a test case where it is 100%), but that code uses optimizations missing from this PR. However, this PR also has optimizations missing from the other. I will add warp tiling and double buffering asap and you can decide base on the numbers. I expect that once I updated the code with these optimizations, this will be the fastest. (Adding optimizations of this PR to #15805 is also possible, but I know this code better, so I will work on this.)目前, #15805 中提出的隐式实现分支速度明显更快(大部分情况下提升了 20%,但有一个测试用例提升了 100%),但该代码使用了此 PR 中缺失的优化。然而,此 PR 也包含其他 PR 中缺失的优化。我会尽快添加 Warp Tiling 和双缓冲,您可以根据实际数据来判断。我预计,一旦我使用这些优化更新代码,这将是最快的。(将此 PR 的优化添加到 #15805 也是可能的,但我更了解这段代码,所以我会继续改进。)

With stable diffusion, the results are mixed. It was shown on a device that this PR is already marginally better than implcit conv (etasnadi#1 (comment)), but on my device the implicit conv is faster (etasnadi#1 (comment)).在稳定扩散的情况下,结果好坏参半。在设备上显示,此 PR 已经略优于隐式卷积 ( etasnadi#1 (comment) ),但在我的设备上,隐式卷积速度更快 ( etasnadi#1 (comment) )。

Since the current implementation does not specifically take FP16 scenarios into account, my repo may not fully exploit the CUDA core FP16 units, as certain techniques are not (yet) implemented. Examples include vectorized load/store, swizzling, the use of hfma2 instructions, larger tiles (e.g., 256×128) to increase computational intensity, multi-stage pipelining(>2) on devices with larger shared memory, cp.async on post-ampere arch GPU, and iterator-based mechanisms (as in CUTLASS), among others.

@etasnadi
Copy link
Contributor Author

etasnadi commented Oct 2, 2025

Is there now consensus that the implicit GEMM kernel is the one that should be reviewed and merged?现在是否已经达成共识,认为隐式 GEMM 内核是应该审查和合并的内核?
@Qwesh157 would you be fine with licensing your code to us with the MIT license? We already have a copy of the MIT license at the project root, so my suggestion would be to simply add a copyright notice and a link to your repository in the file containing the copied CUDA code.您同意使用 MIT 许可证授权您的代码给我们吗?我们在项目根目录下已经有一份 MIT 许可证的副本,所以我的建议是,在包含复制的 CUDA 代码的文件中添加版权声明和指向您代码库的链接。

Currently, @Qwesh157's implicit implementation fork proposed by @bssrdf in #15805 is significantly faster (mostly 20% improvement, but there is a test case where it is 100%), but that code uses optimizations missing from this PR. However, this PR also has optimizations missing from the other. I will add warp tiling and double buffering asap and you can decide base on the numbers. I expect that once I updated the code with these optimizations, this will be the fastest. (Adding optimizations of this PR to #15805 is also possible, but I know this code better, so I will work on this.)目前, #15805 中提出的隐式实现分支速度明显更快(大部分情况下提升了 20%,但有一个测试用例提升了 100%),但该代码使用了此 PR 中缺失的优化。然而,此 PR 也包含其他 PR 中缺失的优化。我会尽快添加 Warp Tiling 和双缓冲,您可以根据实际数据来判断。我预计,一旦我使用这些优化更新代码,这将是最快的。(将此 PR 的优化添加到 #15805 也是可能的,但我更了解这段代码,所以我会继续改进。)
With stable diffusion, the results are mixed. It was shown on a device that this PR is already marginally better than implcit conv (etasnadi#1 (comment)), but on my device the implicit conv is faster (etasnadi#1 (comment)).在稳定扩散的情况下,结果好坏参半。在设备上显示,此 PR 已经略优于隐式卷积 ( etasnadi#1 (comment) ),但在我的设备上,隐式卷积速度更快 ( etasnadi#1 (comment) )。

Since the current implementation does not specifically take FP16 scenarios into account, my repo may not fully exploit the CUDA core FP16 units, as certain techniques are not (yet) implemented. Examples include vectorized load/store, swizzling, the use of hfma2 instructions, larger tiles (e.g., 256×128) to increase computational intensity, multi-stage pipelining(>2) on devices with larger shared memory, cp.async on post-ampere arch GPU, and iterator-based mechanisms (as in CUTLASS), among others.

I added warptiling to this PR and your kernel is 3.75% faster than this PR on average on my device. (I removed variants from this PR for fair comparison).

Here are the results with warptiling (https://github.com/etasnadi/llama.cppxx.git):

The format is: test_case implicit_flops/this_pr_flops: speedup

CONV_2D(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0)   5650.0/5180.0: 1.09
CONV_2D(ne_input=[19,19,8,16],ne_kernel=[4,4,8,128],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0)   4740.0/3870.0: 1.22
CONV_2D(ne_input=[19,19,8,16],ne_kernel=[4,4,8,130],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0)   2550.0/2430.0: 1.05
CONV_2D(ne_input=[19,19,4,16],ne_kernel=[2,2,4,4],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0)   49.12/64.34: 0.76
CONV_2D(ne_input=[224,224,3,1],ne_kernel=[3,3,3,8],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0)   195.18/254.59: 0.77
CONV_2D(ne_input=[224,224,1,1],ne_kernel=[2,2,1,8],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0)   48.94/60.8: 0.80
CONV_2D(ne_input=[224,224,1,8],ne_kernel=[2,2,1,8],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0)   50.56/66.92: 0.76
CONV_2D(ne_input=[58,58,32,1],ne_kernel=[3,3,32,64],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0)   2110.0/1630.0: 1.29
CONV_2D(ne_input=[58,58,32,8],ne_kernel=[3,3,32,64],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0)   2810.0/2640.0: 1.06
CONV_2D(ne_input=[16,16,128,8],ne_kernel=[3,3,128,512],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0)   4310.0/3950.0: 1.09
CONV_2D(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0)   5930.0/5310.0: 1.12
CONV_2D(ne_input=[19,19,8,16],ne_kernel=[4,4,8,128],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0)   4760.0/3740.0: 1.27
CONV_2D(ne_input=[19,19,8,16],ne_kernel=[4,4,8,130],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0)   2590.0/2360.0: 1.10
CONV_2D(ne_input=[19,19,4,16],ne_kernel=[2,2,4,4],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0)   49.62/62.12: 0.80
CONV_2D(ne_input=[224,224,3,1],ne_kernel=[3,3,3,8],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0)   197.55/247.48: 0.80
CONV_2D(ne_input=[224,224,1,1],ne_kernel=[2,2,1,8],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0)   49.66/58.67: 0.85
CONV_2D(ne_input=[224,224,1,8],ne_kernel=[2,2,1,8],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0)   51.79/64.77: 0.80
CONV_2D(ne_input=[58,58,32,1],ne_kernel=[3,3,32,64],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0)   2130.0/1580.0: 1.35
CONV_2D(ne_input=[58,58,32,8],ne_kernel=[3,3,32,64],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0)   2820.0/2580.0: 1.09
CONV_2D(ne_input=[16,16,128,8],ne_kernel=[3,3,128,512],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0)   4380.0/3850.0: 1.14
CONV_2D(ne_input=[96,128,512,1],ne_kernel=[3,3,512,512],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0)   5770.0/5170.0: 1.12
CONV_2D(ne_input=[192,256,512,1],ne_kernel=[3,3,512,512],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0)   5970.0/5400.0: 1.11
CONV_2D(ne_input=[384,512,256,1],ne_kernel=[3,3,256,256],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0)   5980.0/5440.0: 1.10
CONV_2D(ne_input=[768,1024,128,1],ne_kernel=[3,3,128,128],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0)   5940.0/5420.0: 1.10
CONV_2D(ne_input=[96,128,512,1],ne_kernel=[1,1,512,512],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0)   5340.0/5020.0: 1.06
CONV_2D(ne_input=[96,128,4,1],ne_kernel=[1,1,4,4],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0)   23.08/24.21: 0.95
CONV_2D(ne_input=[96,128,4,1],ne_kernel=[3,3,4,512],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0)   3200.0/3070.0: 1.04
CONV_2D(ne_input=[384,512,512,1],ne_kernel=[3,3,512,512],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0)   6020.0/5490.0: 1.10
CONV_2D(ne_input=[384,512,512,1],ne_kernel=[3,3,512,256],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0)   5980.0/5450.0: 1.10
CONV_2D(ne_input=[384,512,512,1],ne_kernel=[1,1,512,256],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0)   5580.0/5290.0: 1.05
CONV_2D(ne_input=[768,1024,256,1],ne_kernel=[3,3,256,128],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0)   5980.0/5430.0: 1.10
CONV_2D(ne_input=[768,1024,256,1],ne_kernel=[1,1,256,128],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0)   5320.0/5100.0: 1.04
CONV_2D(ne_input=[768,1024,256,1],ne_kernel=[3,3,256,256],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0)   5990.0/5460.0: 1.10
CONV_2D(ne_input=[768,1024,128,1],ne_kernel=[3,3,128,3],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0)   148.53/135.51: 1.10
Mean speedup: 1.0375495514365727 ([0.7555289898386133, 1.3481012658227849])

This PR is faster on memory bound problems and your kernel is faster when the input is compute bound.

I checked your code to see why is the difference so I realized that your tiling strategy is heavily optimized to Nvidia cards. Now I am really curious if your optimizations would be also effective on non-Nvidia devices as well. @bssrdf are you motivated enough to also add the Vulkan port of the kernel?

In summary, I suggest to reopen #15805 and merge the features of this PR (variants and fastdiv) so the alg could be even faster!

@bssrdf
Copy link
Contributor

bssrdf commented Oct 3, 2025

Is there now consensus that the implicit GEMM kernel is the one that should be reviewed and merged?现在是否已经达成共识,认为隐式 GEMM 内核是应该审查和合并的内核?
@Qwesh157 would you be fine with licensing your code to us with the MIT license? We already have a copy of the MIT license at the project root, so my suggestion would be to simply add a copyright notice and a link to your repository in the file containing the copied CUDA code.您同意使用 MIT 许可证授权您的代码给我们吗?我们在项目根目录下已经有一份 MIT 许可证的副本,所以我的建议是,在包含复制的 CUDA 代码的文件中添加版权声明和指向您代码库的链接。

Currently, @Qwesh157's implicit implementation fork proposed by @bssrdf in #15805 is significantly faster (mostly 20% improvement, but there is a test case where it is 100%), but that code uses optimizations missing from this PR. However, this PR also has optimizations missing from the other. I will add warp tiling and double buffering asap and you can decide base on the numbers. I expect that once I updated the code with these optimizations, this will be the fastest. (Adding optimizations of this PR to #15805 is also possible, but I know this code better, so I will work on this.)目前, #15805 中提出的隐式实现分支速度明显更快(大部分情况下提升了 20%,但有一个测试用例提升了 100%),但该代码使用了此 PR 中缺失的优化。然而,此 PR 也包含其他 PR 中缺失的优化。我会尽快添加 Warp Tiling 和双缓冲,您可以根据实际数据来判断。我预计,一旦我使用这些优化更新代码,这将是最快的。(将此 PR 的优化添加到 #15805 也是可能的,但我更了解这段代码,所以我会继续改进。)
With stable diffusion, the results are mixed. It was shown on a device that this PR is already marginally better than implcit conv (etasnadi#1 (comment)), but on my device the implicit conv is faster (etasnadi#1 (comment)).在稳定扩散的情况下,结果好坏参半。在设备上显示,此 PR 已经略优于隐式卷积 ( etasnadi#1 (comment) ),但在我的设备上,隐式卷积速度更快 ( etasnadi#1 (comment) )。

Since the current implementation does not specifically take FP16 scenarios into account, my repo may not fully exploit the CUDA core FP16 units, as certain techniques are not (yet) implemented. Examples include vectorized load/store, swizzling, the use of hfma2 instructions, larger tiles (e.g., 256×128) to increase computational intensity, multi-stage pipelining(>2) on devices with larger shared memory, cp.async on post-ampere arch GPU, and iterator-based mechanisms (as in CUTLASS), among others.

I added warptiling to this PR and your kernel is 3.75% faster than this PR on average on my device. (I removed variants from this PR for fair comparison).

Here are the results with warptiling (https://github.com/etasnadi/llama.cppxx.git):

The format is: test_case implicit_flops/this_pr_flops: speedup

CONV_2D(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0)   5650.0/5180.0: 1.09
CONV_2D(ne_input=[19,19,8,16],ne_kernel=[4,4,8,128],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0)   4740.0/3870.0: 1.22
CONV_2D(ne_input=[19,19,8,16],ne_kernel=[4,4,8,130],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0)   2550.0/2430.0: 1.05
CONV_2D(ne_input=[19,19,4,16],ne_kernel=[2,2,4,4],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0)   49.12/64.34: 0.76
CONV_2D(ne_input=[224,224,3,1],ne_kernel=[3,3,3,8],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0)   195.18/254.59: 0.77
CONV_2D(ne_input=[224,224,1,1],ne_kernel=[2,2,1,8],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0)   48.94/60.8: 0.80
CONV_2D(ne_input=[224,224,1,8],ne_kernel=[2,2,1,8],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0)   50.56/66.92: 0.76
CONV_2D(ne_input=[58,58,32,1],ne_kernel=[3,3,32,64],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0)   2110.0/1630.0: 1.29
CONV_2D(ne_input=[58,58,32,8],ne_kernel=[3,3,32,64],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0)   2810.0/2640.0: 1.06
CONV_2D(ne_input=[16,16,128,8],ne_kernel=[3,3,128,512],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0)   4310.0/3950.0: 1.09
CONV_2D(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0)   5930.0/5310.0: 1.12
CONV_2D(ne_input=[19,19,8,16],ne_kernel=[4,4,8,128],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0)   4760.0/3740.0: 1.27
CONV_2D(ne_input=[19,19,8,16],ne_kernel=[4,4,8,130],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0)   2590.0/2360.0: 1.10
CONV_2D(ne_input=[19,19,4,16],ne_kernel=[2,2,4,4],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0)   49.62/62.12: 0.80
CONV_2D(ne_input=[224,224,3,1],ne_kernel=[3,3,3,8],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0)   197.55/247.48: 0.80
CONV_2D(ne_input=[224,224,1,1],ne_kernel=[2,2,1,8],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0)   49.66/58.67: 0.85
CONV_2D(ne_input=[224,224,1,8],ne_kernel=[2,2,1,8],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0)   51.79/64.77: 0.80
CONV_2D(ne_input=[58,58,32,1],ne_kernel=[3,3,32,64],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0)   2130.0/1580.0: 1.35
CONV_2D(ne_input=[58,58,32,8],ne_kernel=[3,3,32,64],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0)   2820.0/2580.0: 1.09
CONV_2D(ne_input=[16,16,128,8],ne_kernel=[3,3,128,512],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0)   4380.0/3850.0: 1.14
CONV_2D(ne_input=[96,128,512,1],ne_kernel=[3,3,512,512],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0)   5770.0/5170.0: 1.12
CONV_2D(ne_input=[192,256,512,1],ne_kernel=[3,3,512,512],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0)   5970.0/5400.0: 1.11
CONV_2D(ne_input=[384,512,256,1],ne_kernel=[3,3,256,256],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0)   5980.0/5440.0: 1.10
CONV_2D(ne_input=[768,1024,128,1],ne_kernel=[3,3,128,128],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0)   5940.0/5420.0: 1.10
CONV_2D(ne_input=[96,128,512,1],ne_kernel=[1,1,512,512],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0)   5340.0/5020.0: 1.06
CONV_2D(ne_input=[96,128,4,1],ne_kernel=[1,1,4,4],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0)   23.08/24.21: 0.95
CONV_2D(ne_input=[96,128,4,1],ne_kernel=[3,3,4,512],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0)   3200.0/3070.0: 1.04
CONV_2D(ne_input=[384,512,512,1],ne_kernel=[3,3,512,512],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0)   6020.0/5490.0: 1.10
CONV_2D(ne_input=[384,512,512,1],ne_kernel=[3,3,512,256],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0)   5980.0/5450.0: 1.10
CONV_2D(ne_input=[384,512,512,1],ne_kernel=[1,1,512,256],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0)   5580.0/5290.0: 1.05
CONV_2D(ne_input=[768,1024,256,1],ne_kernel=[3,3,256,128],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0)   5980.0/5430.0: 1.10
CONV_2D(ne_input=[768,1024,256,1],ne_kernel=[1,1,256,128],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0)   5320.0/5100.0: 1.04
CONV_2D(ne_input=[768,1024,256,1],ne_kernel=[3,3,256,256],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0)   5990.0/5460.0: 1.10
CONV_2D(ne_input=[768,1024,128,1],ne_kernel=[3,3,128,3],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0)   148.53/135.51: 1.10
Mean speedup: 1.0375495514365727 ([0.7555289898386133, 1.3481012658227849])

This PR is faster on memory bound problems and your kernel is faster when the input is compute bound.

I checked your code to see why is the difference so I realized that your tiling strategy is heavily optimized to Nvidia cards. Now I am really curious if your optimizations would be also effective on non-Nvidia devices as well. @bssrdf are you motivated enough to also add the Vulkan port of the kernel?

In summary, I suggest to reopen #15805 and merge the features of this PR (variants and fastdiv) so the alg could be even faster!

@etasnadi, thank you for the detailed benchmarking. I am working on adding multiple blocking strategies into my PR. If it worked on memory bound cases as well, I will reopen it. BTW, fastdiv have been added. I also added vectorized load which seems also giving a speed bump.

As to the specific optimization, I wonder whether double buffering helped by hiding the latency. Unfortunately I don't know vulkan so can not help optimizing it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants