Skip to content
This repository was archived by the owner on Aug 22, 2025. It is now read-only.

Commit 9c9e4d5

Browse files
Merge pull request #158 from avik-pal/ap/vecjac
Add vecjac operators
2 parents 867ccda + cd944a2 commit 9c9e4d5

File tree

4 files changed

+197
-101
lines changed

4 files changed

+197
-101
lines changed

src/SparseDiffTools.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ export contract_color,
3030
ForwardColorJacCache,
3131
auto_jacvec,auto_jacvec!,
3232
num_jacvec,num_jacvec!,
33+
num_vecjac,num_vecjac!,
3334
num_hesvec,num_hesvec!,
3435
numauto_hesvec,numauto_hesvec!,
3536
autonum_hesvec,autonum_hesvec!,
@@ -48,15 +49,17 @@ include("coloring/greedy_star2_coloring.jl")
4849
include("coloring/matrix2graph.jl")
4950
include("differentiation/compute_jacobian_ad.jl")
5051
include("differentiation/jaches_products.jl")
52+
include("differentiation/vecjac_products.jl")
5153

5254
Base.@pure __parameterless_type(T) = Base.typename(T).wrapper
5355
parameterless_type(x) = parameterless_type(typeof(x))
5456
parameterless_type(x::Type) = __parameterless_type(x)
5557

5658
function __init__()
5759
@require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin
58-
export numback_hesvec, numback_hesvec!, autoback_hesvec, autoback_hesvec!
60+
export numback_hesvec, numback_hesvec!, autoback_hesvec, autoback_hesvec!, auto_vecjac, auto_vecjac!
5961

62+
include("differentiation/vecjac_products_zygote.jl")
6063
include("differentiation/jaches_products_zygote.jl")
6164
end
6265
end
Lines changed: 140 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -1,146 +1,178 @@
11
struct DeivVecTag end
22

33
# J(f(x))*v
4-
function auto_jacvec!(dy, f, x, v,
5-
cache1 = ForwardDiff.Dual{DeivVecTag}.(x, v),
6-
cache2 = ForwardDiff.Dual{DeivVecTag}.(x, v))
7-
cache1 .= Dual{DeivVecTag}.(x, v)
8-
f(cache2,cache1)
4+
function auto_jacvec!(
5+
dy,
6+
f,
7+
x,
8+
v,
9+
cache1 = Dual{DeivVecTag}.(x, reshape(v, size(x))),
10+
cache2 = similar(cache1),
11+
)
12+
cache1 .= Dual{DeivVecTag}.(x, reshape(v, size(x)))
13+
f(cache2, cache1)
914
dy .= partials.(cache2, 1)
1015
end
16+
1117
function auto_jacvec(f, x, v)
12-
fval = f(map((xi, vi) -> Dual{typeof(ForwardDiff.Tag(f,eltype(x)))}(xi, vi), x, v))
13-
map(u -> partials(u)[1], fval)
18+
vv = reshape(v, axes(x))
19+
vec(partials.(vec(f(ForwardDiff.Dual{DeivVecTag}.(x, vv))), 1))
1420
end
1521

16-
function num_jacvec!(dy,f,x,v,cache1 = similar(v),
17-
cache2 = similar(v);
18-
compute_f0 = true)
19-
compute_f0 && (f(cache1,x))
22+
function num_jacvec!(
23+
dy,
24+
f,
25+
x,
26+
v,
27+
cache1 = similar(v),
28+
cache2 = similar(v);
29+
compute_f0 = true,
30+
)
31+
vv = reshape(v, axes(x))
32+
compute_f0 && (f(cache1, x))
2033
T = eltype(x)
2134
# Should it be min? max? mean?
2235
ϵ = sqrt(eps(real(T))) * max(one(real(T)), abs(norm(x)))
23-
@. x += ϵ*v
24-
f(cache2,x)
25-
@. x -= ϵ*v
26-
@. dy = (cache2 - cache1)/ϵ
36+
@. x += ϵ * vv
37+
f(cache2, x)
38+
@. x -= ϵ * vv
39+
@. dy = (cache2 - cache1) / ϵ
2740
end
2841

29-
function num_jacvec(f,x,v,f0=nothing)
42+
function num_jacvec(f, x, v, f0 = nothing)
43+
vv = reshape(v, axes(x))
3044
f0 === nothing ? _f0 = f(x) : _f0 = f0
3145
T = eltype(x)
3246
# Should it be min? max? mean?
3347
ϵ = sqrt(eps(real(T))) * max(one(real(T)), abs(minimum(x)))
34-
(f(x.+ϵ.*v) .- _f0)./ϵ
48+
vec((f(x .+ ϵ .* vv) .- _f0) ./ ϵ)
3549
end
3650

37-
function num_hesvec!(dy,f,x,v,
38-
cache1 = similar(v),
39-
cache2 = similar(v),
40-
cache3 = similar(v))
41-
cache = FiniteDiff.GradientCache(v[1],cache1,Val{:central})
42-
g = let f=f,cache=cache
43-
(dx,x) -> FiniteDiff.finite_difference_gradient!(dx,f,x,cache)
51+
function num_hesvec!(
52+
dy,
53+
f,
54+
x,
55+
v,
56+
cache1 = similar(v),
57+
cache2 = similar(v),
58+
cache3 = similar(v),
59+
)
60+
cache = FiniteDiff.GradientCache(v[1], cache1, Val{:central})
61+
g = let f = f, cache = cache
62+
(dx, x) -> FiniteDiff.finite_difference_gradient!(dx, f, x, cache)
4463
end
4564
T = eltype(x)
4665
# Should it be min? max? mean?
4766
ϵ = sqrt(eps(real(T))) * max(one(real(T)), abs(norm(x)))
48-
@. x += ϵ*v
49-
g(cache2,x)
50-
@. x -= 2ϵ*v
51-
g(cache3,x)
52-
@. dy = (cache2 - cache3)/(2ϵ)
67+
@. x += ϵ * v
68+
g(cache2, x)
69+
@. x -= 2ϵ * v
70+
g(cache3, x)
71+
@. dy = (cache2 - cache3) / (2ϵ)
5372
end
5473

55-
function num_hesvec(f,x,v)
56-
g = (x) -> FiniteDiff.finite_difference_gradient(f,x)
74+
function num_hesvec(f, x, v)
75+
g = (x) -> FiniteDiff.finite_difference_gradient(f, x)
5776
T = eltype(x)
5877
# Should it be min? max? mean?
5978
ϵ = sqrt(eps(real(T))) * max(one(real(T)), abs(norm(x)))
60-
x += ϵ*v
79+
x += ϵ * v
6180
gxp = g(x)
62-
x -= 2ϵ*v
81+
x -= 2ϵ * v
6382
gxm = g(x)
64-
(gxp - gxm)/(2ϵ)
83+
(gxp - gxm) / (2ϵ)
6584
end
6685

67-
function numauto_hesvec!(dy,f,x,v,
68-
cache = ForwardDiff.GradientConfig(f,v),
69-
cache1 = similar(v),
70-
cache2 = similar(v))
71-
g = let f=f,x=x,cache=cache
72-
g = (dx,x) -> ForwardDiff.gradient!(dx,f,x,cache)
86+
function numauto_hesvec!(
87+
dy,
88+
f,
89+
x,
90+
v,
91+
cache = ForwardDiff.GradientConfig(f, v),
92+
cache1 = similar(v),
93+
cache2 = similar(v),
94+
)
95+
g = let f = f, x = x, cache = cache
96+
g = (dx, x) -> ForwardDiff.gradient!(dx, f, x, cache)
7397
end
7498
T = eltype(x)
7599
# Should it be min? max? mean?
76100
ϵ = sqrt(eps(real(T))) * max(one(real(T)), abs(norm(x)))
77-
@. x += ϵ*v
78-
g(cache1,x)
79-
@. x -= 2ϵ*v
80-
g(cache2,x)
81-
@. dy = (cache1 - cache2)/(2ϵ)
101+
@. x += ϵ * v
102+
g(cache1, x)
103+
@. x -= 2ϵ * v
104+
g(cache2, x)
105+
@. dy = (cache1 - cache2) / (2ϵ)
82106
end
83107

