Skip to content

Update Brusselators with Multiple Shooting #1228

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 121 additions & 81 deletions docs/src/examples/pde/brusselator.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Learning Nonlinear Reaction Dynamics in the 2D Brusselator PDE Using Universal Differential Equations
# Learning Nonlinear Reaction Dynamics in the 2D Brusselator PDE Using Universal Differential Equations and Multiple Shooting

## Introduction

Expand All @@ -8,71 +8,70 @@ The Brusselator is a mathematical model used to describe oscillating chemical re

The Brusselator PDE is defined on a unit square periodic domain as follows:

$$
\frac{\partial U}{\partial t} = B + U^2V - (A+1)U + \alpha \nabla^2 U + f(x, y, t)
$$
```math
\frac{\partial U}{\partial t} = B + U^2V - (A+1)U + \alpha \nabla^2 U + f(x, y, t)
```

$$
\frac{\partial V}{\partial t} = AU - U^2V + \alpha \nabla^2 V
$$
```math
\frac{\partial V}{\partial t} = AU - U^2V + \alpha \nabla^2
```

where $A=3.4, B=1$ and the forcing term is:

$$
```math
f(x, y, t) =
\begin{cases}
5 & \text{if } (x - 0.3)^2 + (y - 0.6)^2 \leq 0.1^2 \text{ and } t \geq 1.1 \\
0 & \text{otherwise}
\end{cases}
$$
```

and the Laplacian operator is:

$$
```math
\nabla^2 = \frac{\partial^2}{\partial x^2} + \frac{\partial^2}{\partial y^2}
$$
```

These equations are solved over the time interval:

$$
```math
t \in [0, 11.5]
$$
```

with the initial conditions:

$$
U(x, y, 0) = 22 \cdot \left( y(1 - y) \right)^{3/2}
$$
```math
U(x, y, 0) = 22 \cdot \left( y(1 - y) \right)^{3/2}
```

$$
```math
V(x, y, 0) = 27 \cdot \left( x(1 - x) \right)^{3/2}
$$
```

and the periodic boundary conditions:

$$
U(x + 1, y, t) = U(x, y, t)
$$

$$
```math
U(x + 1, y, t) = U(x, y, t)
```
```math
V(x, y + 1, t) = V(x, y, t)
$$
```

## Numerical Discretization

f
To numerically solve this PDE, we discretize the unit square domain using $N$ grid points along each spatial dimension. The variables $U[i,j]$ and $V[i,j]$ then denote the concentrations at the grid point $(i, j)$ at a given time $t$.

We represent the spatially discretized fields as:

$$
```math
U[i,j] = U(i \cdot \Delta x, j \cdot \Delta y), \quad V[i,j] = V(i \cdot \Delta x, j \cdot \Delta y),
$$
```

where $\Delta x = \Delta y = \frac{1}{N}$ for a grid of size $N \times N$. To organize the simulation state efficiently, we store both $ U $ and $ V $ in a single 3D array:

$$
```math
u[i,j,1] = U[i,j], \quad u[i,j,2] = V[i,j],
$$
```

giving us a field tensor of shape $(N, N, 2)$. This structure is flexible and extends naturally to systems with additional field variables.

Expand All @@ -81,44 +80,46 @@ giving us a field tensor of shape $(N, N, 2)$. This structure is flexible and ex

For spatial derivatives, we apply a second-order central difference scheme using a three-point stencil. The Laplacian is discretized as:

$$
```math
[\ 1,\ -2,\ 1\ ]
$$
```

in both the $ x $ and $ y $ directions, forming a tridiagonal structure in both the x and y directions; applying this 1D stencil (scaled appropriately by $\frac{1}{Δx^2}$ or $\frac{1}{Δy^2}$) along each axis and summing the contributions yields the standard 5-point stencil computation for the 2D Laplacian. Periodic boundary conditions are incorporated by wrapping the stencil at the domain edges, effectively connecting the boundaries. The nonlinear interaction terms are computed directly at each grid point, making the implementation straightforward and local in nature.
in both the $x$ and $y$ directions, forming a tridiagonal structure in both the x and y directions; applying this 1D stencil (scaled appropriately by $\frac{1}{Δx^2}$ or $\frac{1}{Δy^2}$) along each axis and summing the contributions yields the standard 5-point stencil computation for the 2D Laplacian. Periodic boundary conditions are incorporated by wrapping the stencil at the domain edges, effectively connecting the boundaries. The nonlinear interaction terms are computed directly at each grid point, making the implementation straightforward and local in nature.

