Skip to content
This repository was archived by the owner on Aug 22, 2025. It is now read-only.

Commit cd944a2

Browse files
author
Avik Pal
committed
Move stuff around
1 parent 7974f57 commit cd944a2

File tree

3 files changed

+44
-44
lines changed

3 files changed

+44
-44
lines changed

src/SparseDiffTools.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ using LightGraphs: SimpleGraph
88
using Requires
99
using VertexSafeGraphs
1010
using Adapt
11-
using Zygote
1211

1312
using LinearAlgebra
1413
using SparseArrays, ArrayInterface
@@ -50,6 +49,7 @@ include("coloring/greedy_star2_coloring.jl")
5049
include("coloring/matrix2graph.jl")
5150
include("differentiation/compute_jacobian_ad.jl")
5251
include("differentiation/jaches_products.jl")
52+
include("differentiation/vecjac_products.jl")
5353

5454
Base.@pure __parameterless_type(T) = Base.typename(T).wrapper
5555
parameterless_type(x) = parameterless_type(typeof(x))
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
2+
function num_vecjac!(
3+
du,
4+
f,
5+
x,
6+
v,
7+
cache1 = similar(v),
8+
cache2 = similar(v);
9+
compute_f0 = true,
10+
)
11+
if DiffEqBase.numargs(f) != 2
12+
du .= num_jacvec(f, x, v)
13+
return du
14+
end
15+
compute_f0 && (f(cache1, x))
16+
T = eltype(x)
17+
# Should it be min? max? mean?
18+
ϵ = sqrt(eps(real(T))) * max(one(real(T)), abs(norm(x)))
19+
vv = reshape(v, size(x))
20+
for i = 1:length(x)
21+
x[i] += ϵ
22+
f(cache2, x)
23+
x[i] -= ϵ
24+
du[i] = (((cache2 .- cache1) ./ ϵ)'*vv)[1]
25+
end
26+
return du
27+
end
28+
29+
function num_vecjac(f, x, v, f0 = nothing)
30+
vv = reshape(v, axes(x))
31+
f0 === nothing ? _f0 = f(x) : _f0 = f0
32+
T = eltype(x)
33+
# Should it be min? max? mean?
34+
ϵ = sqrt(eps(real(T))) * max(one(real(T)), abs(norm(x)))
35+
du = similar(x)
36+
for i = 1:length(x)
37+
x[i] += ϵ
38+
f0 = f(x)
39+
x[i] -= ϵ
40+
du[i] = (((f0 .- _f0) ./ ϵ)'*vv)[1]
41+
end
42+
return vec(du)
43+
end

src/differentiation/vecjac_products_zygote.jl

Lines changed: 0 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -8,46 +8,3 @@ function auto_vecjac(f, x, v)
88
vv, back = Zygote.pullback(f, x)
99
return vec(back(reshape(v, size(vv)))[1])
1010
end
11-
12-
function num_vecjac!(
13-
du,
14-
f,
15-
x,
16-
v,
17-
cache1 = similar(v),
18-
cache2 = similar(v);
19-
compute_f0 = true,
20-
)
21-
if DiffEqBase.numargs(f) != 2
22-
du .= num_jacvec(f, x, v)
23-
return du
24-
end
25-
compute_f0 && (f(cache1, x))
26-
T = eltype(x)
27-
# Should it be min? max? mean?
28-
ϵ = sqrt(eps(real(T))) * max(one(real(T)), abs(norm(x)))
29-
vv = reshape(v, size(x))
30-
for i = 1:length(x)
31-
x[i] += ϵ
32-
f(cache2, x)
33-
x[i] -= ϵ
34-
du[i] = (((cache2 .- cache1) ./ ϵ)'*vv)[1]
35-
end
36-
return du
37-
end
38-
39-
function num_vecjac(f, x, v, f0 = nothing)
40-
vv = reshape(v, axes(x))
41-
f0 === nothing ? _f0 = f(x) : _f0 = f0
42-
T = eltype(x)
43-
# Should it be min? max? mean?
44-
ϵ = sqrt(eps(real(T))) * max(one(real(T)), abs(norm(x)))
45-
du = similar(x)
46-
for i = 1:length(x)
47-
x[i] += ϵ
48-
f0 = f(x)
49-
x[i] -= ϵ
50-
du[i] = (((f0 .- _f0) ./ ϵ)'*vv)[1]
51-
end
52-
return vec(du)
53-
end

0 commit comments

Comments
 (0)