@@ -6,14 +6,20 @@ struct LocalPerspectiveAD{T, N, A<:AbstractArray{T,N}, I} <: AbstractArray{T,N}
66end
77
88function LocalPerspectiveAD (a:: A , index:: I_t ) where {A<: AbstractArray , I_t<: Integer }
9- LocalPerspectiveAD {eltype(a), ndims(A), A, I_t} (index, a)
9+ return LocalPerspectiveAD {eltype(a), ndims(A), A, I_t} (index, a)
1010end
1111
1212struct LocalStateAD{T, I, E} # Data type, index, entity tag
1313 index:: I
1414 data:: T
1515end
1616
17+ struct MultiModelLocalStateAD{T, I, E} # Data type, index, entity tag
18+ symbol:: Symbol
19+ index:: I
20+ data:: T
21+ end
22+
1723Base. keys (x:: LocalStateAD ) = keys (getfield (x, :data ))
1824
1925struct ValueStateAD{T} # Data type
2228
2329Base. keys (x:: ValueStateAD ) = keys (getfield (x, :data ))
2430
31+ function convert_to_immutable_storage (x:: ValueStateAD )
32+ data = getfield (x, :data )
33+ data = convert_to_immutable_storage (data)
34+ return ValueStateAD (data)
35+ end
36+
2537const StateType = Union{NamedTuple,AbstractDict,JutulStorage}
2638
2739as_value (x:: StateType ) = ValueStateAD (x)
@@ -31,7 +43,6 @@ export local_ad
3143@inline local_ad (v, :: Nothing ) = as_value (v)
3244@inline local_ad (v, i) = v
3345
34-
3546@inline function new_entity_index (state:: LocalStateAD{T, I, E} , index:: I ) where {T, I, E}
3647 return LocalStateAD {T, I, E} (index, getfield (state, :data ))
3748end
4051 return x
4152end
4253
43-
4454@inline local_entity (a:: LocalPerspectiveAD ) = a. index
4555
4656@inline function value_or_ad (A:: LocalPerspectiveAD{T} , v:: T , entity) where T
8090end
8191@inline Base. haskey (state:: ValueStateAD , f:: Symbol ) = haskey (getfield (state, :data ), f)
8292
83-
8493# Match in type - pass index on
8594@inline next_level_local_ad (x:: AbstractArray{T} , :: Type{T} , index) where T = local_ad (x, index)
8695
@@ -140,10 +149,26 @@ end
140149 return next_level_local_ad (val, E, index)
141150end
142151
152+ @inline function Base. getproperty (state:: MultiModelLocalStateAD{T, I, E} , f:: Symbol ) where {T, I, E}
153+ index = getfield (state, :index )
154+ inner_state = getfield (state, :data )
155+ val = getproperty (inner_state, f)
156+ sym = getfield (state, :symbol )
157+ if sym == f
158+ return next_level_local_ad (val, E, index)
159+ else
160+ return as_value (val)
161+ end
162+ end
163+
143164@inline function Base. getindex (state:: LocalStateAD , s:: Symbol )
144165 Base. getproperty (state, s)
145166end
146167
168+ @inline function Base. getindex (state:: MultiModelLocalStateAD , s:: Symbol )
169+ Base. getproperty (state, s)
170+ end
171+
147172@inline function Base. getproperty (state:: ValueStateAD{T} , f:: Symbol ) where {T}
148173 inner_state = getfield (state, :data )
149174 val = getproperty (inner_state, f)
@@ -163,10 +188,18 @@ Create local_ad for state for index I of AD tag of type ad_tag
163188 local_state_ad (state, index, ad_tag)
164189end
165190
191+ @inline function local_ad (state, index, ad_tag, symbol)
192+ local_state_ad (state, index, ad_tag, symbol)
193+ end
194+
166195@inline function local_state_ad (state:: T , index:: I , ad_tag: :∂T ) where {T, I<: Integer , ∂T}
167196 return LocalStateAD {T, I, ad_tag} (index, state)
168197end
169198
199+ @inline function local_state_ad (state:: T , index:: I , ad_tag: :∂T , symbol:: Symbol ) where {T, I<: Integer , ∂T}
200+ return MultiModelLocalStateAD {T, I, ad_tag} (symbol, index, state)
201+ end
202+
170203function Base. show (io:: IO , t:: MIME"text/plain" , x:: LocalStateAD{T, I, E} ) where {T, I, E}
171204 print (io, " Local state for $(unpack_tag (E)) -> $(getfield (x, :index )) with fields $(keys (getfield (x, :data ))) " )
172205end
0 commit comments