Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/core.yml
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,10 @@ jobs:
--branch $(python3 ./firedrake-repo/scripts/firedrake-configure --show-petsc-version) \
https://gitlab.com/petsc/petsc.git
else
git clone --depth 1 https://gitlab.com/petsc/petsc.git
git clone https://gitlab.com/petsc/petsc.git
fi
cd petsc
git checkout ksagiyam/preserve_global_point_number
python3 ../firedrake-repo/scripts/firedrake-configure \
--arch ${{ matrix.arch }} --show-petsc-configure-options | \
xargs -L1 ./configure --with-make-np=8 --download-slepc
Expand Down
42 changes: 18 additions & 24 deletions firedrake/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,6 @@ def __init__(self, filename, mode, comm=COMM_WORLD):
self.commkey = self._comm.py2f()
assert self.commkey != MPI.COMM_NULL.py2f()
self._function_spaces = {}
self._function_load_utils = {}
if mode in [PETSc.Viewer.FileMode.WRITE, PETSc.Viewer.FileMode.W, "w"]:
version = CheckpointFile.latest_version
self.set_attr_byte_string("/", "dmplex_storage_version", version)
Expand Down Expand Up @@ -1058,11 +1057,8 @@ def load_mesh(self, name=DEFAULT_MESH_NAME, reorder=None, distribution_parameter
cell = base_tmesh.ufl_cell()
element = finat.ufl.VectorElement("DP" if cell.is_simplex else "DQ", cell, 0, dim=2)
_ = self._load_function_space_topology(base_tmesh, element)
base_tmesh_key = self._generate_mesh_key_from_names(base_tmesh.name,
base_tmesh._distribution_name,
base_tmesh._permutation_name)
sd_key = self._get_shared_data_key_for_checkpointing(base_tmesh, element)
_, _, lsf = self._function_load_utils[base_tmesh_key + sd_key]
_, _, lsf = self._shared_data_cache(base_tmesh)[sd_key]
nroots, _, _ = lsf.getGraph()
layers_a = np.empty(nroots, dtype=utils.IntType)
layers_a_iset = PETSc.IS().createGeneral(layers_a, comm=self._comm)
Expand Down Expand Up @@ -1114,18 +1110,15 @@ def load_mesh(self, name=DEFAULT_MESH_NAME, reorder=None, distribution_parameter
coordinates = self._load_function_topology(tmesh, coord_element, coord_name)
mesh = make_mesh_from_coordinates(coordinates, name)
# Load plex coordinates for a complete representation of plex.
tmesh.topology_dm.coordinatesLoad(self.viewer, tmesh.sfXC)
tmesh.topology_dm.coordinatesLoad(self.viewer)
# Load cell_orientations for immersed meshes.
path = self._path_to_mesh_immersed(tmesh.name, name)
if path in self.h5pyfile:
cell = tmesh.ufl_cell()
element = finat.ufl.FiniteElement("DP" if cell.is_simplex else "DQ", cell, 0)
cell_orientations_tV = self._load_function_space_topology(tmesh, element)
tmesh_key = self._generate_mesh_key_from_names(tmesh.name,
tmesh._distribution_name,
tmesh._permutation_name)
sd_key = self._get_shared_data_key_for_checkpointing(tmesh, element)
_, _, lsf = self._function_load_utils[tmesh_key + sd_key]
_, _, lsf = self._shared_data_cache(tmesh)[sd_key]
nroots, _, _ = lsf.getGraph()
cell_orientations_a = np.empty(nroots, dtype=utils.IntType)
cell_orientations_a_iset = PETSc.IS().createGeneral(cell_orientations_a, comm=self._comm)
Expand Down Expand Up @@ -1192,9 +1185,9 @@ def _load_mesh_topology(self, tmesh_name, reorder, distribution_parameters):
format = ViewerHDF5.Format.HDF5_PETSC
self.viewer.pushFormat(format=format)
plex.distributionSetName(distribution_name)
sfXB = plex.topologyLoad(self.viewer)
plex.topologyLoad(self.viewer)
plex.distributionSetName(None)
plex.labelsLoad(self.viewer, sfXB)
plex.labelsLoad(self.viewer)
self.viewer.popFormat()
# These labels are distribution dependent.
# We should be able to save/load labels selectively.
Expand Down Expand Up @@ -1223,7 +1216,7 @@ def _load_mesh_topology(self, tmesh_name, reorder, distribution_parameters):
# -- Construct Mesh (Topology) --
# Use public API so pass user comm (self.comm)
tmesh = MeshTopology(plex, name=plex.getName(), reorder=reorder,
distribution_parameters=distribution_parameters, sfXB=sfXB, perm_is=perm_is,
distribution_parameters=distribution_parameters, perm_is=perm_is,
distribution_name=distribution_name, permutation_name=permutation_name,
comm=self.comm)
return tmesh
Expand Down Expand Up @@ -1270,11 +1263,8 @@ def _load_function_space(self, mesh, name):
def _load_function_space_topology(self, tmesh, element):
if element.family() == "Real":
return impl.RealFunctionSpace(tmesh, element, "unused_name")
tmesh_key = self._generate_mesh_key_from_names(tmesh.name,
tmesh._distribution_name,
tmesh._permutation_name)
sd_key = self._get_shared_data_key_for_checkpointing(tmesh, element)
if tmesh_key + sd_key not in self._function_load_utils:
if sd_key not in self._shared_data_cache(tmesh):
topology_dm = tmesh.topology_dm
dm = PETSc.DMShell().create(comm=tmesh._comm)
dm.setName(self._get_dm_name_for_checkpointing(tmesh, element))
Expand All @@ -1283,9 +1273,8 @@ def _load_function_space_topology(self, tmesh, element):
section.setPermutation(tmesh._dm_renumbering)
dm.setSection(section)
base_tmesh = tmesh._base_mesh if isinstance(tmesh, ExtrudedMeshTopology) else tmesh
sfXC = base_tmesh.sfXC
topology_dm.setName(tmesh.name)
gsf, lsf = topology_dm.sectionLoad(self.viewer, dm, sfXC)
gsf, lsf = topology_dm.sectionLoad(self.viewer, dm)
topology_dm.setName(base_tmesh.name)
nodes_per_entity, real_tensorproduct, block_size = sd_key
# Don't cache if the section has been expanded by block_size
Expand All @@ -1294,7 +1283,7 @@ def _load_function_space_topology(self, tmesh, element):
if dm.getSection() is not cached_section:
# The same section has already been cached.
dm.setSection(cached_section)
self._function_load_utils[tmesh_key + sd_key] = (dm, gsf, lsf)
self._shared_data_cache(tmesh)[sd_key] = (dm, gsf, lsf)
return impl.FunctionSpace(tmesh, element)

@PETSc.Log.EventDecorator("LoadFunction")
Expand Down Expand Up @@ -1380,10 +1369,7 @@ def _load_function_topology(self, tmesh, element, tf_name, idx=None):
with tf.dat.vec_wo as vec:
vec.setName(tf_name)
sd_key = self._get_shared_data_key_for_checkpointing(tmesh, element)
tmesh_key = self._generate_mesh_key_from_names(tmesh.name,
tmesh._distribution_name,
tmesh._permutation_name)
dm, sf, _ = self._function_load_utils[tmesh_key + sd_key]
dm, sf, _ = self._shared_data_cache(tmesh)[sd_key]
base_tmesh_name = topology_dm.getName()
topology_dm.setName(tmesh.name)
topology_dm.globalVectorLoad(self.viewer, dm, sf, vec)
Expand Down Expand Up @@ -1461,6 +1447,14 @@ def _get_dm_name_for_checkpointing(self, tmesh, ufl_element):
sd_key = self._get_shared_data_key_for_checkpointing(tmesh, ufl_element)
return self._generate_dm_name(*sd_key)

def _shared_data_cache(self, tmesh):
# Cache gsf/lsf that push forward the on-disk DoF vector to the in-memory global/local vectors.
# Cache on mesh, not on self, so that they can be used across multiple CheckpointFile instances
# (at the cost of longer life).
if not hasattr(tmesh, "_shared_data_cache"):
raise RuntimeError(f"_shared_data_cache not on {tmesh}")
return tmesh._shared_data_cache["checkpointfile_" + self.filename]

def _path_to_topologies(self):
return "topologies"

Expand Down
32 changes: 6 additions & 26 deletions firedrake/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ class AbstractMeshTopology(object, metaclass=abc.ABCMeta):
"""A representation of an abstract mesh topology without a concrete
PETSc DM implementation"""

def __init__(self, topology_dm, name, reorder, sfXB, perm_is, distribution_name, permutation_name, comm, submesh_parent=None):
def __init__(self, topology_dm, name, reorder, perm_is, distribution_name, permutation_name, comm, submesh_parent=None):
"""Initialise a mesh topology.

Parameters
Expand All @@ -501,11 +501,6 @@ def __init__(self, topology_dm, name, reorder, sfXB, perm_is, distribution_name,
Name of the mesh topology.
reorder : bool
Whether to reorder the mesh entities.
sfXB : PETSc.PetscSF
`PETSc.SF` that pushes forward the global point number
slab ``[0, NX)`` to input (naive) plex (only significant when
the mesh topology is loaded from file and only passed from inside
`~.CheckpointFile`).
perm_is : PETSc.IS
`PETSc.IS` that is used as ``_dm_renumbering``; only
makes sense if we know the exact parallel distribution of ``plex``
Expand All @@ -526,10 +521,6 @@ def __init__(self, topology_dm, name, reorder, sfXB, perm_is, distribution_name,
topology_dm.setFromOptions()
self.topology_dm = topology_dm
r"The PETSc DM representation of the mesh topology."
self.sfBC = None
r"The PETSc SF that pushes the input (naive) plex to current (good) plex."
self.sfXB = sfXB
r"The PETSc SF that pushes the global point number slab [0, NX) to input (naive) plex."
self.submesh_parent = submesh_parent
# User comm
self.user_comm = comm
Expand All @@ -540,8 +531,6 @@ def __init__(self, topology_dm, name, reorder, sfXB, perm_is, distribution_name,
self._grown_halos = False
if self.comm.size > 1:
self._add_overlap()
if self.sfXB is not None:
self.sfXC = sfXB.compose(self.sfBC) if self.sfBC else self.sfXB
dmcommon.label_facets(self.topology_dm)
dmcommon.complete_facet_labels(self.topology_dm)
# TODO: Allow users to set distribution name if they want to save
Expand Down Expand Up @@ -1067,7 +1056,6 @@ def __init__(
name,
reorder,
distribution_parameters,
sfXB=None,
perm_is=None,
distribution_name=None,
permutation_name=None,
Expand All @@ -1086,11 +1074,6 @@ def __init__(
Whether to reorder the mesh entities.
distribution_parameters : dict
Options controlling mesh distribution; see `Mesh` for details.
sfXB : PETSc.PetscSF
`PETSc.SF` that pushes forward the global point number
slab ``[0, NX)`` to input (naive) plex (only significant when
the mesh topology is loaded from file and only passed from inside
`~.CheckpointFile`).
perm_is : PETSc.IS
`PETSc.IS` that is used as ``_dm_renumbering``; only
makes sense if we know the exact parallel distribution of ``plex``
Expand Down Expand Up @@ -1121,7 +1104,7 @@ def __init__(
# Disable auto distribution and reordering before setFromOptions is called.
plex.distributeSetDefault(False)
plex.reorderSetDefault(PETSc.DMPlex.ReorderDefaultFlag.FALSE)
super().__init__(plex, name, reorder, sfXB, perm_is, distribution_name, permutation_name, comm, submesh_parent=submesh_parent)
super().__init__(plex, name, reorder, perm_is, distribution_name, permutation_name, comm, submesh_parent=submesh_parent)

def _distribute(self):
# Distribute/redistribute the dm to all ranks
Expand All @@ -1132,9 +1115,8 @@ def _distribute(self):
# refine this mesh in parallel. Later, when we actually use
# it, we grow the halo.
original_name = plex.getName()
sfBC = plex.distribute(overlap=0)
_ = plex.distribute(overlap=0)
plex.setName(original_name)
self.sfBC = sfBC
# plex carries a new dm after distribute, which
# does not inherit partitioner from the old dm.
# It probably makes sense as chaco does not work
Expand All @@ -1150,17 +1132,15 @@ def _add_overlap(self):
elif overlap_type in [DistributedMeshOverlapType.FACET, DistributedMeshOverlapType.RIDGE]:
dmcommon.set_adjacency_callback(self.topology_dm, overlap_type)
original_name = self.topology_dm.getName()
sfBC = self.topology_dm.distributeOverlap(overlap)
_ = self.topology_dm.distributeOverlap(overlap)
self.topology_dm.setName(original_name)
self.sfBC = self.sfBC.compose(sfBC) if self.sfBC else sfBC
dmcommon.clear_adjacency_callback(self.topology_dm)
self._grown_halos = True
elif overlap_type == DistributedMeshOverlapType.VERTEX:
# Default is FEM (vertex star) adjacency.
original_name = self.topology_dm.getName()
sfBC = self.topology_dm.distributeOverlap(overlap)
_ = self.topology_dm.distributeOverlap(overlap)
self.topology_dm.setName(original_name)
self.sfBC = self.sfBC.compose(sfBC) if self.sfBC else sfBC
self._grown_halos = True
else:
raise ValueError("Unknown overlap type %r" % overlap_type)
Expand Down Expand Up @@ -2025,7 +2005,7 @@ def __init__(self, swarm, parentmesh, name, reorder, input_ordering_swarm=None,
"overlap_type": (DistributedMeshOverlapType.NONE, 0)}
self.input_ordering_swarm = input_ordering_swarm
self._parent_mesh = parentmesh
super().__init__(swarm, name, reorder, None, perm_is, distribution_name, permutation_name, parentmesh.comm)
super().__init__(swarm, name, reorder, perm_is, distribution_name, permutation_name, parentmesh.comm)

def _distribute(self):
pass
Expand Down
Loading