@@ -1519,6 +1519,48 @@ def test_prefetch_to_same_temp_var(ctx_factory):
15191519    lp .auto_test_vs_ref (ref_tunit , ctx , t_unit )
15201520
15211521
1522+ def  test_sum_redn_algebraic_transforms (ctx_factory ):
1523+     from  pymbolic  import  variables 
1524+     from  loopy .symbolic  import  Reduction 
1525+ 
1526+     t_unit  =  lp .make_kernel (
1527+         "{[e,i,j,x,r]: 0<=e<N_e and 0<=i,j<35 and 0<=x,r<3}" ,
1528+         """ 
1529+         y[i] = sum([r,j], J[x, r, e]*D[r,i,j]*u[e,j]) 
1530+         """ ,
1531+         [lp .GlobalArg ("J,D,u" , dtype = np .float64 , shape = lp .auto ),
1532+          ...],
1533+     )
1534+     knl  =  t_unit .default_entrypoint 
1535+ 
1536+     knl  =  lp .split_reduction_inward (knl , "j" )
1537+     knl  =  lp .hoist_invariant_multiplicative_terms_in_sum_reduction (
1538+         knl ,
1539+         reduction_inames = "j" 
1540+     )
1541+     knl  =  lp .extract_multiplicative_terms_in_sum_reduction_as_subst (
1542+         knl ,
1543+         within = None ,
1544+         subst_name = "grad_without_jacobi_subst" ,
1545+         arguments = variables ("r i e" ),
1546+         terms_filter = lambda  x : isinstance (x , Reduction )
1547+     )
1548+ 
1549+     transformed_t_unit  =  t_unit .with_kernel (knl )
1550+     transformed_t_unit  =  lp .precompute (
1551+         transformed_t_unit ,
1552+         "grad_without_jacobi_subst" ,
1553+         sweep_inames = ["r" , "i" ],
1554+         precompute_outer_inames = frozenset ({"e" }),
1555+         temporary_address_space = lp .AddressSpace .PRIVATE )
1556+ 
1557+     x1  =  lp .get_op_map (t_unit , subgroup_size = 1 ).eval_and_sum ({"N_e" : 1 })
1558+     x2  =  lp .get_op_map (transformed_t_unit , subgroup_size = 1 ).eval_and_sum ({"N_e" : 1 })
1559+ 
1560+     assert  x1  ==  33075 
1561+     assert  x2  ==  7980   # i.e. this demonstrates a 4.14x reduction in flops 
1562+ 
1563+ 
15221564if  __name__  ==  "__main__" :
15231565    if  len (sys .argv ) >  1 :
15241566        exec (sys .argv [1 ])
0 commit comments