Skip to content

weird issue arising when using exp function with Diagonal matrices  #1321

@aarontrowbridge

Description

@aarontrowbridge

Package Version

0.6.49

Julia Version

1.8.1

OS / Environment

ubuntu linux (pop!_os)

Describe the bug

get dot product dimension mismatch error in gradient call when using exp function on Diagonal matrices.

Steps to Reproduce

here is code to reproduce the error:

using Zygote
using LinearAlgebra

d = rand(3) 

D = Diagonal(d)

D.diag == diag(D)
# output: true

function exp_diag_function(D::Diagonal)
    return Diagonal(exp.(diag(D)))
end

function exp_diag_field_name(D::Diagonal)
    return Diagonal(exp.(D.diag))
end

∇_diag_function(x) = Zygote.gradient(x -> tr(exp_diag_function(x * D)), x)

∇_diag_function(1.0)
# output: (1.6395871633504406,)

∇_diag_field_name(x) = Zygote.gradient(x -> tr(exp_diag_field_name(x * D)), x)

∇_diag_field_name(1.0)
# output: ERROR: DimensionMismatch: x and y are of different lengths!

∇_built_in(x) = Zygote.gradient(x -> tr(exp(x * D)), x)

∇_built_in(1.0)
# output: ERROR: DimensionMismatch: x and y are of different lengths!

Expected Results

there should be no error, somehow using the diag field name breaks things vs using the diag method.

Observed Results

dimension mismatch error with native exp function. here is the stack trace:

ERROR: DimensionMismatch: x and y are of different lengths!
Stacktrace:
  [1] dot
    @ ~/packages/julias/julia-1.8/share/julia/stdlib/v1.8/LinearAlgebra/src/generic.jl:866 [inlined]
  [2] dot
    @ ~/packages/julias/julia-1.8/share/julia/stdlib/v1.8/LinearAlgebra/src/generic.jl:856 [inlined]
  [3] #1483
    @ ~/.julia/packages/ChainRules/2ql0h/src/rulesets/Base/arraymath.jl:108 [inlined]
  [4] unthunk
    @ ~/.julia/packages/ChainRulesCore/C73ay/src/tangent_types/thunks.jl:204 [inlined]
  [5] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/dABKa/src/compiler/chainrules.jl:105 [inlined]
  [6] map
    @ ./tuple.jl:223 [inlined]
  [7] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/dABKa/src/compiler/chainrules.jl:106 [inlined]
  [8] ZBack
    @ ~/.julia/packages/Zygote/dABKa/src/compiler/chainrules.jl:206 [inlined]
  [9] Pullback
    @ ~/projects/Pico.jl/diagonal.jl:36 [inlined]
 [10] (::typeof((#59)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
 [11] (::Zygote.var"#60#61"{typeof((#59))})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface.jl:45
 [12] gradient(f::Function, args::Float64)
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface.jl:97
 [13] ∇_built_in(x::Float64)
    @ Main ~/projects/Pico.jl/diagonal.jl:36
 [14] top-level scope
    @ ~/projects/Pico.jl/diagonal.jl:38

Relevant log output

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions