Skip to content

Commit cf01c4a

Browse files
authored
Merge pull request #96 from sintefmath/dev
State gradients and minor fixes
2 parents e354cc8 + 96aebc6 commit cf01c4a

File tree

16 files changed

+207
-30
lines changed

16 files changed

+207
-30
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Jutul"
22
uuid = "2b460a1a-8a2b-45b2-b125-b5c536396eb9"
33
authors = ["Olav Møyner <[email protected]>"]
4-
version = "0.2.37"
4+
version = "0.2.38"
55

66
[deps]
77
AlgebraicMultigrid = "2169fc97-5a83-5252-b627-83903c6c433c"

ext/JutulMakieExt/interactive_3d.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ function plot_interactive_impl(grid, states;
6868
)
6969
has_primitives = !isnothing(primitives)
7070
active_filters = []
71-
if states isa AbstractDict
71+
if states isa AbstractDict || states isa DataDomain
7272
states = [states]
7373
end
7474
if states isa AbstractVecOrMat && eltype(states)<:AbstractFloat

ext/JutulMakieExt/mesh_plots.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,11 @@ function Jutul.plot_mesh_impl!(ax, m;
6666
tri = tri[keep, :]
6767
tri, pts = remove_unused_points(tri, pts)
6868
end
69-
f = mesh!(ax, pts, tri; color = color, backlight = 1, kwarg...)
69+
if length(pts) > 0
70+
f = mesh!(ax, pts, tri; color = color, backlight = 1, kwarg...)
71+
else
72+
f = nothing
73+
end
7074
return f
7175
end
7276

src/ad/gradients.jl

Lines changed: 85 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@ function solve_adjoint_sensitivities(model, states, reports_or_timesteps, G;
6363
jutul_message("Adjoints", "Storing sensitivities.", color = :blue)
6464
end
6565
out = store_sensitivities(parameter_model, ∇G, storage.parameter_map)
66+
s0_map = storage.state0_map
67+
if !ismissing(s0_map)
68+
store_sensitivities!(out, storage.backward.model, ∇G, s0_map)
69+
end
6670
end
6771
if extra_output
6872
return (out, storage)
@@ -107,24 +111,27 @@ function setup_adjoint_storage(model;
107111
parameters = setup_parameters(model),
108112
n_objective = nothing,
109113
targets = parameter_targets(model),
114+
include_state0 = false,
110115
use_sparsity = true,
111116
linear_solver = select_linear_solver(model, mode = :adjoint, rtol = 1e-6),
112117
param_obj = true,
113118
info_level = 0,
114119
kwarg...
115120
)
116-
# Set up the generic adjoint storage
117-
storage = setup_adjoint_storage_base(
118-
model, state0, parameters,
119-
use_sparsity = use_sparsity,
120-
linear_solver = linear_solver,
121-
n_objective = n_objective,
122-
info_level = info_level
123-
)
124121
# Create parameter model for ∂Fₙ / ∂p
125122
parameter_model = adjoint_parameter_model(model, targets)
123+
n_prm = number_of_degrees_of_freedom(parameter_model)
126124
# Note that primary is here because the target parameters are now the primaries for the parameter_model
127125
parameter_map, = variable_mapper(parameter_model, :primary, targets = targets; kwarg...)
126+
if include_state0
127+
state0_map, = variable_mapper(model, :primary)
128+
n_state0 = number_of_degrees_of_freedom(model)
129+
state0_vec = zeros(n_state0)
130+
else
131+
state0_map = missing
132+
state0_vec = missing
133+
n_state0 = 0
134+
end
128135
# Transfer over parameters and state0 variables since many parameters are now variables
129136
state0_p = swap_variables(state0, parameters, parameter_model, variables = true)
130137
parameters_p = swap_variables(state0, parameters, parameter_model, variables = false)
@@ -137,11 +144,21 @@ function setup_adjoint_storage(model;
137144
dobj_dparam = nothing
138145
param_buf = nothing
139146
end
147+
# Set up the generic adjoint storage
148+
storage = setup_adjoint_storage_base(
149+
model, state0, parameters,
150+
use_sparsity = use_sparsity,
151+
linear_solver = linear_solver,
152+
n_objective = n_objective,
153+
info_level = info_level,
154+
)
140155
storage[:dparam] = dobj_dparam
141156
storage[:param_buf] = param_buf
142157
storage[:parameter] = parameter_sim
143158
storage[:parameter_map] = parameter_map
144-
storage[:n] = number_of_degrees_of_freedom(parameter_model)
159+
storage[:state0_map] = state0_map
160+
storage[:dstate0] = state0_vec
161+
storage[:n] = n_prm
145162

146163
return storage
147164
end
@@ -151,7 +168,7 @@ function setup_adjoint_storage_base(model, state0, parameters;
151168
linear_solver = select_linear_solver(model, mode = :adjoint, rtol = 1e-8),
152169
n_objective = nothing,
153170
info_level = 0
154-
)
171+
)
155172
primary_model = adjoint_model_copy(model)
156173
# Standard model for: ∂Fₙᵀ / ∂xₙ
157174
forward_sim = Simulator(primary_model, state0 = deepcopy(state0), parameters = deepcopy(parameters), mode = :forward, extra_timing = nothing)
@@ -236,6 +253,8 @@ function solve_adjoint_sensitivities!(∇G, storage, states, state0, timesteps,
236253
end
237254
rescale_sensitivities!(∇G, storage.parameter.model, storage.parameter_map)
238255
@assert all(isfinite, ∇G)
256+
# Finally deal with initial state gradients
257+
update_state0_sensitivities!(storage)
239258
return ∇G
240259
end
241260

@@ -720,6 +739,28 @@ function store_sensitivities!(out, model, variables, result, prm_map, ::Equation
720739
return out
721740
end
722741

742+
function store_sensitivities!(out, model, variables, result, prm_map, ::Union{EquationMajorLayout, BlockMajorLayout})
743+
scalar_valued_objective = result isa AbstractVector
744+
@assert scalar_valued_objective "Only supported for scalar objective"
745+
746+
us = get_primary_variable_ordered_entities(model)
747+
@assert length(us) == 1 "This function is not implemented for more than one entity type for primary variables"
748+
u = only(us)
749+
bz = degrees_of_freedom_per_entity(model, u)
750+
ne = count_active_entities(model.domain, u)
751+
752+
offset = 1
753+
for (k, var) in pairs(variables)
754+
m = degrees_of_freedom_per_entity(model, var)
755+
var::ScalarVariable
756+
pos = offset:bz:(bz*(ne-1)+offset)
757+
@assert length(pos) == ne "$(length(pos))"
758+
out[k] = result[pos]
759+
offset += 1
760+
end
761+
return out
762+
end
763+
723764
function extract_sensitivity_subset(r, var, n, m, offset)
724765
if var isa ScalarVariable
725766
v = r
@@ -845,3 +886,37 @@ function adjoint_transfer_canonical_order_inner!(λ, dx, model, ::BlockMajorLayo
845886
end
846887
end
847888
end
889+
890+
function update_state0_sensitivities!(storage)
891+
state0_map = storage.state0_map
892+
if !ismissing(state0_map)
893+
sim = storage.backward
894+
model = sim.model
895+
if model isa MultiModel
896+
for (k, v) in pairs(model.models)
897+
@assert matrix_layout(v.context) isa EquationMajorLayout
898+
end
899+
else
900+
@assert matrix_layout(model.context) isa EquationMajorLayout
901+
end
902+
# Assume that this gets called at the end when everything has been set
903+
# up in terms of the simulators
904+
λ = storage.lagrange
905+
∇x = storage.dstate0
906+
@. ∇x = 0.0
907+
# order = collect(eachindex(λ))
908+
# renum = similar(order)
909+
# TODO: Finish this part and remove the assertions above
910+
# Get order to put values into canonical order
911+
# adjoint_transfer_canonical_order!(renum, order, model)
912+
# λ_renum = similar(λ)
913+
# @. λ_renum[renum] = λ
914+
# model_prm = storage.parameter.model
915+
lsys_b = sim.storage.LinearizedSystem
916+
op_b = linear_operator(lsys_b, skip_red = true)
917+
# tmp = zeros(size(∇x))
918+
sens_add_mult!(∇x, op_b, λ)
919+
# adjoint_transfer_canonical_order!(∇x, tmp, model)
920+
rescale_sensitivities!(∇x, sim.model, storage.state0_map)
921+
end
922+
end

src/core_types/core_types.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -884,6 +884,15 @@ function Base.getindex(case::JutulCase, ix::Int)
884884
return case[ix:ix]
885885
end
886886

887+
function Base.lastindex(case::JutulCase)
888+
return length(case.dt)
889+
end
890+
891+
function Base.lastindex(case::JutulCase, d::Int)
892+
d == 1 || throw(ArgumentError("JutulCase is 1D."))
893+
return Base.lastindex(case)
894+
end
895+
887896
function Base.getindex(case::JutulCase, ix)
888897
(; model, dt, forces, state0, parameters, input_data) = case
889898
f = deepcopy(forces)
@@ -1159,9 +1168,9 @@ function Base.show(io::IO, t::MIME"text/plain", options::MeshEntityTags{T}) wher
11591168
kv = "<no tags>"
11601169
else
11611170
s = map(x -> "$x $(keys(v[x]))", collect(kv))
1162-
kv = join(s, ",")
1171+
kv = join(s, ",\n\t")
11631172
end
1164-
println(io, " $k: $(kv)")
1173+
println(io, " $k:\n\t$(kv)")
11651174
end
11661175
end
11671176

src/core_types/domains.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,11 @@ function Base.show(io::IO, t::MIME"text/plain", d::DataDomain)
9999
for (k, v) in data
100100
vals, e = v
101101
if e == u
102-
sz = join(map(x -> "$x", size(vals)), "×")
102+
if vals isa AbstractVecOrMat
103+
sz = join(map(x -> "$x", size(vals)), "×")
104+
else
105+
sz = "$(typeof(vals))"
106+
end
103107
print(io, " :$k => $sz $(typeof(vals))\n")
104108
end
105109
end

src/domains.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,30 @@ count_entities(D::Union{DataDomain, DiscretizedDomain}, entity) = D.entities[ent
5656
count_active_entities(D, entity; kwarg...) = count_entities(D, entity)
5757
count_active_entities(D::DiscretizedDomain, entity; kwarg...) = count_active_entities(D, D.global_map, entity; kwarg...)
5858

59+
60+
"""
61+
number_of_cells(D::Union{DataDomain, DiscretizedDomain})
62+
63+
Get the number of cells in a `DataDomain` or `DiscretizedDomain`.
64+
"""
5965
function number_of_cells(D::Union{DataDomain, DiscretizedDomain})
6066
return count_entities(D, Cells())
6167
end
6268

69+
"""
70+
number_of_faces(D::Union{DataDomain, DiscretizedDomain})
71+
72+
Get the number of faces in a `DataDomain` or `DiscretizedDomain`.
73+
"""
6374
function number_of_faces(D::Union{DataDomain, DiscretizedDomain})
6475
return count_entities(D, Faces())
6576
end
6677

78+
"""
79+
number_of_half_faces(D::Union{DataDomain, DiscretizedDomain})
80+
81+
Get the number of half-faces in a `DataDomain` or `DiscretizedDomain`.
82+
"""
6783
function number_of_half_faces(D::Union{DataDomain, DiscretizedDomain})
6884
return 2*number_of_faces(D)
6985
end

src/ext/makie_ext.jl

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,20 @@ function plot_multimodel_interactive_impl
2323

2424
end
2525

26-
26+
"""
27+
plot_mesh(mesh)
28+
plot_mesh(mesh;
29+
cells = nothing,
30+
faces = nothing,
31+
boundaryfaces = nothing,
32+
outer = false,
33+
color = :lightblue,
34+
)
35+
36+
Plot a `mesh` with uniform colors. Optionally, indices `cells`, `faces` or
37+
`boundaryfaces` can be passed to limit the plotting to a specific selection of
38+
entities.
39+
"""
2740
function plot_mesh(arg...; kwarg...)
2841
check_plotting_availability()
2942
plot_mesh_impl(arg...; kwarg...)
@@ -33,6 +46,13 @@ function plot_mesh_impl
3346

3447
end
3548

49+
50+
"""
51+
plot_mesh!(ax, mesh)
52+
53+
Mutating version of `plot_mesh` that plots into an existing Makie `Axis`
54+
instance.
55+
"""
3656
function plot_mesh!(arg...; kwarg...)
3757
check_plotting_availability()
3858
plot_mesh_impl!(arg...; kwarg...)
@@ -42,6 +62,11 @@ function plot_mesh_impl!
4262

4363
end
4464

65+
"""
66+
plot_mesh_edges(mesh; kwarg...)
67+
68+
Plot the edges of all cells on the exterior of a mesh.
69+
"""
4570
function plot_mesh_edges(arg...; kwarg...)
4671
check_plotting_availability()
4772
plot_mesh_edges_impl(arg...; kwarg...)
@@ -51,6 +76,12 @@ function plot_mesh_edges_impl
5176

5277
end
5378

79+
"""
80+
plot_mesh_edges!(ax, mesh; kwarg...)
81+
82+
Plot the edges of all cells on the exterior of a mesh into existing Makie
83+
`Axis` `ax`.
84+
"""
5485
function plot_mesh_edges!(arg...; kwarg...)
5586
check_plotting_availability()
5687
plot_mesh_edges_impl!(arg...; kwarg...)
@@ -82,6 +113,18 @@ function plotting_check_interactive
82113

83114
end
84115

116+
"""
117+
check_plotting_availability(; throw = true, interactive = false)
118+
119+
Check if plotting through at least one `Makie` backend is available in the Julia
120+
session (after package has been loaded by for example `using GLMakie`). The
121+
argument `throw` can be used to control if this function acts as a programmatic
122+
check (`throw=false`) there the return value indicates availability, or if an
123+
error message is to be printed telling the user how to get plotting working
124+
(`throw=true`)
125+
126+
An additional check for specifically `interactive` plots can also be added.
127+
"""
85128
function check_plotting_availability(; throw = true, interactive = false)
86129
ok = true
87130
try

src/interpolation.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,11 @@ end
7474

7575
function LinearInterpolant(X::V, F::T; static = false, constant_dx = missing) where {T<:AbstractVector, V<:AbstractVector}
7676
length(X) == length(F) || throw(ArgumentError("X and F values must have equal length."))
77+
if length(X) == 1
78+
# Handle single inputs by constant extrapolation
79+
push!(X, only(X) + one(eltype(X)))
80+
push!(F, only(F))
81+
end
7782
if !issorted(X)
7883
ix = sortperm(X)
7984
X = X[ix]

src/linsolve/default.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ function MultiLinearizedSystem(subsystems, context, layout; r = nothing, dx = no
9090
for i in urng
9191
J = subsystems[i, i].jac
9292
ni, mi = size(J)
93-
@assert ni == mi
93+
@assert ni == mi "Mismatch in block size: $ni != $mi"
9494
e = eltype(J)
9595
if e <: Real
9696
bz = 1

0 commit comments

Comments
 (0)