Skip to content

Commit 8713fc3

Browse files
committed
test c vector extensions
1 parent 56c64d9 commit 8713fc3

File tree

1 file changed

+89
-0
lines changed

1 file changed

+89
-0
lines changed

test/test_target.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -663,6 +663,95 @@ def test_glibc_bessel_functions(dtype):
663663
rtol=1e-6, atol=1e-6)
664664

665665

666+
def test_c_vector_extensions():
667+
knl = lp.make_kernel(
668+
"{[i, j1, j2, j3]: 0<=i<10 and 0<=j1,j2,j3<4}",
669+
"""
670+
<> temp1[j1] = x[i, j1]
671+
<> temp2[j2] = 2*temp1[j2] + 1 {inames=i:j2}
672+
y[i, j3] = temp2[j3]
673+
""",
674+
[lp.GlobalArg("x, y", shape=lp.auto, dtype=float)],
675+
seq_dependencies=True,
676+
target=lp.CVectorExtensionsTarget())
677+
678+
knl = lp.tag_inames(knl, "j2:vec, j1:ilp, j3:ilp")
679+
knl = lp.tag_array_axes(knl, "temp1,temp2", "vec")
680+
681+
print(lp.generate_code_v2(knl).device_code())
682+
683+
684+
def test_omp_simd_tag():
685+
knl = lp.make_kernel(
686+
"{[i]: 0<=i<16}",
687+
"""
688+
y[i] = 2 * x[i]
689+
""")
690+
691+
knl = lp.add_dtypes(knl, {"x": "float64"})
692+
knl = lp.split_iname(knl, "i", 4)
693+
knl = lp.tag_inames(knl, {"i_inner": lp.OpenMPSIMDTag()})
694+
695+
code_str = lp.generate_code_v2(knl).device_code()
696+
697+
assert any(line.strip() == "#pragma omp simd"
698+
for line in code_str.split("\n"))
699+
700+
701+
def test_vec_tag_with_omp_simd_fallback():
702+
knl = lp.make_kernel(
703+
"{[i, j1, j2, j3]: 0<=i<10 and 0<=j1,j2,j3<4}",
704+
"""
705+
<> temp1[j1] = x[i, j1]
706+
<> temp2[j2] = 2*temp1[j2] + 1 {inames=i:j2}
707+
y[i, j3] = temp2[j3]
708+
""",
709+
[lp.GlobalArg("x, y", shape=lp.auto, dtype=float)],
710+
seq_dependencies=True,
711+
target=lp.ExecutableCVectorExtensionsTarget())
712+
713+
knl = lp.tag_inames(knl, {"j1": lp.VectorizeTag(lp.OpenMPSIMDTag()),
714+
"j2": lp.VectorizeTag(lp.OpenMPSIMDTag()),
715+
"j3": lp.VectorizeTag(lp.OpenMPSIMDTag())})
716+
knl = lp.tag_array_axes(knl, "temp1,temp2", "vec")
717+
718+
code_str = lp.generate_code_v2(knl).device_code()
719+
720+
assert len([line
721+
for line in code_str.split("\n")
722+
if line.strip() == "#pragma omp simd"]) == 2
723+
724+
x = np.random.rand(10, 4)
725+
_, (out,) = knl(x=x)
726+
np.testing.assert_allclose(out, 2*x+1)
727+
728+
729+
def test_vec_extensions_with_multiple_loopy_body_insns():
730+
knl = lp.make_kernel(
731+
"{[n]: 0<=n<N}",
732+
"""
733+
for n
734+
... nop {id=expr_start}
735+
<> tmp = 2.0
736+
dat0[n, 0] = tmp {id=expr_insn}
737+
... nop {id=statement0}
738+
end
739+
""",
740+
seq_dependencies=True,
741+
target=lp.ExecutableCVectorExtensionsTarget())
742+
743+
knl = lp.add_dtypes(knl, {"dat0": "float64"})
744+
knl = lp.split_iname(knl, "n", 4, slabs=(1, 1),
745+
inner_iname="n_batch")
746+
knl = lp.privatize_temporaries_with_inames(knl, "n_batch")
747+
knl = lp.tag_array_axes(knl, "tmp", "vec")
748+
knl = lp.tag_inames(knl, {
749+
"n_batch": lp.VectorizeTag(lp.OpenMPSIMDTag())})
750+
751+
_, (out,) = knl(N=100)
752+
np.testing.assert_allclose(out, 2*np.ones((100, 1)))
753+
754+
666755
if __name__ == "__main__":
667756
if len(sys.argv) > 1:
668757
exec(sys.argv[1])

0 commit comments

Comments
 (0)