## Generating Training Data

This provides us with an `ODEProblem` that can be solved to obtain training data.

```@example bruss
using ComponentArrays, Random, Plots, OrdinaryDiffEq
using ComponentArrays, Random, Plots, OrdinaryDiffEq, Statistics
using Lux, Optimization, OptimizationOptimJL, SciMLSensitivity, Zygote, OptimizationOptimisers

# Grid and Time Setup
N_GRID = 16
XYD = range(0f0, stop = 1f0, length = N_GRID)
dx = step(XYD)
T_FINAL = 11.5f0
SAVE_AT = 0.5f0
tspan = (0.0f0, T_FINAL)
t_points = range(tspan[1], stop=tspan[2], step=SAVE_AT)
t_points = collect(range(tspan[1], stop=tspan[2], step=SAVE_AT))
A, B, alpha = 3.4f0, 1.0f0, 10.0f0

brusselator_f(x, y, t) = (((x - 0.3f0)^2 + (y - 0.6f0)^2) <= 0.01f0) * (t >= 1.1f0) * 5.0f0
# Helper Functions
limit(a, N) = a == 0 ? N : a == N+1 ? 1 : a

brusselator_f(x, y, t) = (((x - 0.3f0)^2 + (y - 0.6f0)^2) <= 0.01f0) * (t >= 1.1f0) * 5.0f0

function init_brusselator(xyd)
println("[Init] Creating initial condition array...")
u0 = zeros(Float32, N_GRID, N_GRID, 2)
for I in CartesianIndices((N_GRID, N_GRID))
x, y = xyd[I[1]], xyd[I[2]]
u0[I,1] = 22f0 * (y * (1f0 - y))^(3f0/2f0)
u0[I,2] = 27f0 * (x * (1f0 - x))^(3f0/2f0)
end
println("[Init] Done.")
return u0
end
u0 = init_brusselator(XYD)

# Ground Truth PDE
function pde_truth!(du, u, p, t)
A, B, alpha, dx = p
αdx = alpha / dx^2
Expand All @@ -135,17 +136,16 @@ function pde_truth!(du, u, p, t)
end
end

u0 = init_brusselator(XYD)
p_tuple = (A, B, alpha, dx)
@time sol_truth = solve(ODEProblem(pde_truth!, u0, tspan, p_tuple), FBDF(), saveat=t_points)
sol_truth = solve(ODEProblem(pde_truth!, u0, tspan, p_tuple), FBDF(), saveat=t_points)
u_true = Array(sol_truth)
```

## Visualizing Mean Concentration Over Time

We can now use this code for training our UDE, and generating time-series plots of the concentrations of species of U and V using the code:
```@example bruss
using Plots, Statistics

# Compute average concentration at each timestep
avg_U = [mean(snapshot[:, :, 1]) for snapshot in sol_truth.u]
avg_V = [mean(snapshot[:, :, 2]) for snapshot in sol_truth.u]
Expand All @@ -158,27 +158,20 @@ plot!(sol_truth.t, avg_V, label="Mean V", lw=2, linestyle=:dash)

With the ground truth data generated and visualized, we are now ready to construct a Universal Differential Equation (UDE) by replacing the nonlinear term $U^2V$ with a neural network. The next section outlines how we define this hybrid model and train it to recover the reaction dynamics from data.

## Universal Differential Equation (UDE) Formulation

In the original Brusselator model, the nonlinear reaction term \( U^2V \) governs key dynamic behavior. In our UDE approach, we replace this known term with a trainable neural network \( \mathcal{N}_\theta(U, V) \), where \( \theta \) are the learnable parameters.
## Universal Differential Equation (UDE) Formulation with Multiple Shooting

The resulting system becomes:

$$
```math
\frac{\partial U}{\partial t} = 1 + \mathcal{N}_\theta(U, V) - 4.4U + \alpha \nabla^2 U + f(x, y, t)
$$
```

$$
```math
\frac{\partial V}{\partial t} = 3.4U - \mathcal{N}_\theta(U, V) + \alpha \nabla^2 V
$$
```

Here, $\mathcal{N}_\theta(U, V)$ is trained to approximate the true interaction term $U^2V$ using simulation data. This hybrid formulation allows us to recover unknown or partially known physical processes while preserving the known structural components of the PDE.

First, we have to define and configure the neural network that has to be used for the training. The implementation for that is as follows:

```@example bruss
using Lux, Random, Optimization, OptimizationOptimJL, SciMLSensitivity, Zygote

