@@ -1519,6 +1519,48 @@ def test_prefetch_to_same_temp_var(ctx_factory):
15191519 lp .auto_test_vs_ref (ref_tunit , ctx , t_unit )
15201520
15211521
1522+ def test_sum_redn_algebraic_transforms (ctx_factory ):
1523+ from pymbolic import variables
1524+ from loopy .symbolic import Reduction
1525+
1526+ t_unit = lp .make_kernel (
1527+ "{[e,i,j,x,r]: 0<=e<N_e and 0<=i,j<35 and 0<=x,r<3}" ,
1528+ """
1529+ y[i] = sum([r,j], J[x, r, e]*D[r,i,j]*u[e,j])
1530+ """ ,
1531+ [lp .GlobalArg ("J,D,u" , dtype = np .float64 , shape = lp .auto ),
1532+ ...],
1533+ )
1534+ knl = t_unit .default_entrypoint
1535+
1536+ knl = lp .split_reduction_inward (knl , "j" )
1537+ knl = lp .hoist_invariant_multiplicative_terms_in_sum_reduction (
1538+ knl ,
1539+ reduction_inames = "j"
1540+ )
1541+ knl = lp .extract_multiplicative_terms_in_sum_reduction_as_subst (
1542+ knl ,
1543+ within = None ,
1544+ subst_name = "grad_without_jacobi_subst" ,
1545+ arguments = variables ("r i e" ),
1546+ terms_filter = lambda x : isinstance (x , Reduction )
1547+ )
1548+
1549+ transformed_t_unit = t_unit .with_kernel (knl )
1550+ transformed_t_unit = lp .precompute (
1551+ transformed_t_unit ,
1552+ "grad_without_jacobi_subst" ,
1553+ sweep_inames = ["r" , "i" ],
1554+ precompute_outer_inames = frozenset ({"e" }),
1555+ temporary_address_space = lp .AddressSpace .PRIVATE )
1556+
1557+ x1 = lp .get_op_map (t_unit , subgroup_size = 1 ).eval_and_sum ({"N_e" : 1 })
1558+ x2 = lp .get_op_map (transformed_t_unit , subgroup_size = 1 ).eval_and_sum ({"N_e" : 1 })
1559+
1560+ assert x1 == 33075
1561+ assert x2 == 7980 # i.e. this demonstrates a 4.14x reduction in flops
1562+
1563+
15221564if __name__ == "__main__" :
15231565 if len (sys .argv ) > 1 :
15241566 exec (sys .argv [1 ])
0 commit comments