84-
function numauto_hesvec(f,x,v)
85-
g = (x) -> ForwardDiff.gradient(f,x)
108+
function numauto_hesvec(f, x, v)
109+
g = (x) -> ForwardDiff.gradient(f, x)
86110
T = eltype(x)
87111
# Should it be min? max? mean?
88112
ϵ = sqrt(eps(real(T))) * max(one(real(T)), abs(norm(x)))
89-
x += ϵ*v
113+
x += ϵ * v
90114
gxp = g(x)
91-
x -= 2ϵ*v
115+
x -= 2ϵ * v
92116
gxm = g(x)
93-
(gxp - gxm)/(2ϵ)
117+
(gxp - gxm) / (2ϵ)
94118
end
95119

96-
function autonum_hesvec!(dy,f,x,v,
97-
cache1 = ForwardDiff.Dual{DeivVecTag}.(x, v),
98-
cache2 = ForwardDiff.Dual{DeivVecTag}.(x, v))
99-
cache = FiniteDiff.GradientCache(v[1],cache1,Val{:central})
100-
g = (dx,x) -> FiniteDiff.finite_difference_gradient!(dx,f,x,cache)
120+
function autonum_hesvec!(
121+
dy,
122+
f,
123+
x,
124+
v,
125+
cache1 = ForwardDiff.Dual{DeivVecTag}.(x, v),
126+
cache2 = ForwardDiff.Dual{DeivVecTag}.(x, v),
127+
)
128+
cache = FiniteDiff.GradientCache(v[1], cache1, Val{:central})
129+
g = (dx, x) -> FiniteDiff.finite_difference_gradient!(dx, f, x, cache)
101130
cache1 .= Dual{DeivVecTag}.(x, v)
102-
g(cache2,cache1)
131+
g(cache2, cache1)
103132
dy .= partials.(cache2, 1)
104133
end
105134

106-
function autonum_hesvec(f,x,v)
107-
g = (x) -> FiniteDiff.finite_difference_gradient(f,x)
135+
function autonum_hesvec(f, x, v)
136+
g = (x) -> FiniteDiff.finite_difference_gradient(f, x)
108137
partials.(g(Dual{DeivVecTag}.(x, v)), 1)
109138
end
110139

111-
function num_hesvecgrad!(dy,g,x,v,
112-
cache2 = similar(v),
113-
cache3 = similar(v))
140+
function num_hesvecgrad!(dy, g, x, v, cache2 = similar(v), cache3 = similar(v))
114141
T = eltype(x)
115142
# Should it be min? max? mean?
116143
ϵ = sqrt(eps(real(T))) * max(one(real(T)), abs(norm(x)))
117-
@. x += ϵ*v
118-
g(cache2,x)
119-
@. x -= 2ϵ*v
120-
g(cache3,x)
121-
@. dy = (cache2 - cache3)/(2ϵ)
144+
@. x += ϵ * v
145+
g(cache2, x)
146+
@. x -= 2ϵ * v
147+
g(cache3, x)
148+
@. dy = (cache2 - cache3) / (2ϵ)
122149
end
123150

124-
function num_hesvecgrad(g,x,v)
151+
function num_hesvecgrad(g, x, v)
125152
T = eltype(x)
126153
# Should it be min? max? mean?
127154
ϵ = sqrt(eps(real(T))) * max(one(real(T)), abs(norm(x)))
128-
x += ϵ*v
155+
x += ϵ * v
129156
gxp = g(x)
130-
x -= 2ϵ*v
157+
x -= 2ϵ * v
131158
gxm = g(x)
132-
(gxp - gxm)/(2ϵ)
159+
(gxp - gxm) / (2ϵ)
133160
end
134161