model = Lux.Chain(Dense(2 => 16, tanh), Dense(16 => 1))
rng = Random.default_rng()
ps_init, st = Lux.setup(rng, model)
Expand Down Expand Up @@ -216,62 +209,109 @@ function pde_ude!(du, u, ps_nn, t)
end
prob_ude_template = ODEProblem(pde_ude!, u0, tspan, ps_init)
```
## Loss Function and Optimization
To train the neural network
$\mathcal{N}_\theta(U, V)$ embedded in the UDE, we define a loss function that measures how closely the solution of the UDE matches the ground truth data generated earlier.

The loss is computed as the sum of squared errors between the predicted solution from the UDE and the true solution at each saved time point. If the solver fails (e.g., due to numerical instability or incorrect parameters), we return an infinite loss to discard that configuration during optimization. We use ```FBDF()``` as the solver due to the stiff nature of the brusselators euqation. Other solvers like ```KenCarp47()``` could also be used.
### Multiple Shooting
Traditional single-shooting training for stiff PDEs like the Brusselator often leads to instability or suboptimal learning due to long simulation horizons. Multiple shooting mitigates this by dividing the overall time span into shorter, manageable segments. This:

* Prevents error accumulation,
* Encourages better generalization,
* And enforces continuity between segments.

First, we have to conduct the time segmentation:
```@example bruss
segment_duration = 2.5f0 # 5 steps of SAVE_AT
n_segments = floor(Int, T_FINAL / segment_duration) # This will calculate n_segments = 4

# Create segments based on the duration, not a fixed number
segment_times = range(tspan[1], step=segment_duration, length=n_segments + 1)
segment_spans = [(segment_times[i], segment_times[i+1]) for i in 1:n_segments]

# The rest of the code remains the same
segment_saves = [collect(range(t[1], stop=t[2], step=SAVE_AT)) for t in segment_spans]

To efficiently compute gradients of the loss with respect to the neural network parameters, we use an adjoint sensitivity method (`GaussAdjoint`), which performs high-accuracy quadrature-based integration of the adjoint equations. This approach enables scalable and memory-efficient training for stiff PDEs by avoiding full trajectory storage while maintaining accurate gradient estimates.
function match_time_indices(t_points, segment_saves)
return [map(ti -> findmin(abs.(t_points .- ti))[2], segment_saves[i]) for i in 1:length(segment_saves)]
end

