@@ -63,6 +63,10 @@ function solve_adjoint_sensitivities(model, states, reports_or_timesteps, G;
6363 jutul_message (" Adjoints" , " Storing sensitivities." , color = :blue )
6464 end
6565 out = store_sensitivities (parameter_model, ∇G, storage. parameter_map)
66+ s0_map = storage. state0_map
67+ if ! ismissing (s0_map)
68+ store_sensitivities! (out, storage. backward. model, ∇G, s0_map)
69+ end
6670 end
6771 if extra_output
6872 return (out, storage)
@@ -107,24 +111,27 @@ function setup_adjoint_storage(model;
107111 parameters = setup_parameters (model),
108112 n_objective = nothing ,
109113 targets = parameter_targets (model),
114+ include_state0 = false ,
110115 use_sparsity = true ,
111116 linear_solver = select_linear_solver (model, mode = :adjoint , rtol = 1e-6 ),
112117 param_obj = true ,
113118 info_level = 0 ,
114119 kwarg...
115120 )
116- # Set up the generic adjoint storage
117- storage = setup_adjoint_storage_base (
118- model, state0, parameters,
119- use_sparsity = use_sparsity,
120- linear_solver = linear_solver,
121- n_objective = n_objective,
122- info_level = info_level
123- )
124121 # Create parameter model for ∂Fₙ / ∂p
125122 parameter_model = adjoint_parameter_model (model, targets)
123+ n_prm = number_of_degrees_of_freedom (parameter_model)
126124 # Note that primary is here because the target parameters are now the primaries for the parameter_model
127125 parameter_map, = variable_mapper (parameter_model, :primary , targets = targets; kwarg... )
126+ if include_state0
127+ state0_map, = variable_mapper (model, :primary )
128+ n_state0 = number_of_degrees_of_freedom (model)
129+ state0_vec = zeros (n_state0)
130+ else
131+ state0_map = missing
132+ state0_vec = missing
133+ n_state0 = 0
134+ end
128135 # Transfer over parameters and state0 variables since many parameters are now variables
129136 state0_p = swap_variables (state0, parameters, parameter_model, variables = true )
130137 parameters_p = swap_variables (state0, parameters, parameter_model, variables = false )
@@ -137,11 +144,21 @@ function setup_adjoint_storage(model;
137144 dobj_dparam = nothing
138145 param_buf = nothing
139146 end
147+ # Set up the generic adjoint storage
148+ storage = setup_adjoint_storage_base (
149+ model, state0, parameters,
150+ use_sparsity = use_sparsity,
151+ linear_solver = linear_solver,
152+ n_objective = n_objective,
153+ info_level = info_level,
154+ )
140155 storage[:dparam ] = dobj_dparam
141156 storage[:param_buf ] = param_buf
142157 storage[:parameter ] = parameter_sim
143158 storage[:parameter_map ] = parameter_map
144- storage[:n ] = number_of_degrees_of_freedom (parameter_model)
159+ storage[:state0_map ] = state0_map
160+ storage[:dstate0 ] = state0_vec
161+ storage[:n ] = n_prm
145162
146163 return storage
147164end
@@ -151,7 +168,7 @@ function setup_adjoint_storage_base(model, state0, parameters;
151168 linear_solver = select_linear_solver (model, mode = :adjoint , rtol = 1e-8 ),
152169 n_objective = nothing ,
153170 info_level = 0
154- )
171+ )
155172 primary_model = adjoint_model_copy (model)
156173 # Standard model for: ∂Fₙᵀ / ∂xₙ
157174 forward_sim = Simulator (primary_model, state0 = deepcopy (state0), parameters = deepcopy (parameters), mode = :forward , extra_timing = nothing )
@@ -236,6 +253,8 @@ function solve_adjoint_sensitivities!(∇G, storage, states, state0, timesteps,
236253 end
237254 rescale_sensitivities! (∇G, storage. parameter. model, storage. parameter_map)
238255 @assert all (isfinite, ∇G)
256+ # Finally deal with initial state gradients
257+ update_state0_sensitivities! (storage)
239258 return ∇G
240259end
241260
@@ -720,6 +739,28 @@ function store_sensitivities!(out, model, variables, result, prm_map, ::Equation
720739 return out
721740end
722741
742+ function store_sensitivities! (out, model, variables, result, prm_map, :: Union{EquationMajorLayout, BlockMajorLayout} )
743+ scalar_valued_objective = result isa AbstractVector
744+ @assert scalar_valued_objective " Only supported for scalar objective"
745+
746+ us = get_primary_variable_ordered_entities (model)
747+ @assert length (us) == 1 " This function is not implemented for more than one entity type for primary variables"
748+ u = only (us)
749+ bz = degrees_of_freedom_per_entity (model, u)
750+ ne = count_active_entities (model. domain, u)
751+
752+ offset = 1
753+ for (k, var) in pairs (variables)
754+ m = degrees_of_freedom_per_entity (model, var)
755+ var:: ScalarVariable
756+ pos = offset: bz: (bz* (ne- 1 )+ offset)
757+ @assert length (pos) == ne " $(length (pos)) "
758+ out[k] = result[pos]
759+ offset += 1
760+ end
761+ return out
762+ end
763+
723764function extract_sensitivity_subset (r, var, n, m, offset)
724765 if var isa ScalarVariable
725766 v = r
@@ -845,3 +886,37 @@ function adjoint_transfer_canonical_order_inner!(λ, dx, model, ::BlockMajorLayo
845886 end
846887 end
847888end
889+
890+ function update_state0_sensitivities! (storage)
891+ state0_map = storage. state0_map
892+ if ! ismissing (state0_map)
893+ sim = storage. backward
894+ model = sim. model
895+ if model isa MultiModel
896+ for (k, v) in pairs (model. models)
897+ @assert matrix_layout (v. context) isa EquationMajorLayout
898+ end
899+ else
900+ @assert matrix_layout (model. context) isa EquationMajorLayout
901+ end
902+ # Assume that this gets called at the end when everything has been set
903+ # up in terms of the simulators
904+ λ = storage. lagrange
905+ ∇x = storage. dstate0
906+ @. ∇x = 0.0
907+ # order = collect(eachindex(λ))
908+ # renum = similar(order)
909+ # TODO : Finish this part and remove the assertions above
910+ # Get order to put values into canonical order
911+ # adjoint_transfer_canonical_order!(renum, order, model)
912+ # λ_renum = similar(λ)
913+ # @. λ_renum[renum] = λ
914+ # model_prm = storage.parameter.model
915+ lsys_b = sim. storage. LinearizedSystem
916+ op_b = linear_operator (lsys_b, skip_red = true )
917+ # tmp = zeros(size(∇x))
918+ sens_add_mult! (∇x, op_b, λ)
919+ # adjoint_transfer_canonical_order!(∇x, tmp, model)
920+ rescale_sensitivities! (∇x, sim. model, storage. state0_map)
921+ end
922+ end
0 commit comments