135-
function auto_hesvecgrad!(dy,g,x,v,
136-
cache2 = ForwardDiff.Dual{DeivVecTag}.(x, v),
137-
cache3 = ForwardDiff.Dual{DeivVecTag}.(x, v))
162+
function auto_hesvecgrad!(
163+
dy,
164+
g,
165+
x,
166+
v,
167+
cache2 = ForwardDiff.Dual{DeivVecTag}.(x, v),
168+
cache3 = ForwardDiff.Dual{DeivVecTag}.(x, v),
169+
)
138170
cache2 .= Dual{DeivVecTag}.(x, v)
139-
g(cache3,cache2)
171+
g(cache3, cache2)
140172
dy .= partials.(cache3, 1)
141173
end
142174

143-
function auto_hesvecgrad(g,x,v)
175+
function auto_hesvecgrad(g, x, v)
144176
partials.(g(Dual{DeivVecTag}.(x, v)), 1)
145177
end
146178

@@ -154,26 +186,28 @@ struct JacVec{F,T1,T2,xType}
154186
autodiff::Bool
155187
end
156188

157-
function JacVec(f,x::AbstractArray;autodiff=true)
189+
function JacVec(f, x::AbstractArray; autodiff = true)
158190
if autodiff
159191
cache1 = ForwardDiff.Dual{DeivVecTag}.(x, x)
160192
cache2 = ForwardDiff.Dual{DeivVecTag}.(x, x)
161193
else
162194
cache1 = similar(x)
163195
cache2 = similar(x)
164196
end
165-
JacVec(f,cache1,cache2,x,autodiff)
197+
JacVec(f, cache1, cache2, x, autodiff)
166198
end
167199

168-
Base.size(L::JacVec) = (length(L.cache1),length(L.cache1))
169-
Base.size(L::JacVec,i::Int) = length(L.cache1)
170-
Base.:*(L::JacVec,v::AbstractVector) = L.autodiff ? auto_jacvec(_x->L.f(_x),L.x,v) : num_jacvec(_x->L.f(_x),L.x,v)
200+
Base.size(L::JacVec) = (length(L.cache1), length(L.cache1))
201+
Base.size(L::JacVec, i::Int) = length(L.cache1)
202+
Base.:*(L::JacVec, v::AbstractVector) =
203+
L.autodiff ? auto_jacvec(_x -> L.f(_x), L.x, v) :
204+
num_jacvec(_x -> L.f(_x), L.x, v)
171205

172-
function LinearAlgebra.mul!(dy::AbstractVector,L::JacVec,v::AbstractVector)
206+
function LinearAlgebra.mul!(dy::AbstractVector, L::JacVec, v::AbstractVector)
173207
if L.autodiff
174-
auto_jacvec!(dy,(_y,_x)->L.f(_y,_x),L.x,v,L.cache1,L.cache2)
208+
auto_jacvec!(dy, (_y, _x) -> L.f(_y, _x), L.x, v, L.cache1, L.cache2)
175209
else
176-
num_jacvec!(dy,(_y,_x)->L.f(_y,_x),L.x,v,L.cache1,L.cache2)
210+
num_jacvec!(dy, (_y, _x) -> L.f(_y, _x), L.x, v, L.cache1, L.cache2)
177211
end
178212
end
179213

@@ -186,28 +220,29 @@ struct HesVec{F,T1,T2,xType}
186220
autodiff::Bool
187221
end
188222

189-
function HesVec(f,x::AbstractArray;autodiff=true)
223+
function HesVec(f, x::AbstractArray; autodiff = true)
190224
if autodiff
191-
cache1 = ForwardDiff.GradientConfig(f,x)
225+
cache1 = ForwardDiff.GradientConfig(f, x)
192226
cache2 = similar(x)
193227
cache3 = similar(x)
194228
else
195229
cache1 = similar(x)
196230
cache2 = similar(x)
197231
cache3 = similar(x)
198232
end
199-
HesVec(f,cache1,cache2,cache3,x,autodiff)
233+
HesVec(f, cache1, cache2, cache3, x, autodiff)
200234
end
201235

