Skip to content

Commit 6f62364

Browse files
committed
docs: Add document retracing x86-64 AVX sgemm microkernel in typst
This is mostly for fun (and verification). Not as generic as it could be and so on. Compiled using typst 0.13.1. PDF included in repo so that it is readily available to read. Experimenting with the document in typst.app or locally with instant preview is a good way to work with it.
1 parent 1c91e1c commit 6f62364

File tree

3 files changed

+313
-0
lines changed

3 files changed

+313
-0
lines changed

docs/typst/Makefile

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
2+
x86_sgemm.pdf: x86_sgemm.typ
3+
typst compile $<

docs/typst/x86_sgemm.pdf

45.2 KB
Binary file not shown.

docs/typst/x86_sgemm.typ

Lines changed: 310 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,310 @@
1+
// Copyright 2025 Ulrik Sverdrup "bluss"
2+
//
3+
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4+
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5+
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6+
// option. This file may not be copied, modified, or distributed
7+
// except according to those terms.
8+
//
9+
// This document retraces the vector permutations in the x86-64 AVX sgemm microkernel,
10+
// to verify and visualize where the elements from the input buffers end up.
11+
12+
#set document(
13+
date: none,
14+
author: ("Ulrik Sverdrup", ),
15+
title: "matrixmultiply: x86-64 AVX sgemm microkernel",
16+
)
17+
18+
#set text(font: "Fira Sans", size: 11pt, features: ())
19+
#let rawfont = "Fira Code"
20+
#show raw: set text(font: rawfont, size: 10pt)
21+
22+
#show link: underline.with(evade: false)
23+
#set page(numbering: "1", header: {
24+
set align(right)
25+
set text(size: 0.8em)
26+
[matrixmultiply #link("https://github.com/bluss/matrixmultiply")]
27+
})
28+
29+
30+
/// Add string prefix to each array element
31+
#let tag(name, arr) = {
32+
arr.map(x => name + str(x))
33+
}
34+
35+
#let load_ps(name) = {
36+
tag(name, range(0, 8))
37+
}
38+
39+
// https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm256_moveldup_ps&ig_expand=4923,6050,4597
40+
#let moveldup_ps(x) = {
41+
range(0, x.len()).map(i => x.at(2 * calc.div-euclid(i, 2)))
42+
}
43+
44+
#let movehdup_ps(x) = {
45+
range(0, x.len()).map(i => x.at(1 + 2 * calc.div-euclid(i, 2)))
46+
}
47+
48+
#let select4_128(src, control) = {
49+
let i = control
50+
if i <= 3 {
51+
src.slice(i, i + 1)
52+
} else {
53+
panic("invalid control")
54+
}
55+
}
56+
57+
58+
/// _mm256_permute_ps
59+
/// control word a, b, c, d (each 2 bits)
60+
// https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm256_permute_ps&ig_expand=4923
61+
#let permute_ps(x, a, b, c, d) = {
62+
for (i, c) in (a, b, c, d).enumerate() {
63+
select4_128(x.slice(0, 4), c)
64+
}
65+
for (i, c) in (a, b, c, d).enumerate() {
66+
select4_128(x.slice(4, 8), c)
67+
}
68+
}
69+
70+
/// _mm256_permute2f128_ps
71+
/// control word a, b (each 2 bits)
72+
#let permute2f128_ps(src1, src2, a, b) = {
73+
let select4_perm(control) = {
74+
if control == 0 {
75+
src1.slice(0, 4)
76+
} else if control == 1 {
77+
src1.slice(4, 8)
78+
} else if control == 2 {
79+
src2.slice(0, 4)
80+
} else if control == 3 {
81+
src2.slice(4, 8)
82+
} else {
83+
panic("invalid control")
84+
}
85+
}
86+
select4_perm(a)
87+
select4_perm(b)
88+
}
89+
90+
/// _mm256_shuffle_ps
91+
/// control word a, b, c, d (each 2 bits)
92+
// https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm256_shuffle_ps&ig_expand=4923,6050
93+
#let shuffle_ps(src1, src2, a, b, c, d) = {
94+
let control-and-source = (a, b, c, d).zip((src1, src1, src2, src2)).enumerate()
95+
for (i, (c, src)) in control-and-source {
96+
select4_128(src.slice(0, 4), c)
97+
}
98+
for (i, (c, src)) in control-and-source {
99+
select4_128(src.slice(4, 8), c)
100+
}
101+
}
102+
103+
104+
105+
#let digits = "0123456789".codepoints()
106+
/// Translate a1b2 to ab12
107+
#let norm-name(x) = {
108+
x.split("").sorted(key: x => digits.contains(x)).join()
109+
}
110+
111+
/// Multiply two arrays (a0, a1) * (b0, b1) == (a0b0, a1b1)
112+
#let mul(x, y) = {
113+
x.zip(y, exact: true).map(((a, b)) => a + b).map(norm-name)
114+
}
115+
116+
/// Map array (of string) to (elt, bool) where the boolean marks it as duplicated or not
117+
#let markduplicates(arr) = {
118+
let counter = (:)
119+
for elt in arr {
120+
let c = 1 + counter.at(elt, default: 0)
121+
counter.insert(elt, c)
122+
}
123+
arr.map(elt => (elt, counter.at(elt) > 1))
124+
}
125+
126+
127+
#let show-vectors(ab, name: none, row-label: none, check-duplicates: true) = {
128+
let ncol = 8
129+
let vector-width = 3.5em
130+
let color-indices = true
131+
132+
let elements = ab.flatten()
133+
let extra-col = 0
134+
let nrows = calc.div-euclid(ab.flatten().len(), 8)
135+
136+
let row-enumerator = box
137+
if name != none and row-label == none {
138+
row-label = name
139+
row-enumerator = x => none
140+
} else if name != none {
141+
block(strong(name), below: 0.6em)
142+
}
143+
144+
show sub: text.with(size: 1.3em)
145+
show <row-label>: it => {
146+
set text(font: rawfont, size: 9pt)
147+
strong(it.body)
148+
}
149+
150+
show table.cell: it => {
151+
if it.x >= ncol {
152+
return it
153+
}
154+
show regex("([a-z]+[0-9]*)+"): it => {
155+
show regex("\d"): it => {
156+
let color = if not color-indices {
157+
black
158+
} else if it.text.match(regex("[37]")) != none {
159+
green.darken(10%)
160+
} else if it.text.match(regex("[15]")) != none {
161+
red.darken(20%)
162+
} else if it.text.match(regex("[26]")) != none {
163+
blue.darken(10%)
164+
} else {
165+
black
166+
}
167+
set text(fill: color)
168+
strong(sub(it))
169+
}
170+
it
171+
}
172+
it
173+
}
174+
175+
176+
// check and mark duplicates
177+
if nrows > 1 and check-duplicates {
178+
elements = markduplicates(elements).map(((elt, duplicated)) => {
179+
set text(stroke: red + 0.7pt) if duplicated
180+
elt
181+
})
182+
}
183+
184+
if row-label != none {
185+
elements = elements.chunks(8).enumerate().map(
186+
((i, c)) => c + ([_#row-label;#row-enumerator[[#i]]_<row-label>], )
187+
).flatten()
188+
extra-col += 1
189+
}
190+
let t = 0.5pt
191+
table(
192+
columns: (vector-width,) * ncol + (auto, ) * extra-col,
193+
align: bottom + center,
194+
inset: (bottom: 0.5em),
195+
stroke: (x, y) => {
196+
let st = (:)
197+
if x == 0 { st.insert("left", t) }
198+
if x == ncol - 1 { st.insert("right", t) }
199+
if y == 0 and x < ncol { st.insert("top", t)}
200+
if y == nrows - 1 and x < ncol { st.insert("bottom", t) }
201+
st
202+
},
203+
fill: (x, y) => if x >= 8 { none } else if calc.odd(y) { rgb("EAF2F5") },
204+
..elements,
205+
table.vline(x: 2, position: start, stroke: t / 4),
206+
table.vline(x: 4, position: start, stroke: t / 2),
207+
table.vline(x: 6, position: start, stroke: t / 4),
208+
)
209+
}
210+
211+
212+
= x86-64 AVX/FMA sgemm microkernel: 32-bit float
213+
214+
== Loop Iteration
215+
216+
Load data from buffers `a` and `b` into vectors `aNNNN` and `bv`, `bv_lh`.
217+
#{
218+
let av = load_ps("a")
219+
let bv = load_ps("b")
220+
let a0246 = moveldup_ps(av)
221+
let a2064 = permute_ps(a0246, 2, 3, 0, 1)
222+
let a1357 = movehdup_ps(av)
223+
let a3175 = permute_ps(a1357, 2, 3, 0, 1)
224+
let bv_lh = permute2f128_ps(bv, bv, 3, 0)
225+
226+
show-vectors(av, name: `av`)
227+
show-vectors(a0246, name: `a0246`)
228+
show-vectors(a2064, name: `a2064`)
229+
show-vectors(a1357, name: `a1357`)
230+
show-vectors(a3175, name: `a3175`)
231+
show-vectors(bv, name: `bv`)
232+
show-vectors(bv_lh, name: `bv_lh`)
233+
234+
[
235+
#show "+=": $+#h(0em)=$
236+
#show "*": $times$
237+
```rust
238+
ab[0] += a0246 * bv
239+
ab[1] += a2064 * bv
240+
ab[2] += a0246 * bv_lh
241+
ab[3] += a2064 * bv_lh
242+
ab[4] += a1357 * bv
243+
ab[5] += a3175 * bv
244+
ab[6] += a1357 * bv_lh
245+
ab[7] += a3175 * bv_lh
246+
```
247+
]
248+
249+
let ab = (
250+
mul(a0246, bv),
251+
mul(a2064, bv),
252+
mul(a0246, bv_lh),
253+
mul(a2064, bv_lh),
254+
255+
mul(a1357, bv),
256+
mul(a3175, bv),
257+
mul(a1357, bv_lh),
258+
mul(a3175, bv_lh),
259+
)
260+
261+
show-vectors(ab, name: [`ab` accumulator in loop], row-label: [ab])
262+
if ab.flatten().len() != ab.flatten().dedup().len() {
263+
highlight(fill: red, [Duplicate entries])
264+
}
265+
266+
pagebreak()
267+
268+
[
269+
== Finish
270+
De-stripe data from accumulator into final storage order.
271+
]
272+
273+
let shuf_mask = (0, 1, 2, 3)
274+
let shuffle_ab = (i, j) => shuffle_ps(ab.at(i), ab.at(j), ..shuf_mask)
275+
let ab0044 = shuffle_ab(0, 1)
276+
let ab2266 = shuffle_ab(1, 0)
277+
let ab4400 = shuffle_ab(2, 3)
278+
let ab6622 = shuffle_ab(3, 2)
279+
280+
let ab1155 = shuffle_ab(4, 5)
281+
let ab3377 = shuffle_ab(5, 4)
282+
let ab5511 = shuffle_ab(6, 7)
283+
let ab7733 = shuffle_ab(7, 6)
284+
285+
show-vectors(ab0044, name: `ab0044`)
286+
show-vectors(ab2266, name: `ab2266`)
287+
show-vectors(ab4400, name: `ab4400`)
288+
show-vectors(ab6622, name: `ab6622`)
289+
290+
show-vectors(ab1155, name: `ab1155`)
291+
show-vectors(ab3377, name: `ab3377`)
292+
show-vectors(ab5511, name: `ab5511`)
293+
show-vectors(ab7733, name: `ab7733`)
294+
295+
let abfinal = (
296+
permute2f128_ps(ab0044, ab4400, 0, 2),
297+
permute2f128_ps(ab1155, ab5511, 0, 2),
298+
permute2f128_ps(ab2266, ab6622, 0, 2),
299+
permute2f128_ps(ab3377, ab7733, 0, 2),
300+
permute2f128_ps(ab0044, ab4400, 3, 1),
301+
permute2f128_ps(ab1155, ab5511, 3, 1),
302+
permute2f128_ps(ab2266, ab6622, 3, 1),
303+
permute2f128_ps(ab3377, ab7733, 3, 1),
304+
)
305+
306+
show-vectors(abfinal, name: [`ab` in order], row-label: [ab])
307+
if abfinal.flatten().len() != abfinal.flatten().dedup().len() {
308+
highlight(fill: red, [Duplicate entries])
309+
}
310+
}

0 commit comments

Comments
 (0)