Skip to content
Open
1 change: 1 addition & 0 deletions src/NeuralPDE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ using RuntimeGeneratedFunctions
using SciMLBase
using Statistics
using ArrayInterface
using LinearAlgebra
import Optim
using DomainSets
using Symbolics
Expand Down
200 changes: 200 additions & 0 deletions src/adaptive_losses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -257,3 +257,203 @@ function generate_adaptive_loss_function(pinnrep::PINNRepresentation,
nothing
end
end

"""
Inverse Dirichlet Adaptive Loss
```julia
function InverseDirichletAdaptiveLoss(reweight_every;
weight_change_inertia = 0.5,
pde_loss_weights = 1,
bc_loss_weights = 1,
additional_loss_weights = 1)
```

Inverse Dirichlet weighting enables reliable training of physics informed neural networks
Suryanarayana Maddu, Dominik Sturm, Christian L Müller, and Ivo F Sbalzarini
https://iopscience.iop.org/article/10.1088/2632-2153/ac3712/pdf
with code reference
https://github.com/mosaic-group/inverse-dirichlet-pinn
"""
mutable struct InverseDirichletAdaptiveLoss{T <: Real} <: AbstractAdaptiveLoss
reweight_every::Int64
weight_change_inertia::T
pde_loss_weights::Vector{T}
bc_loss_weights::Vector{T}
additional_loss_weights::Vector{T}
SciMLBase.@add_kwonly function InverseDirichletAdaptiveLoss{T}(reweight_every;
weight_change_inertia = 0.5,
pde_loss_weights = 1,
bc_loss_weights = 1,
additional_loss_weights = 1) where {
T <:
Real
}
new(convert(Int64, reweight_every), convert(T, weight_change_inertia),
vectorify(pde_loss_weights, T), vectorify(bc_loss_weights, T),
vectorify(additional_loss_weights, T))
end
end
# default to Float64
SciMLBase.@add_kwonly function InverseDirichletAdaptiveLoss(reweight_every;
weight_change_inertia = 0.5,
pde_loss_weights = 1,
bc_loss_weights = 1,
additional_loss_weights = 1)
InverseDirichletAdaptiveLoss{Float64}(reweight_every;
weight_change_inertia = weight_change_inertia,
pde_loss_weights = pde_loss_weights,
bc_loss_weights = bc_loss_weights,
additional_loss_weights = additional_loss_weights)
end

function generate_adaptive_loss_function(pinnrep::PINNRepresentation,
adaloss::InverseDirichletAdaptiveLoss,
pde_loss_functions, bc_loss_functions)
weight_change_inertia = adaloss.weight_change_inertia
iteration = pinnrep.iteration
adaloss_T = eltype(adaloss.pde_loss_weights)

function run_loss_inverse_dirichlet_gradients_adaptive_loss(θ, pde_losses, bc_losses)
if iteration[1] % adaloss.reweight_every == 0
# the paper assumes a single pde loss function, so here we grab the maximum of the maximums of each pde loss function
pde_stds = [std(Zygote.gradient(pde_loss_function, θ)[1])
for pde_loss_function in pde_loss_functions]
pde_std_max = maximum(pde_stds)
bc_stds = [std(Zygote.gradient(bc_loss_function, θ)[1])
for bc_loss_function in bc_loss_functions]
bc_std_max = maximum(bc_stds)
gamma = max(pde_std_max, bc_std_max)

bc_loss_weights_proposed = gamma ./ (bc_stds)
adaloss.bc_loss_weights .= weight_change_inertia .* adaloss.bc_loss_weights .+
(1 .- weight_change_inertia) .* bc_loss_weights_proposed

pde_loss_weights_proposed = gamma ./ (pde_stds)
adaloss.pde_loss_weights .= weight_change_inertia .* adaloss.pde_loss_weights .+
(1 .- weight_change_inertia) .* pde_loss_weights_proposed
nonzero_divisor_eps = adaloss_T isa Float64 ? Float64(1e-11) :
convert(adaloss_T, 1e-7)
bc_loss_weights_proposed = gamma ./
(bc_stds)
adaloss.bc_loss_weights .= weight_change_inertia .*
adaloss.bc_loss_weights .+
(1 .- weight_change_inertia) .*
bc_loss_weights_proposed

logscalar(pinnrep.logger, gamma,
"adaptive_loss/gamma", iteration[1])
logvector(pinnrep.logger, pde_stds,
"adaptive_loss/pde_stds", iteration[1])
logvector(pinnrep.logger, bc_stds,
"adaptive_loss/bc_stds", iteration[1])
logvector(pinnrep.logger, adaloss.bc_loss_weights,
"adaptive_loss/bc_loss_weights", iteration[1])
logvector(pinnrep.logger, adaloss.pde_loss_weights,
"adaptive_loss/pde_loss_weights", iteration[1])
end
nothing
end
end

"""
Neural Tangent Kernel Adaptive Loss
``` julla
NeuralTangentKernelAdaptiveLoss(reweight_every;
pde_max_optimiser = Flux.ADAM(1e-4),
bc_max_optimiser = Flux.ADAM(0.5),
pde_loss_weights = 1,
bc_loss_weights = 1,
additional_loss_weights = 1)
```

A way of adaptively reweighing the components of the loss function by using the
values of the Jacobian of the
NTK predictions at the current point (infinite width assumption).

#References :

When and Why PINNs Fail to Train: A Neural Tangent kernel perspective
https://arxiv.org/pdf/2007.14527.pdf
"""
mutable struct NeuralTangentKernelAdaptiveLoss{T <: Real,
PDE_OPT <: Flux.Optimise.AbstractOptimiser,
BC_OPT <: Flux.Optimise.AbstractOptimiser} <: AbstractAdaptiveLoss
reweight_every::Int64
pde_max_optimiser::PDE_OPT
bc_max_optimiser::BC_OPT
pde_loss_weights::Vector{T}
bc_loss_weights::Vector{T}
additional_loss_weights::Vector{T}
SciMLBase.@add_kwonly function NeuralTangentKernelAdaptiveLoss{T,
PDE_OPT, BC_OPT}(reweight_every;
pde_max_optimiser = Flux.ADAM(1e-4),
bc_max_optimiser = Flux.ADAM(0.5),
pde_loss_weights = 1,
bc_loss_weights = 1,
additional_loss_weights = 1) where {
T <:
Real,
PDE_OPT <:
Flux.Optimise.AbstractOptimiser,
BC_OPT <:
Flux.Optimise.AbstractOptimiser
}
new(convert(Int64, reweight_every), convert(PDE_OPT, pde_max_optimiser),
convert(BC_OPT, bc_max_optimiser),
vectorify(pde_loss_weights, T), vectorify(bc_loss_weights, T),
vectorify(additional_loss_weights, T))
end
end

SciMLBase.@add_kwonly function NeuralTangentKernelAdaptiveLoss(reweight_every;
pde_max_optimiser = Flux.ADAM(1e-4),
bc_max_optimiser = Flux.ADAM(0.5),
pde_loss_weights = 1,
bc_loss_weights = 1,
additional_loss_weights = 1) where {
T <:
Real,
PDE_OPT <:
Flux.Optimise.AbstractOptimiser,
BC_OPT <:
Flux.Optimise.AbstractOptimiser
}
NeuralTangentKernelAdaptiveLoss{Float64, typeof(pde_max_optimiser),
typeof(bc_max_optimiser)}(reweight_every;
pde_max_optimiser = pde_max_optimiser,
bc_max_optimiser = bc_max_optimiser,
pde_loss_weights = pde_loss_weights,
bc_loss_weights = bc_loss_weights,
additional_loss_weights = additional_loss_weights)
end

function generate_adaptive_loss_function(pinnrep::PINNRepresentation,
adaloss::NeuralTangentKernelAdaptiveLoss,
pde_loss_functions, bc_loss_functions)

pde_max_optimiser = adaloss.pde_max_optimiser
bc_max_optimiser = adaloss.bc_max_optimiser
iteration = pinnrep.iteration
adaloss_T = eltype(adaloss.pde_loss_weights)

function run_loss_ntk_adaptive_loss(θ, pde_losses, bc_losses)
if iteration[1] % adaloss.reweight_every == 0
bc_vec = [dot(vec(Zygote.gradient(bc_loss_function, θ)[1]), vec(Zygote.gradient(bc_loss_function, θ)[1]))
for bc_loss_function in bc_loss_functions]
bc_trace = sum(bc_vec)
pde_vec = [dot(vec(Zygote.gradient(pde_loss_function, θ)[1]), vec(Zygote.gradient(pde_loss_function, θ)[1]))
for pde_loss_function in pde_loss_functions]
pde_trace = sum(pde_vec)

adaloss.bc_loss_weights .= ( bc_trace .+ pde_trace ) ./ (bc_trace)
adaloss.pde_loss_weights .= ( bc_trace .+ pde_trace ) ./ (pde_trace)

logvector(pinnrep.logger, adaloss.pde_loss_weights,
"adaptive_loss/pde_loss_weights", iteration[1])
logvector(pinnrep.logger, adaloss.bc_loss_weights,
"adaptive_loss/bc_loss_weights",
iteration[1])
end
nothing
end
end
14 changes: 12 additions & 2 deletions test/adaptive_loss_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,14 @@ import Lux
nonadaptive_loss = NeuralPDE.NonAdaptiveLoss(pde_loss_weights = 1, bc_loss_weights = 1)
gradnormadaptive_loss = NeuralPDE.GradientScaleAdaptiveLoss(100, pde_loss_weights = 1e3,
bc_loss_weights = 1)
adaptive_loss = NeuralPDE.MiniMaxAdaptiveLoss(100; pde_loss_weights = 1,
minimaxadaptive_loss = NeuralPDE.MiniMaxAdaptiveLoss(100; pde_loss_weights = 1,
bc_loss_weights = 1)
adaptive_losses = [nonadaptive_loss, gradnormadaptive_loss, adaptive_loss]
inversedirichletadaptive_loss = NeuralPDE.InverseDirichletAdaptiveLoss(100, pde_loss_weights = 1e3,
bc_loss_weights = 1)
neuraltangentkerneladaptive_loss = NeuralPDE.NeuralTangentKernelAdaptiveLoss(100; pde_loss_weights = 1,
bc_loss_weights = 1)
adaptive_losses = [nonadaptive_loss, gradnormadaptive_loss, minimaxadaptive_loss,
inversedirichletadaptive_loss, neuraltangentkerneladaptive_loss, ]
maxiters = 4000
seed = 60

Expand Down Expand Up @@ -89,11 +94,16 @@ error_results_no_logs = map(test_2d_poisson_equation_adaptive_loss_no_logs_run_s
@show error_results_no_logs[1][:total_diff_rel]
@show error_results_no_logs[2][:total_diff_rel]
@show error_results_no_logs[3][:total_diff_rel]
@show error_results_no_logs[4][:total_diff_rel]
@show error_results_no_logs[5][:total_diff_rel]

# accuracy tests, these work for this specific seed but might not for others
# note that this doesn't test that the adaptive losses are outperforming the nonadaptive loss, which is not guaranteed, and seed/arch/hyperparam/pde etc dependent
@test error_results_no_logs[1][:total_diff_rel] < 0.4
@test error_results_no_logs[2][:total_diff_rel] < 0.4
@test error_results_no_logs[3][:total_diff_rel] < 0.4
@test error_results_no_logs[4][:total_diff_rel] < 0.4
@test error_results_no_logs[5][:total_diff_rel] < 0.4

#plots_diffs[1][:plot]
#plots_diffs[2][:plot]
Expand Down