@@ -930,8 +930,8 @@ def callable():
930930 return callable
931931 else :
932932 loops = []
933-
934- if access == op2 .INC :
933+ # Initialise to zero if needed
934+ if access is op2 .INC :
935935 loops .append (tensor .zero )
936936
937937 # Arguments in the operand are allowed to be from a MixedFunctionSpace
@@ -957,7 +957,7 @@ def callable():
957957 for indices , sub_expr in expressions .items ():
958958 sub_tensor = tensor [indices [0 ]] if rank == 1 else tensor
959959 loops .extend (_interpolator (sub_tensor , sub_expr , subset , access , bcs = bcs ))
960-
960+ # Apply bcs
961961 if bcs and rank == 1 :
962962 loops .extend (partial (bc .apply , f ) for bc in bcs )
963963
@@ -1038,32 +1038,36 @@ def _interpolator(tensor, expr, subset, access, bcs=None):
10381038 parameters = {}
10391039 parameters ['scalar_type' ] = utils .ScalarType
10401040
1041- callables = ()
1041+ copyin = ()
1042+ copyout = ()
10421043
10431044 # For the matfree adjoint 1-form and the 0-form, the cellwise kernel will add multiple
10441045 # contributions from the facet DOFs of the dual argument.
10451046 # The incoming Cofunction needs to be weighted by the reciprocal of the DOF multiplicity.
10461047 needs_weight = isinstance (dual_arg , ufl .Cofunction ) and not to_element .is_dg ()
10471048 if needs_weight :
1048- # Compute the reciprocal of the DOF multiplicity
1049+ # Create a buffer for the weighted Cofunction
10491050 W = dual_arg .function_space ()
1051+ v = firedrake .Function (W )
1052+ expr = expr ._ufl_expr_reconstruct_ (operand , v = v )
1053+ copyin += (partial (dual_arg .dat .copy , v .dat ),)
1054+
1055+ # Compute the reciprocal of the DOF multiplicity
1056+ wdat = W .make_dat ()
1057+ m_ = get_interp_node_map (source_mesh , target_mesh , W )
10501058 wsize = W .finat_element .space_dimension () * W .block_size
10511059 kernel_code = f"""
10521060 void multiplicity(PetscScalar *restrict w) {{
10531061 for (PetscInt i=0; i<{ wsize } ; i++) w[i] += 1;
10541062 }}"""
1055- kernel = op2 .Kernel (kernel_code , "multiplicity" , requires_zeroed_output_arguments = False )
1056- weight = firedrake .Function (W )
1057- m_ = get_interp_node_map (source_mesh , target_mesh , W )
1058- op2 .par_loop (kernel , cell_set , weight .dat (op2 .INC , m_ ))
1059- with weight .dat .vec as w :
1063+ kernel = op2 .Kernel (kernel_code , "multiplicity" )
1064+ op2 .par_loop (kernel , cell_set , wdat (op2 .INC , m_ ))
1065+ with wdat .vec as w :
10601066 w .reciprocal ()
10611067
1062- # Create a buffer for the weighted Cofunction and a callable to apply the weight
1063- v = firedrake .Function (W )
1064- expr = expr ._ufl_expr_reconstruct_ (operand , v = v )
1065- with weight .dat .vec_ro as w , dual_arg .dat .vec_ro as x , v .dat .vec_wo as y :
1066- callables += (partial (y .pointwiseMult , x , w ),)
1068+ # Create a callable to apply the weight
1069+ with wdat .vec_ro as w , v .dat .vec as y :
1070+ copyin += (partial (y .pointwiseMult , y , w ),)
10671071
10681072 # We need to pass both the ufl element and the finat element
10691073 # because the finat elements might not have the right mapping
@@ -1079,7 +1083,7 @@ def _interpolator(tensor, expr, subset, access, bcs=None):
10791083 coefficient_numbers = kernel .coefficient_numbers
10801084 needs_external_coords = kernel .needs_external_coords
10811085 name = kernel .name
1082- kernel = op2 .Kernel (ast , name , requires_zeroed_output_arguments = True ,
1086+ kernel = op2 .Kernel (ast , name , requires_zeroed_output_arguments = ( access is not op2 . INC ) ,
10831087 flop_count = kernel .flop_count , events = (kernel .event ,))
10841088
10851089 parloop_args = [kernel , cell_set ]
@@ -1092,17 +1096,12 @@ def _interpolator(tensor, expr, subset, access, bcs=None):
10921096 output = tensor
10931097 tensor = op2 .Dat (tensor .dataset )
10941098 if access is not op2 .WRITE :
1095- copyin = (partial (output .copy , tensor ), )
1096- else :
1097- copyin = ()
1098- copyout = (partial (tensor .copy , output ), )
1099- else :
1100- copyin = ()
1101- copyout = ()
1099+ copyin += (partial (output .copy , tensor ), )
1100+ copyout += (partial (tensor .copy , output ), )
11021101 if isinstance (tensor , op2 .Global ):
11031102 parloop_args .append (tensor (access ))
11041103 elif isinstance (tensor , op2 .Dat ):
1105- V_dest = arguments [- 1 ].function_space () if isinstance ( dual_arg , ufl . Cofunction ) else V
1104+ V_dest = arguments [- 1 ].function_space ()
11061105 m_ = get_interp_node_map (source_mesh , target_mesh , V_dest )
11071106 parloop_args .append (tensor (access , m_ ))
11081107 else :
@@ -1162,11 +1161,10 @@ def _interpolator(tensor, expr, subset, access, bcs=None):
11621161 parloop_args .append (target_ref_coords .dat (op2 .READ , m_ ))
11631162
11641163 parloop = op2 .ParLoop (* parloop_args )
1165- parloop_compute_callable = parloop .compute
11661164 if isinstance (tensor , op2 .Mat ):
1167- return parloop_compute_callable , tensor .assemble
1165+ return parloop , tensor .assemble
11681166 else :
1169- return copyin + callables + ( parloop_compute_callable , ) + copyout
1167+ return copyin + ( parloop , ) + copyout
11701168
11711169
11721170def get_interp_node_map (source_mesh , target_mesh , fs ):
0 commit comments