Skip to content

Commit feb6f94

Browse files
authored
Enable solving multi-domain problems involving codim-0 submeshes (#3478)
1 parent 704b43a commit feb6f94

40 files changed

+2241
-516
lines changed

demos/saddle_point_pc/saddle_point_systems.py.rst

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ Finally, at each mesh size, we print out the number of cells in the
180180
mesh and the number of iterations the solver took to converge ::
181181

182182
#
183-
print(w.function_space().mesh().num_cells(), solver.snes.ksp.getIterationNumber())
183+
print(w.function_space().mesh().unique().num_cells(), solver.snes.ksp.getIterationNumber())
184184

185185
The resulting convergence is unimpressive:
186186

@@ -282,7 +282,7 @@ applying the action of blocks, so we can use a block matrix format. ::
282282
for n in range(8):
283283
solver, w = build_problem(n, parameters, block_matrix=True)
284284
solver.solve()
285-
print(w.function_space().mesh().num_cells(), solver.snes.ksp.getIterationNumber())
285+
print(w.function_space().mesh().unique().num_cells(), solver.snes.ksp.getIterationNumber())
286286

287287
The resulting convergence is algorithmically good, however, the larger
288288
problems still take a long time.
@@ -367,7 +367,7 @@ Let's see what happens. ::
367367
for n in range(8):
368368
solver, w = build_problem(n, parameters, block_matrix=True)
369369
solver.solve()
370-
print(w.function_space().mesh().num_cells(), solver.snes.ksp.getIterationNumber())
370+
print(w.function_space().mesh().unique().num_cells(), solver.snes.ksp.getIterationNumber())
371371

372372
This is much better, the problem takes much less time to solve and
373373
when observing the iteration counts for inverting :math:`S` we can see
@@ -422,7 +422,7 @@ and so we no longer need a flexible Krylov method. ::
422422
for n in range(8):
423423
solver, w = build_problem(n, parameters, block_matrix=True)
424424
solver.solve()
425-
print(w.function_space().mesh().num_cells(), solver.snes.ksp.getIterationNumber())
425+
print(w.function_space().mesh().unique().num_cells(), solver.snes.ksp.getIterationNumber())
426426

427427
This results in the following GMRES iteration counts
428428

@@ -487,7 +487,7 @@ variable. We can provide it as an :class:`~.AuxiliaryOperatorPC` via a python pr
487487
for n in range(8):
488488
solver, w = build_problem(n, parameters, aP=None, block_matrix=False)
489489
solver.solve()
490-
print(w.function_space().mesh().num_cells(), solver.snes.ksp.getIterationNumber())
490+
print(w.function_space().mesh().unique().num_cells(), solver.snes.ksp.getIterationNumber())
491491

492492
This actually results in slightly worse convergence than the diagonal
493493
approximation we used above.
@@ -571,7 +571,7 @@ Let's see what the iteration count looks like now. ::
571571
for n in range(8):
572572
solver, w = build_problem(n, parameters, aP=riesz, block_matrix=True)
573573
solver.solve()
574-
print(w.function_space().mesh().num_cells(), solver.snes.ksp.getIterationNumber())
574+
print(w.function_space().mesh().unique().num_cells(), solver.snes.ksp.getIterationNumber())
575575

576576
============== ==================
577577
Mesh elements GMRES iterations

firedrake/assemble.py

Lines changed: 188 additions & 55 deletions
Large diffs are not rendered by default.

firedrake/bcs.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,6 @@ def hermite_stride(bcnodes):
162162
# take intersection of facet nodes, and add it to bcnodes
163163
# i, j, k can also be strings.
164164
bcnodes1 = []
165-
if len(s) > 1 and not isinstance(self._function_space.finat_element, (finat.Lagrange, finat.GaussLobattoLegendre)):
166-
raise TypeError("Currently, edge conditions have only been tested with CG Lagrange elements")
167165
for ss in s:
168166
# intersection of facets
169167
# Edge conditions have only been tested with Lagrange elements.

firedrake/checkpointing.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,8 @@ def save_mesh(self, mesh, distribution_name=None, permutation_name=None):
566566
:kwarg distribution_name: the name under which distribution is saved; if `None`, auto-generated name will be used.
567567
:kwarg permutation_name: the name under which permutation is saved; if `None`, auto-generated name will be used.
568568
"""
569+
# TODO: Add general MeshSequence support.
570+
mesh = mesh.unique()
569571
# Handle extruded mesh
570572
tmesh = mesh.topology
571573
if mesh.extruded:
@@ -835,6 +837,8 @@ def get_timestepping_history(self, mesh, name):
835837
@PETSc.Log.EventDecorator("SaveFunctionSpace")
836838
def _save_function_space(self, V):
837839
mesh = V.mesh()
840+
# TODO: Add general MeshSequence support.
841+
mesh = mesh.unique()
838842
if isinstance(V.topological, impl.MixedFunctionSpace):
839843
V_name = self._generate_function_space_name(V)
840844
base_path = self._path_to_mixed_function_space(mesh.name, V_name)
@@ -910,10 +914,12 @@ def save_function(self, f, idx=None, name=None, timestepping_info={}):
910914
each index.
911915
"""
912916
V = f.function_space()
913-
mesh = V.mesh()
914917
if name:
915918
g = Function(V, val=f.dat, name=name)
916919
return self.save_function(g, idx=idx, timestepping_info=timestepping_info)
920+
mesh = V.mesh()
921+
# TODO: Add general MeshSequence support.
922+
mesh = mesh.unique()
917923
# -- Save function space --
918924
self._save_function_space(V)
919925
# -- Save function --
@@ -1224,6 +1230,8 @@ def _load_mesh_topology(self, tmesh_name, reorder, distribution_parameters):
12241230

12251231
@PETSc.Log.EventDecorator("LoadFunctionSpace")
12261232
def _load_function_space(self, mesh, name):
1233+
# TODO: Add general MeshSequence support.
1234+
mesh = mesh.unique()
12271235
mesh_key = self._generate_mesh_key_from_names(mesh.name,
12281236
mesh.topology._distribution_name,
12291237
mesh.topology._permutation_name)
@@ -1299,6 +1307,8 @@ def load_function(self, mesh, name, idx=None):
12991307
be loaded with idx only when it was saved with idx.
13001308
:returns: the loaded :class:`~.Function`.
13011309
"""
1310+
# TODO: Add general MeshSequence support.
1311+
mesh = mesh.unique()
13021312
tmesh = mesh.topology
13031313
if name in self._get_mixed_function_name_mixed_function_space_name_map(mesh.name):
13041314
V_name = self._get_mixed_function_name_mixed_function_space_name_map(mesh.name)[name]

firedrake/dmhooks.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343

4444
import firedrake
4545
from firedrake.petsc import PETSc
46+
from firedrake.mesh import MeshSequenceGeometry
4647

4748

4849
@PETSc.Log.EventDecorator()
@@ -53,8 +54,11 @@ def get_function_space(dm):
5354
:raises RuntimeError: if no function space was found.
5455
"""
5556
info = dm.getAttr("__fs_info__")
56-
meshref, element, indices, (name, names), boundary_sets = info
57-
mesh = meshref()
57+
meshref_tuple, element, indices, (name, names), boundary_sets = info
58+
if len(meshref_tuple) == 1:
59+
mesh = meshref_tuple[0]()
60+
else:
61+
mesh = MeshSequenceGeometry([meshref() for meshref in meshref_tuple])
5862
if mesh is None:
5963
raise RuntimeError("Somehow your mesh was collected, this should never happen")
6064
V = firedrake.FunctionSpace(mesh, element, name=name)
@@ -80,8 +84,6 @@ def set_function_space(dm, V):
8084
This stores the information necessary to make a function space given a DM.
8185
8286
"""
83-
mesh = V.mesh()
84-
8587
indices = []
8688
names = []
8789
while V.parent is not None:
@@ -92,11 +94,12 @@ def set_function_space(dm, V):
9294
assert V.index is None
9395
indices.append(V.component)
9496
V = V.parent
97+
mesh = V.mesh()
9598
if len(V) > 1:
9699
names = tuple(V_.name for V_ in V)
97100
element = V.ufl_element()
98101
boundary_sets = tuple(V_.boundary_set for V_ in V)
99-
info = (weakref.ref(mesh), element, tuple(reversed(indices)), (V.name, names), boundary_sets)
102+
info = (tuple(weakref.ref(m) for m in mesh), element, tuple(reversed(indices)), (V.name, names), boundary_sets)
100103
dm.setAttr("__fs_info__", info)
101104

102105

@@ -414,7 +417,9 @@ def coarsen(dm, comm):
414417
"""
415418
from firedrake.mg.utils import get_level
416419
V = get_function_space(dm)
417-
hierarchy, level = get_level(V.mesh())
420+
# TODO: Think harder.
421+
m, = set(m_ for m_ in V.mesh())
422+
hierarchy, level = get_level(m)
418423
if level < 1:
419424
raise RuntimeError("Cannot coarsen coarsest DM")
420425
coarsen = get_ctx_coarsener(dm)

firedrake/ensemble/ensemble_functionspace.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ class EnsembleFunctionSpaceBase:
9292
- Dual ensemble objects: :class:`EnsembleDualSpace` and :class:`~firedrake.ensemble.ensemble_function.EnsembleCofunction`.
9393
"""
9494
def __init__(self, local_spaces: Collection, ensemble: Ensemble):
95-
meshes = set(V.mesh() for V in local_spaces)
95+
meshes = set(V.mesh().unique() for V in local_spaces)
9696
nlocal_meshes = len(meshes)
9797
max_local_meshes = ensemble.ensemble_comm.allreduce(nlocal_meshes, MPI.MAX)
9898
if max_local_meshes > 1:

firedrake/function.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -593,16 +593,20 @@ def _at(self, arg, *args, **kwargs):
593593

594594
tolerance = kwargs.get('tolerance', None)
595595
mesh = self.function_space().mesh()
596+
if len(set(mesh)) == 1:
597+
mesh_unique = mesh.unique()
598+
else:
599+
raise NotImplementedError("Not implemented for general mixed meshes")
596600
if tolerance is None:
597-
tolerance = mesh.tolerance
601+
tolerance = mesh_unique.tolerance
598602
else:
599-
mesh.tolerance = tolerance
603+
mesh_unique.tolerance = tolerance
600604

601605
# Handle f._at(0.3)
602606
if not arg.shape:
603607
arg = arg.reshape(-1)
604608

605-
if mesh.variable_layers:
609+
if mesh_unique.variable_layers:
606610
raise NotImplementedError("Point evaluation not implemented for variable layers")
607611

608612
# Validate geometric dimension
@@ -778,7 +782,7 @@ def evaluate(self, function: Function) -> np.ndarray | Tuple[np.ndarray, ...]:
778782
if function.function_space().ufl_element().family() == "Real":
779783
return function.dat.data_ro
780784

781-
function_mesh = function.function_space().mesh()
785+
function_mesh = function.function_space().mesh().unique()
782786
if function_mesh is not self.mesh:
783787
raise ValueError("Function mesh must be the same Mesh object as the PointEvaluator mesh.")
784788
if coord_changed := function_mesh.coordinates.dat.dat_version != self.mesh._saved_coordinate_dat_version:

firedrake/functionspace.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
API is functional, rather than object-based, to allow for simple
55
backwards-compatibility, argument checking, and dispatch.
66
"""
7+
import itertools
78
import ufl
89
import finat.ufl
910

@@ -253,6 +254,8 @@ def MixedFunctionSpace(spaces, name=None, mesh=None):
253254
:class:`finat.ufl.mixedelement.MixedElement`, ignored otherwise.
254255
255256
"""
257+
from firedrake.mesh import MeshSequenceGeometry
258+
256259
if isinstance(spaces, finat.ufl.FiniteElementBase):
257260
# Build the spaces if we got a mixed element
258261
assert type(spaces) is finat.ufl.MixedElement and mesh is not None
@@ -267,22 +270,15 @@ def rec(eles):
267270
sub_elements.append(ele)
268271
rec(spaces.sub_elements)
269272
spaces = [FunctionSpace(mesh, element) for element in sub_elements]
270-
271-
# Check that function spaces are on the same mesh
272-
meshes = [space.mesh() for space in spaces]
273-
for i in range(1, len(meshes)):
274-
if meshes[i] is not meshes[0]:
275-
raise ValueError("All function spaces must be defined on the same mesh!")
276-
273+
# Flatten MeshSequences.
274+
meshes = list(itertools.chain(*[space.mesh() for space in spaces]))
277275
try:
278276
cls, = set(type(s) for s in spaces)
279277
except ValueError:
280278
# Neither primal nor dual
281279
# We had not implemented something in between, so let's make it primal
282280
cls = impl.WithGeometry
283281

284-
# Select mesh
285-
mesh = meshes[0]
286282
# Get topological spaces
287283
spaces = tuple(s.topological for s in flatten(spaces))
288284
# Error checking
@@ -296,10 +292,9 @@ def rec(eles):
296292
else:
297293
raise ValueError("Can't make mixed space with %s" % type(space))
298294

299-
new = impl.MixedFunctionSpace(spaces, name=name)
300-
if mesh is not mesh.topology:
301-
new = cls.create(new, mesh)
302-
return new
295+
mixed_mesh_geometry = MeshSequenceGeometry(meshes)
296+
new = impl.MixedFunctionSpace(spaces, mixed_mesh_geometry.topology, name=name)
297+
return cls.create(new, mixed_mesh_geometry)
303298

304299

305300
@PETSc.Log.EventDecorator("CreateFunctionSpace")

0 commit comments

Comments
 (0)