Skip to content

Commit f95d4c7

Browse files
committed
Refactor reorder_and_finalize_potentials into wranglers
1 parent c85f168 commit f95d4c7

File tree

3 files changed

+55
-47
lines changed

3 files changed

+55
-47
lines changed

pytential/qbx/distributed.py

Lines changed: 4 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -787,7 +787,7 @@ def exec_compute_potential_insn_fmm(self, actx: PyOpenCLArrayContext,
787787
# Execute global QBX.
788788
timing_data: Dict[str, Any] = {}
789789
all_potentials_on_every_target = drive_dfmm(
790-
self.comm, flat_strengths, wrangler, timing_data)
790+
flat_strengths, wrangler, timing_data)
791791

792792
if self.comm.Get_rank() == 0:
793793
assert global_geo_data_device is not None
@@ -816,11 +816,9 @@ def exec_compute_potential_insn_fmm(self, actx: PyOpenCLArrayContext,
816816
return results, timing_data
817817

818818

819-
def drive_dfmm(comm, src_weight_vecs, wrangler, timing_data=None):
819+
def drive_dfmm(src_weight_vecs, wrangler, timing_data=None):
820820
# TODO: Integrate the distributed functionality with `qbx.fmm.drive_fmm`,
821821
# similar to that in `boxtree`.
822-
823-
current_rank = comm.Get_rank()
824822
local_traversal = wrangler.traversal
825823

826824
# {{{ Distribute source weights
@@ -993,30 +991,8 @@ def drive_dfmm(comm, src_weight_vecs, wrangler, timing_data=None):
993991
non_qbx_potentials = wrangler.gather_non_qbx_potentials(non_qbx_potentials)
994992
qbx_potentials = wrangler.gather_qbx_potentials(qbx_potentials)
995993

996-
if current_rank != 0: # worker process
997-
result = None
998-
999-
else: # master process
1000-
1001-
all_potentials_in_tree_order = wrangler.full_output_zeros(template_ary)
1002-
1003-
nqbtl = wrangler.global_geo_data.non_qbx_box_target_lists
1004-
1005-
for ap_i, nqp_i in zip(
1006-
all_potentials_in_tree_order, non_qbx_potentials):
1007-
ap_i[nqbtl.unfiltered_from_filtered_target_indices] = nqp_i
1008-
1009-
all_potentials_in_tree_order += qbx_potentials
1010-
1011-
def reorder_and_finalize_potentials(x):
1012-
# "finalize" gives host FMMs (like FMMlib) a chance to turn the
1013-
# potential back into a CL array.
1014-
return wrangler.finalize_potentials(
1015-
x[wrangler.global_traversal.tree.sorted_target_ids], template_ary)
1016-
1017-
from pytools.obj_array import with_object_array_or_scalar
1018-
result = with_object_array_or_scalar(
1019-
reorder_and_finalize_potentials, all_potentials_in_tree_order)
994+
result = wrangler.reorder_and_finalize_potentials(
995+
non_qbx_potentials, qbx_potentials, template_ary)
1020996

1021997
if timing_data is not None:
1022998
timing_data.update(recorder.summarize())

pytential/qbx/fmm.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,28 @@ def make_container():
400400

401401
# {{{ FMM top-level
402402

403+
def _reorder_and_finalize_potentials(
404+
wrangler, non_qbx_potentials, qbx_potentials, template_ary):
405+
nqbtl = wrangler.geo_data.non_qbx_box_target_lists()
406+
407+
all_potentials_in_tree_order = wrangler.full_output_zeros(template_ary)
408+
409+
for ap_i, nqp_i in zip(all_potentials_in_tree_order, non_qbx_potentials):
410+
ap_i[nqbtl.unfiltered_from_filtered_target_indices] = nqp_i
411+
412+
all_potentials_in_tree_order += qbx_potentials
413+
414+
def reorder_and_finalize_potentials(x):
415+
# "finalize" gives host FMMs (like FMMlib) a chance to turn the
416+
# potential back into a CL array.
417+
return wrangler.finalize_potentials(x[
418+
wrangler.geo_data.traversal().tree.sorted_target_ids], template_ary)
419+
420+
from pytools.obj_array import obj_array_vectorize
421+
return obj_array_vectorize(
422+
reorder_and_finalize_potentials, all_potentials_in_tree_order)
423+
424+
403425
def drive_fmm(expansion_wrangler, src_weight_vecs, timing_data=None,
404426
traversal=None):
405427
"""Top-level driver routine for the QBX fast multipole calculation.
@@ -423,8 +445,6 @@ def drive_fmm(expansion_wrangler, src_weight_vecs, timing_data=None,
423445
if traversal is None:
424446
traversal = geo_data.traversal()
425447

426-
tree = traversal.tree
427-
428448
template_ary = src_weight_vecs[0]
429449

430450
recorder = TimingRecorder()
@@ -585,23 +605,8 @@ def drive_fmm(expansion_wrangler, src_weight_vecs, timing_data=None,
585605

586606
# {{{ reorder potentials
587607

588-
nqbtl = geo_data.non_qbx_box_target_lists()
589-
590-
all_potentials_in_tree_order = wrangler.full_output_zeros(template_ary)
591-
592-
for ap_i, nqp_i in zip(all_potentials_in_tree_order, non_qbx_potentials):
593-
ap_i[nqbtl.unfiltered_from_filtered_target_indices] = nqp_i
594-
595-
all_potentials_in_tree_order += qbx_potentials
596-
597-
def reorder_and_finalize_potentials(x):
598-
# "finalize" gives host FMMs (like FMMlib) a chance to turn the
599-
# potential back into a CL array.
600-
return wrangler.finalize_potentials(x[tree.sorted_target_ids], template_ary)
601-
602-
from pytools.obj_array import obj_array_vectorize
603-
result = obj_array_vectorize(
604-
reorder_and_finalize_potentials, all_potentials_in_tree_order)
608+
result = _reorder_and_finalize_potentials(
609+
wrangler, non_qbx_potentials, qbx_potentials, template_ary)
605610

606611
# }}}
607612

pytential/qbx/fmmlib.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -752,4 +752,31 @@ def gather_qbx_potentials(self, qbx_potentials):
752752
ntargets, qbx_potentials,
753753
self.geo_data.qbx_target_mask, self.MPITags["qbx_potentials"])
754754

755+
def reorder_and_finalize_potentials(
756+
self, non_qbx_potentials, qbx_potentials, template_ary):
757+
mpi_rank = self.comm.Get_rank()
758+
759+
if mpi_rank == 0:
760+
all_potentials_in_tree_order = self.full_output_zeros(template_ary)
761+
762+
nqbtl = self.global_geo_data.non_qbx_box_target_lists
763+
764+
for ap_i, nqp_i in zip(
765+
all_potentials_in_tree_order, non_qbx_potentials):
766+
ap_i[nqbtl.unfiltered_from_filtered_target_indices] = nqp_i
767+
768+
all_potentials_in_tree_order += qbx_potentials
769+
770+
def _reorder_and_finalize_potentials(x):
771+
# "finalize" gives host FMMs (like FMMlib) a chance to turn the
772+
# potential back into a CL array.
773+
return self.finalize_potentials(
774+
x[self.global_traversal.tree.sorted_target_ids], template_ary)
775+
776+
from pytools.obj_array import with_object_array_or_scalar
777+
return with_object_array_or_scalar(
778+
_reorder_and_finalize_potentials, all_potentials_in_tree_order)
779+
else:
780+
return None
781+
755782
# vim: foldmethod=marker

0 commit comments

Comments
 (0)