The loss function and initial evaluation are implemented as follows:
segment_time_indices = match_time_indices(t_points, segment_saves)
```

Then, we create an individual problem for each segment:
```@example bruss
println("[Loss] Defining loss function...")
function loss_fn(ps, _)
prob = remake(prob_ude_template, p=ps)
sol = solve(prob, FBDF(), saveat=t_points)
# Failed solve
if !SciMLBase.successful_retcode(sol)
return Inf32
end
pred = Array(sol)
lval = sum(abs2, pred .- u_true) / length(u_true)
return lval
function get_segment_prob(ps, u0_seg, seg_idx)
remake(prob_ude_template, u0=u0_seg, tspan=segment_spans[seg_idx], p=ps)
end
```

Once the loss function is defined, we use the ADAM optimizer to train the neural network. The optimization problem is defined using SciML's ```Optimization.jl``` tools, and gradients are computed via automatic differentiation using ```AutoZygote()``` from ```SciMLSensitivity```:
#### Loss Function and Optimization
To train the neural network
$\mathcal{N}_\theta(U, V)$ embedded in the UDE, we implement a multiple shooting loss function that segments the full simulation into smaller time intervals and enforces temporal consistency across them.

For each segment, the loss is computed as the sum of squared errors between the predicted solution and the ground truth data at saved time points. To ensure continuity across segments, we introduce a penalty ($\lambda$) that measures the difference between the final predicted state of one segment and the initial true state of the next. If any segment fails to solve (due to instability or divergence), an infinite loss is returned to discard that parameter configuration during optimization.

Although adjoint sensitivity methods such as `GaussAdjoint` are often used in stiff problems to reduce memory load, multiple shooting naturally mitigates this need by shortening the integration window for each segment. Hence, we rely on `AutoZygote()` for automatic differentiation in our implementation.

This approach improves training robustness by constraining long-term predictions and encouraging accurate short-term learning within each segment. The final optimization is carried out using the `ADAM` algorithm over all neural network parameters.


The loss function is defined below:
```@example bruss
println("[Training] Starting optimization...")
using OptimizationOptimisers
optf = OptimizationFunction(loss_fn, AutoZygote())
λ = 10.0f0
function loss_fn_multi(ps, _)
total_loss = 0f0
u0_seg = copy(u0)
for i in 1:n_segments
prob_i = get_segment_prob(ps, u0_seg, i)
sol_i = solve(prob_i, FBDF(), saveat=segment_saves[i])
if !SciMLBase.successful_retcode(sol_i)
return Inf32
end
pred_i = Array(sol_i)
t_idxs = segment_time_indices[i]
println("Segment $i: matched indices = ", t_idxs)
if isempty(t_idxs)
error("No matching time points for segment $i — check SAVE_AT, t_points, or tolerance.")
end
true_i = u_true[:,:,:,t_idxs]
total_loss += sum(abs2, pred_i .- true_i) / length(true_i)
if i < n_segments
u0_seg = pred_i[:,:,:,end]
next_u0 = u_true[:,:,:,t_idxs[end]+1]
total_loss += λ * sum(abs2, u0_seg .- next_u0) / length(next_u0)
end
end
return total_loss
end
```
Once the loss function is defined, we use the ADAM optimizer to train the neural network. The optimization problem is defined using SciML's ```Optimization.jl``` tools, and gradients are computed via automatic differentiation using ```AutoZygote()``` from ```SciMLSensitivity```:
```@example bruss
optf = OptimizationFunction(loss_fn_multi, AutoZygote())
optprob = OptimizationProblem(optf, ps_init)
loss_history = Float32[]


epoch_counter = Ref(0)
callback = (ps, l) -> begin
epoch_counter[] += 1
push!(loss_history, l)
println("Epoch $(length(loss_history)): Loss = $l")
println("Epoch $(epoch_counter[]): Loss = $l")
false
end
```

Finally to run everything:

```@example bruss
res = solve(optprob, Optimisers.Adam(0.01), callback=callback, maxiters=100)
res = solve(optprob, Optimisers.Adam(0.01), callback=callback, maxiters=10000)
```

```@example bruss
res.objective
```

```@example bruss
println("[Plot] Final U/V comparison plots...")
center = N_GRID ÷ 2
sol_final = solve(remake(prob_ude_template, p=res.u), FBDF(), saveat=t_points)
sol_final = solve(remake(prob_ude_template, p=res.u), FBDF(), saveat=t_points, abstol=1e-6, reltol=1e-6)

pred = Array(sol_final)

p1 = plot(t_points, u_true[center,center,1,:], lw=2, label="U True")
Expand All @@ -287,6 +327,6 @@ plot(p1, p2, layout=(1,2), size=(900,400))

## Results and Conclusion

After training the Universal Differential Equation (UDE), we compared the predicted dynamics to the ground truth for both chemical species.
After training the Universal Differential Equation (UDE) using the multiple shooting strategy, we compared the predicted dynamics to the ground truth for both chemical species.

The low training loss shows us that the neural network in the UDE was able to understand the underlying dynamics, and it was able to learn the $U^2V$ term in the partial differential equation.
The low training loss across segments demonstrates that the neural network was able to accurately capture the underlying reaction dynamics. The model effectively learned the nonlinear $U^2V$ term through a segment-wise optimization process that enforces both data fidelity and inter-segment continuity. This confirms that multiple shooting not only stabilizes training but also enhances temporal consistency in learning complex spatiotemporal PDE systems.