|
68 | 68 | jac_prototype = nothing,
|
69 | 69 | chunksize = nothing,
|
70 | 70 | dx = sparsity === nothing && jac_prototype === nothing ? nothing : copy(x)) #if dx is nothing, we will estimate dx at the cost of a function call
|
71 |
| - if sparsity === nothing && jac_prototype === nothing || !ArrayInterface.ismutable(x) |
| 71 | + |
| 72 | + if sparsity === nothing && jac_prototype === nothing |
72 | 73 | cfg = chunksize === nothing ? ForwardDiff.JacobianConfig(f, x) : ForwardDiff.JacobianConfig(f, x, ForwardDiff.Chunk(getsize(chunksize)))
|
73 | 74 | return ForwardDiff.jacobian(f, x, cfg)
|
74 | 75 | end
|
75 | 76 | if dx isa Nothing
|
76 | 77 | dx = f(x)
|
77 | 78 | end
|
78 |
| - forwarddiff_color_jacobian(f,x,ForwardColorJacCache(f,x,chunksize,dx=dx,colorvec=colorvec,sparsity=sparsity),jac_prototype) |
| 79 | + return forwarddiff_color_jacobian(f,x,ForwardColorJacCache(f,x,chunksize,dx=dx,colorvec=colorvec,sparsity=sparsity),jac_prototype) |
| 80 | +end |
| 81 | + |
| 82 | +@inline function forwarddiff_color_jacobian(J::AbstractArray{<:Number}, f, |
| 83 | + x::AbstractArray{<:Number}; |
| 84 | + colorvec = 1:length(x), |
| 85 | + sparsity = nothing, |
| 86 | + jac_prototype = nothing, |
| 87 | + chunksize = nothing, |
| 88 | + dx = similar(x, size(J, 1))) #dx kwarg can be used to avoid re-allocating dx every time |
| 89 | + if sparsity === nothing && jac_prototype === nothing |
| 90 | + cfg = chunksize === nothing ? ForwardDiff.JacobianConfig(f, x) : ForwardDiff.JacobianConfig(f, x, ForwardDiff.Chunk(getsize(chunksize))) |
| 91 | + return ForwardDiff.jacobian(f, x, cfg) |
| 92 | + end |
| 93 | + return forwarddiff_color_jacobian(J,f,x,ForwardColorJacCache(f,x,chunksize,dx=dx,colorvec=colorvec,sparsity=sparsity)) |
79 | 94 | end
|
80 | 95 |
|
81 | 96 | function forwarddiff_color_jacobian(f,x::AbstractArray{<:Number},jac_cache::ForwardColorJacCache,jac_prototype=nothing)
|
| 97 | + |
| 98 | + if jac_prototype isa Nothing ? ArrayInterface.ismutable(x) : ArrayInterface.ismutable(jac_prototype) |
| 99 | + # Whenever J is mutable, we mutate it to avoid allocations |
| 100 | + dx = jac_cache.dx |
| 101 | + vecx = vec(x) |
| 102 | + sparsity = jac_cache.sparsity |
| 103 | + |
| 104 | + J = jac_prototype isa Nothing ? (sparsity isa Nothing ? false .* vec(dx) .* vecx' : |
| 105 | + zeros(eltype(x),size(sparsity))) : zero(jac_prototype) |
| 106 | + return forwarddiff_color_jacobian(J, f, x, jac_cache) |
| 107 | + else |
| 108 | + return forwarddiff_color_jacobian_immutable(f, x, jac_cache, jac_prototype) |
| 109 | + end |
| 110 | +end |
| 111 | + |
| 112 | +# When J is mutable, this version of forwarddiff_color_jacobian will mutate J to avoid allocations |
| 113 | +function forwarddiff_color_jacobian(J::AbstractMatrix{<:Number},f,x::AbstractArray{<:Number},jac_cache::ForwardColorJacCache) |
| 114 | + t = jac_cache.t |
| 115 | + dx = jac_cache.dx |
| 116 | + p = jac_cache.p |
| 117 | + colorvec = jac_cache.colorvec |
| 118 | + sparsity = jac_cache.sparsity |
| 119 | + chunksize = jac_cache.chunksize |
| 120 | + color_i = 1 |
| 121 | + maxcolor = maximum(colorvec) |
| 122 | + |
| 123 | + vecx = vec(x) |
| 124 | + |
| 125 | + nrows,ncols = size(J) |
| 126 | + |
| 127 | + if !(sparsity isa Nothing) |
| 128 | + rows_index, cols_index = ArrayInterface.findstructralnz(sparsity) |
| 129 | + rows_index = [rows_index[i] for i in 1:length(rows_index)] |
| 130 | + cols_index = [cols_index[i] for i in 1:length(cols_index)] |
| 131 | + end |
| 132 | + |
| 133 | + for i in eachindex(p) |
| 134 | + partial_i = p[i] |
| 135 | + t = reshape(Dual{typeof(ForwardDiff.Tag(f,eltype(vecx)))}.(vecx, partial_i),size(t)) |
| 136 | + fx = f(t) |
| 137 | + if !(sparsity isa Nothing) |
| 138 | + for j in 1:chunksize |
| 139 | + dx = vec(partials.(fx, j)) |
| 140 | + pick_inds = [i for i in 1:length(rows_index) if colorvec[cols_index[i]] == color_i] |
| 141 | + rows_index_c = rows_index[pick_inds] |
| 142 | + cols_index_c = cols_index[pick_inds] |
| 143 | + if J isa SparseMatrixCSC || j > 1 |
| 144 | + # Use sparse matrix to add to J column by column except . . . |
| 145 | + Ji = sparse(rows_index_c, cols_index_c, dx[rows_index_c],nrows,ncols) |
| 146 | + else |
| 147 | + # To overwrite pre-allocated matrix J, using sparse will cause an error |
| 148 | + # so we use this step to overwrite J |
| 149 | + len_rows = length(pick_inds) |
| 150 | + unused_rows = setdiff(1:nrows,rows_index_c) |
| 151 | + perm_rows = sortperm(vcat(rows_index_c,unused_rows)) |
| 152 | + cols_index_c = vcat(cols_index_c,zeros(Int,nrows-len_rows))[perm_rows] |
| 153 | + Ji = [j==cols_index_c[i] ? dx[i] : false for i in 1:nrows, j in 1:ncols] |
| 154 | + end |
| 155 | + if j == 1 && i == 1 |
| 156 | + J .= Ji # overwrite pre-allocated matrix J |
| 157 | + else |
| 158 | + J .+= Ji |
| 159 | + end |
| 160 | + color_i += 1 |
| 161 | + (color_i > maxcolor) && return J |
| 162 | + end |
| 163 | + else |
| 164 | + for j in 1:chunksize |
| 165 | + col_index = (i-1)*chunksize + j |
| 166 | + (col_index > ncols) && return J |
| 167 | + Ji = mapreduce(i -> i==col_index ? partials.(vec(fx), j) : adapt(parameterless_type(J),zeros(eltype(J),nrows)), hcat, 1:ncols) |
| 168 | + if j == 1 && i == 1 |
| 169 | + J .= (size(Ji)!=size(J) ? reshape(Ji,size(J)) : Ji) # overwrite pre-allocated matrix |
| 170 | + else |
| 171 | + J .+= (size(Ji)!=size(J) ? reshape(Ji,size(J)) : Ji) #branch when size(dx) == (1,) => size(Ji) == (1,) while size(J) == (1,1) |
| 172 | + end |
| 173 | + end |
| 174 | + end |
| 175 | + end |
| 176 | + return J |
| 177 | +end |
| 178 | + |
| 179 | +# When J is immutable, this version of forwarddiff_color_jacobian will avoid mutating J |
| 180 | +function forwarddiff_color_jacobian_immutable(f,x::AbstractArray{<:Number},jac_cache::ForwardColorJacCache,jac_prototype=nothing) |
82 | 181 | t = jac_cache.t
|
83 | 182 | dx = jac_cache.dx
|
84 | 183 | p = jac_cache.p
|
@@ -131,7 +230,7 @@ function forwarddiff_color_jacobian(f,x::AbstractArray{<:Number},jac_cache::Forw
|
131 | 230 | end
|
132 | 231 | end
|
133 | 232 | end
|
134 |
| - J |
| 233 | + return J |
135 | 234 | end
|
136 | 235 |
|
137 | 236 | function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number},
|
|
0 commit comments