202-
Base.size(L::HesVec) = (length(L.cache2),length(L.cache2))
203-
Base.size(L::HesVec,i::Int) = length(L.cache2)
204-
Base.:*(L::HesVec,v::AbstractVector) = L.autodiff ? numauto_hesvec(L.f,L.x,v) : num_hesvec(L.f,L.x,v)
236+
Base.size(L::HesVec) = (length(L.cache2), length(L.cache2))
237+
Base.size(L::HesVec, i::Int) = length(L.cache2)
238+
Base.:*(L::HesVec, v::AbstractVector) =
239+
L.autodiff ? numauto_hesvec(L.f, L.x, v) : num_hesvec(L.f, L.x, v)
205240

206-
function LinearAlgebra.mul!(dy::AbstractVector,L::HesVec,v::AbstractVector)
241+
function LinearAlgebra.mul!(dy::AbstractVector, L::HesVec, v::AbstractVector)
207242
if L.autodiff
208-
numauto_hesvec!(dy,L.f,L.x,v,L.cache1,L.cache2,L.cache3)
243+
numauto_hesvec!(dy, L.f, L.x, v, L.cache1, L.cache2, L.cache3)
209244
else
210-
num_hesvec!(dy,L.f,L.x,v,L.cache1,L.cache2,L.cache3)
245+
num_hesvec!(dy, L.f, L.x, v, L.cache1, L.cache2, L.cache3)
211246
end
212247
end
213248

@@ -219,25 +254,30 @@ struct HesVecGrad{G,T1,T2,uType}
219254
autodiff::Bool
220255
end
221256

222-
function HesVecGrad(g,x::AbstractArray;autodiff=false)
257+
function HesVecGrad(g, x::AbstractArray; autodiff = false)
223258
if autodiff
224259
cache1 = ForwardDiff.Dual{DeivVecTag}.(x, x)
225260
cache2 = ForwardDiff.Dual{DeivVecTag}.(x, x)
226261
else
227262
cache1 = similar(x)
228263
cache2 = similar(x)
229264
end
230-
HesVecGrad(g,cache1,cache2,x,autodiff)
265+
HesVecGrad(g, cache1, cache2, x, autodiff)
231266
end
232267

233-
Base.size(L::HesVecGrad) = (length(L.cache2),length(L.cache2))
234-
Base.size(L::HesVecGrad,i::Int) = length(L.cache2)
235-
Base.:*(L::HesVecGrad,v::AbstractVector) = L.autodiff ? auto_hesvecgrad(L.g,L.x,v) : num_hesvecgrad(L.g,L.x,v)
268+
Base.size(L::HesVecGrad) = (length(L.cache2), length(L.cache2))
269+
Base.size(L::HesVecGrad, i::Int) = length(L.cache2)
270+
Base.:*(L::HesVecGrad, v::AbstractVector) =
271+
L.autodiff ? auto_hesvecgrad(L.g, L.x, v) : num_hesvecgrad(L.g, L.x, v)
236272

237-
function LinearAlgebra.mul!(dy::AbstractVector,L::HesVecGrad,v::AbstractVector)
273+
function LinearAlgebra.mul!(
274+
dy::AbstractVector,
275+
L::HesVecGrad,
276+
v::AbstractVector,
277+
)
238278
if L.autodiff
239-
auto_hesvecgrad!(dy,L.g,L.x,v,L.cache1,L.cache2)
279+
auto_hesvecgrad!(dy, L.g, L.x, v, L.cache1, L.cache2)
240280
else
241-
num_hesvecgrad!(dy,L.g,L.x,v,L.cache1,L.cache2)
281+
num_hesvecgrad!(dy, L.g, L.x, v, L.cache1, L.cache2)
242282
end
243283
end

0 commit comments

Comments
 (0)