-
-
Notifications
You must be signed in to change notification settings - Fork 216
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Package Version
0.6.49
Julia Version
1.8.1
OS / Environment
ubuntu linux (pop!_os)
Describe the bug
get dot product dimension mismatch error in gradient
call when using exp
function on Diagonal
matrices.
Steps to Reproduce
here is code to reproduce the error:
using Zygote
using LinearAlgebra
d = rand(3)
D = Diagonal(d)
D.diag == diag(D)
# output: true
function exp_diag_function(D::Diagonal)
return Diagonal(exp.(diag(D)))
end
function exp_diag_field_name(D::Diagonal)
return Diagonal(exp.(D.diag))
end
∇_diag_function(x) = Zygote.gradient(x -> tr(exp_diag_function(x * D)), x)
∇_diag_function(1.0)
# output: (1.6395871633504406,)
∇_diag_field_name(x) = Zygote.gradient(x -> tr(exp_diag_field_name(x * D)), x)
∇_diag_field_name(1.0)
# output: ERROR: DimensionMismatch: x and y are of different lengths!
∇_built_in(x) = Zygote.gradient(x -> tr(exp(x * D)), x)
∇_built_in(1.0)
# output: ERROR: DimensionMismatch: x and y are of different lengths!
Expected Results
there should be no error, somehow using the diag
field name breaks things vs using the diag
method.
Observed Results
dimension mismatch error with native exp function. here is the stack trace:
ERROR: DimensionMismatch: x and y are of different lengths!
Stacktrace:
[1] dot
@ ~/packages/julias/julia-1.8/share/julia/stdlib/v1.8/LinearAlgebra/src/generic.jl:866 [inlined]
[2] dot
@ ~/packages/julias/julia-1.8/share/julia/stdlib/v1.8/LinearAlgebra/src/generic.jl:856 [inlined]
[3] #1483
@ ~/.julia/packages/ChainRules/2ql0h/src/rulesets/Base/arraymath.jl:108 [inlined]
[4] unthunk
@ ~/.julia/packages/ChainRulesCore/C73ay/src/tangent_types/thunks.jl:204 [inlined]
[5] wrap_chainrules_output
@ ~/.julia/packages/Zygote/dABKa/src/compiler/chainrules.jl:105 [inlined]
[6] map
@ ./tuple.jl:223 [inlined]
[7] wrap_chainrules_output
@ ~/.julia/packages/Zygote/dABKa/src/compiler/chainrules.jl:106 [inlined]
[8] ZBack
@ ~/.julia/packages/Zygote/dABKa/src/compiler/chainrules.jl:206 [inlined]
[9] Pullback
@ ~/projects/Pico.jl/diagonal.jl:36 [inlined]
[10] (::typeof(∂(#59)))(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
[11] (::Zygote.var"#60#61"{typeof(∂(#59))})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface.jl:45
[12] gradient(f::Function, args::Float64)
@ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface.jl:97
[13] ∇_built_in(x::Float64)
@ Main ~/projects/Pico.jl/diagonal.jl:36
[14] top-level scope
@ ~/projects/Pico.jl/diagonal.jl:38
Relevant log output
No response
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working