Commit c607174
authored
Fix mask passed to flashinfer (#3324)
Custom masks are padded to the shape `[batch_size, max_len, max_len]`.
However, flashinfer expects an unpadded mask of the shape
`[sum(q_len[i] * k_len[i] for i in range(batch_size)]`.
This change unpads the custom mask (currently only used by Gemma 3)
to this shape (assuming q_len == k_len, since we only use the custom
mask during prefill).1 parent 4f067c2 commit c607174
File tree
1 file changed
+24
-0
lines changed- server/text_generation_server/layers/attention
1 file changed
+24
-0
lines changedLines changed: 24 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1 | 1 | | |
2 | 2 | | |
3 | 3 | | |
| 4 | + | |
4 | 5 | | |
5 | 6 | | |
6 | 7 | | |
| |||
20 | 21 | | |
21 | 22 | | |
22 | 23 | | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
23 | 38 | | |
24 | 39 | | |
25 | 40 | | |
| |||
83 | 98 | | |
84 | 99 | | |
85 | 100 | | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
86 | 110 | | |
87 | 111 | | |
88 | 112 | | |
| |||
0 commit comments