@@ -377,58 +377,60 @@ function fit_mle(::Type{<:Dirichlet}, P::AbstractMatrix{Float64},
377377end 
378378
379379# # Differentiation
380- function  ChainRulesCore. frule ((_, Δalpha), :: Type{DT} , alpha:: AbstractVector{T} ; check_args:: Bool  =  true ) where  {T <:  Real , DT <:  Union{Dirichlet{T}, Dirichlet} }
380+ function  ChainRulesCore. frule ((_, Δalpha):: Tuple{Any,Any} , :: Type{DT} , alpha:: AbstractVector{T} ; check_args:: Bool  =  true ) where  {T <:  Real , DT <:  Union{Dirichlet{T}, Dirichlet} }
381381    d =  DT (alpha; check_args= check_args)
382-     Δalpha =  ChainRulesCore. unthunk (Δalpha)
383382    ∂alpha0 =  sum (Δalpha)
384383    digamma_alpha0 =  SpecialFunctions. digamma (d. alpha0)
385-     ∂lmnB =  sum (Broadcast. instantiate (Broadcast. broadcasted (Δalpha, alpha) do  Δalpha_i, alpha_i 
386-         Δalpha_i  *  (SpecialFunctions. digamma (alpha_i ) -  digamma_alpha0)
384+     ∂lmnB =  sum (Broadcast. instantiate (Broadcast. broadcasted (Δalpha, alpha) do  Δalphai, alphai 
385+         Δalphai  *  (SpecialFunctions. digamma (alphai ) -  digamma_alpha0)
387386    end ))
388-     backing =  (alpha= Δalpha, alpha0= ∂alpha0, lmnB= ∂lmnB)
389-     t =  ChainRulesCore. Tangent {typeof(d), NamedTuple{(:alpha, :alpha0, :lmnB), Tuple{typeof(alpha), typeof(d.alpha0), typeof(d.lmnB)}}} (backing)
390-     return  d, t
387+     Δd =  ChainRulesCore. Tangent {typeof(d)} (; alpha= Δalpha, alpha0= ∂alpha0, lmnB= ∂lmnB)
388+     return  d, Δd
391389end 
392390
393391function  ChainRulesCore. rrule (:: Type{DT} , alpha:: AbstractVector{T} ; check_args:: Bool  =  true ) where  {T <:  Real , DT <:  Union{Dirichlet{T}, Dirichlet} }
394392    d =  DT (alpha; check_args= check_args)
395-     function   dirichlet_pullback (d_dir )
396-         d_dir  =  ChainRulesCore . unthunk (d_dir )
397-         digamma_alpha0  =  SpecialFunctions . digamma (d . alpha0 )
398-         dalpha  =  d_dir . alpha .+  d_dir . alpha0 .+  d_dir . lmnB .*  (SpecialFunctions. digamma .(alpha) .-  digamma_alpha0)
399-         return  ChainRulesCore. NoTangent (), dalpha 
393+     digamma_alpha0  =  SpecialFunctions . digamma (d . alpha0 )
394+     function   Dirichlet_pullback (_Δd )
395+         Δd  =  ChainRulesCore . unthunk (_Δd )
396+         Δalpha  =  Δd . alpha .+  Δd . alpha0 .+  Δd . lmnB .*  (SpecialFunctions. digamma .(alpha) .-  digamma_alpha0)
397+         return  ChainRulesCore. NoTangent (), Δalpha 
400398    end 
401-     return  d, dirichlet_pullback 
399+     return  d, Dirichlet_pullback 
402400end 
403401
404- function  ChainRulesCore. frule ((_, Δd, Δx), :: typeof (_logpdf), d:: Dirichlet , x:: AbstractVector{<:Real} )
405-     lp  =  _logpdf (d, x)
406-     ∂α_x  =  sum (Broadcast. instantiate (Broadcast. broadcasted (Δd. alpha, Δx, d. alpha, x) do  Δalpha_i, Δx_i, alpha_i, x_i 
407-         xlogy (Δalpha_i, x_i ) +  (alpha_i  -  1 ) *  Δx_i  /  x_i 
402+ function  ChainRulesCore. frule ((_, Δd, Δx):: Tuple{Any,Any,Any} , :: typeof (_logpdf), d:: Dirichlet , x:: AbstractVector{<:Real} )
403+     Ω  =  _logpdf (d, x)
404+     ∂alpha  =  sum (Broadcast. instantiate (Broadcast. broadcasted (Δd. alpha, Δx, d. alpha, x) do  Δalphai, Δxi, alphai, xi 
405+         xlogy (Δalphai, xi ) +  (alphai  -  1 ) *  Δxi  /  xi 
408406    end ))
409-     ∂l =  - Δd. lmnB
410-     if  ! insupport (d, x)
411-         ∂α_x =  oftype (∂α_x, NaN )
407+     ∂lmnB =  - Δd. lmnB
408+     ΔΩ =  ∂alpha +  ∂lmnB
409+     if  ! isfinite (Ω)
410+         ΔΩ =  oftype (ΔΩ, NaN )
412411    end 
413-     return  (lp, ∂α_x  +  ∂l) 
412+     return  Ω, ΔΩ 
414413end 
415414
416- function  ChainRulesCore. rrule (:: typeof (_logpdf), d:: Dirichlet , x:: AbstractVector{<:Real} )
417-     y =  _logpdf (d, x)
418-     function  Dirichlet_logpdf_pullback (dy)
419-         ∂alpha =  xlogy .(dy, x)
420-         ∂l =  - dy
421-         ∂x =  dy .*  (d. alpha .- 1 ) ./  x
422-         ∂alpha0 =  sum (∂alpha)
423-         if  ! isfinite (y)
424-             ∂alpha =  oftype (eltype (∂alpha), NaN ) *  ∂alpha
425-             ∂l =  oftype (∂l, NaN )
426-             ∂x =  oftype (eltype (∂x), NaN ) *  ∂x
427-             ∂alpha0 =  oftype (eltype (∂alpha), NaN )
428-         end 
429-         backing =  (alpha =  ∂alpha, alpha0 =  ∂alpha0, lmnB= ∂l)
430-         ∂d =  ChainRulesCore. Tangent {typeof(d), typeof(backing)} (backing)
431-         return  (ChainRulesCore. NoTangent (), ∂d, ∂x)
415+ function  ChainRulesCore. rrule (:: typeof (_logpdf), d:: T , x:: AbstractVector{<:Real} ) where  {T<: Dirichlet }
416+     Ω =  _logpdf (d, x)
417+     isfinite_Ω =  isfinite (Ω)
418+     alpha =  d. alpha
419+     function  _logpdf_Dirichlet_pullback (_ΔΩ)
420+         ΔΩ =  ChainRulesCore. unthunk (_ΔΩ)
421+         ∂alpha =  _logpdf_Dirichlet_∂alphai .(x, ΔΩ, isfinite_Ω)
422+         ∂lmnB =  isfinite_Ω ?  - float (ΔΩ) :  oftype (float (ΔΩ), NaN )
423+         Δd =  ChainRulesCore. Tangent {T} (; alpha= ∂alpha, lmnB= ∂lmnB)
424+         Δx =  _logpdf_Dirichlet_Δxi .(ΔΩ, alpha, x, isfinite_Ω)
425+         return  ChainRulesCore. NoTangent (), Δd, Δx
432426    end 
433-     return  (y, Dirichlet_logpdf_pullback)
427+     return  Ω, _logpdf_Dirichlet_pullback
428+ end 
429+ function  _logpdf_Dirichlet_∂alphai (xi, ΔΩi, isfinite:: Bool )
430+     ∂alphai =  xlogy .(ΔΩi, xi)
431+     return  isfinite ?  ∂alphai :  oftype (∂alphai, NaN )
432+ end 
433+ function  _logpdf_Dirichlet_Δxi (ΔΩi, alphai, xi, isfinite:: Bool )
434+     Δxi =  ΔΩi *  (alphai -  1 ) /  xi
435+     return  isfinite ?  Δxi :  oftype (Δxi, NaN )
434436end 
0 commit comments