@@ -787,7 +787,7 @@ def exec_compute_potential_insn_fmm(self, actx: PyOpenCLArrayContext,
787787 # Execute global QBX.
788788 timing_data : Dict [str , Any ] = {}
789789 all_potentials_on_every_target = drive_dfmm (
790- self . comm , flat_strengths , wrangler , timing_data )
790+ flat_strengths , wrangler , timing_data )
791791
792792 if self .comm .Get_rank () == 0 :
793793 assert global_geo_data_device is not None
@@ -816,11 +816,9 @@ def exec_compute_potential_insn_fmm(self, actx: PyOpenCLArrayContext,
816816 return results , timing_data
817817
818818
819- def drive_dfmm (comm , src_weight_vecs , wrangler , timing_data = None ):
819+ def drive_dfmm (src_weight_vecs , wrangler , timing_data = None ):
820820 # TODO: Integrate the distributed functionality with `qbx.fmm.drive_fmm`,
821821 # similar to that in `boxtree`.
822-
823- current_rank = comm .Get_rank ()
824822 local_traversal = wrangler .traversal
825823
826824 # {{{ Distribute source weights
@@ -993,30 +991,8 @@ def drive_dfmm(comm, src_weight_vecs, wrangler, timing_data=None):
993991 non_qbx_potentials = wrangler .gather_non_qbx_potentials (non_qbx_potentials )
994992 qbx_potentials = wrangler .gather_qbx_potentials (qbx_potentials )
995993
996- if current_rank != 0 : # worker process
997- result = None
998-
999- else : # master process
1000-
1001- all_potentials_in_tree_order = wrangler .full_output_zeros (template_ary )
1002-
1003- nqbtl = wrangler .global_geo_data .non_qbx_box_target_lists
1004-
1005- for ap_i , nqp_i in zip (
1006- all_potentials_in_tree_order , non_qbx_potentials ):
1007- ap_i [nqbtl .unfiltered_from_filtered_target_indices ] = nqp_i
1008-
1009- all_potentials_in_tree_order += qbx_potentials
1010-
1011- def reorder_and_finalize_potentials (x ):
1012- # "finalize" gives host FMMs (like FMMlib) a chance to turn the
1013- # potential back into a CL array.
1014- return wrangler .finalize_potentials (
1015- x [wrangler .global_traversal .tree .sorted_target_ids ], template_ary )
1016-
1017- from pytools .obj_array import with_object_array_or_scalar
1018- result = with_object_array_or_scalar (
1019- reorder_and_finalize_potentials , all_potentials_in_tree_order )
994+ result = wrangler .reorder_and_finalize_potentials (
995+ non_qbx_potentials , qbx_potentials , template_ary )
1020996
1021997 if timing_data is not None :
1022998 timing_data .update (recorder .summarize ())
0 commit comments