11using Flux, ChainRulesCore
22using LinearAlgebra: mul!
3- # using FastBroadcast: @..
3+ using FastBroadcast: @. .
44using Strided
55
66const NoT = NoTangent ()
@@ -108,32 +108,62 @@ function ChainRulesCore.rrule(::typeof(scale!), y, (scale, ds), (x, dx), (bias,
108108end
109109
110110# ####
111- # #### softmax
111+ # #### Conv
112112# ####
113113
114- function PreLayer (:: typeof (softmax))
115- fwd, rev = zeros (Float32, 0 ), zeros (Float32, 0 ) # not ideal, demands `model |> pre |> gpu`
116- PreLayer (softmax, nothing , fwd, rev)
114+ function PreLayer (c:: Conv )
115+ grad = _struct_sim (c)
116+ fwd, rev = similar (c. weight, 0 ), similar (c. weight, 0 )
117+ PreLayer (c, grad, fwd, rev)
117118end
118119
119- function (p:: PreLayer{typeof(softmax) } )(x:: AbstractArray{<:Real} )
120- y, dx = _pre_setup (p, x) # generic version
121- _softmaxcall ! (y, p, x, dx)
120+ function (p:: PreLayer{<:Conv } )(x:: AbstractArray{<:Real} )
121+ y, dx = _pre_setup (p, x)
122+ _convcall ! (y, p, x, dx)
122123end
123124
124- _softmaxcall! (y, p, x, dx) = softmax! (y, x)
125+ using Flux: conv_dims, conv_reshape_bias
126+ using Flux. NNlib: fast_act, conv!, output_size, channels_out
125127
126- function ChainRulesCore. rrule (:: typeof (_softmaxcall!), y, p, x, dx)
127- y = _softmaxcall! (y, p, x, dx)
128- function back (dy)
129- # TODO : CHECK THIS!
130- dx .= dy .* y
131- dx .= dx .- y .* sum (dx; dims= 1 ) # could sum! into the end of rev
132- return (NoT, NoT, NoT, dx, NoT) # last one could be NotImplemented?
128+ function _pre_setup (p:: PreLayer{<:Conv} , x)
129+ cdims = conv_dims (p. layer, x)
130+ ysize = (output_size (cdims)... , channels_out (cdims), size (x)[end ])
131+ if prod (ysize) != length (p. fwd)
132+ resize! (p. fwd, prod (ysize))
133+ resize! (p. rev, length (x))
133134 end
134- y, back
135+ y = _pre_reshape (p. fwd, ysize)
136+ dx = _pre_reshape (p. rev, size (x))
137+ (; y, dx)
138+ end
139+
140+ function _convcall! (y, p, x, dx)
141+ cdims = conv_dims (p. layer, x)
142+ conv! (y, x, p. layer. weight, cdims)
143+ if p. layer. bias isa AbstractArray
144+ y .+ = conv_reshape_bias (p. layer)
145+ end
146+ act! (y, fast_act (p. layer. σ, x))
135147end
136148
149+ # function ChainRulesCore.rrule(::typeof(_convcall!), y, p, x, dx)
150+ # y = _densecall!(y, p, x, dx)
151+ # function back(dy)
152+ # dy = unthunk(dy)
153+ # dy = ∇act!(y, dy, p.layer.σ)
154+ # # layer
155+ # weight = mul!(p.grad.weight, dy, x')
156+ # bias = ∇bias!(p.grad.bias, dy)
157+ # tang = Tangent{Dense}(; weight, bias)
158+ # # input
159+ # dx = mul!(dx, p.layer.weight', dy)
160+ # return (NoT, NoT, Tangent{PreLayer}(; layer = tang), dx, NoT)
161+ # end
162+ # y, back
163+ # end
164+
165+
166+
137167# ####
138168# #### BatchNorm
139169# ####
@@ -201,6 +231,33 @@ function ChainRulesCore.rrule(::typeof(_norm_layer_forward!), y, x, dx, μ, σ²
201231 y, back
202232end
203233
234+ # ####
235+ # #### softmax
236+ # ####
237+
238+ function PreLayer (:: typeof (softmax))
239+ fwd, rev = zeros (Float32, 0 ), zeros (Float32, 0 ) # not ideal, demands `model |> pre |> gpu`
240+ PreLayer (softmax, nothing , fwd, rev)
241+ end
242+
243+ function (p:: PreLayer{typeof(softmax)} )(x:: AbstractArray{<:Real} )
244+ y, dx = _pre_setup (p, x) # generic version
245+ _softmaxcall! (y, p, x, dx)
246+ end
247+
248+ _softmaxcall! (y, p, x, dx) = softmax! (y, x)
249+
250+ function ChainRulesCore. rrule (:: typeof (_softmaxcall!), y, p, x, dx)
251+ y = _softmaxcall! (y, p, x, dx)
252+ function back (dy)
253+ # TODO : CHECK THIS!
254+ dx .= dy .* y
255+ dx .= dx .- y .* sum (dx; dims= 1 ) # could sum! into the end of rev
256+ return (NoT, NoT, NoT, dx, NoT) # last one could be NotImplemented?
257+ end
258+ y, back
259+ end
260+
204261
205262# ####
206263# #### activation functions
@@ -212,8 +269,8 @@ function act!(y, act::F) where F
212269 # y .= σ.(y)
213270 # Unfortunately this hits https://github.com/JuliaLang/julia/issues/43153
214271 # maybe you could patch Strided.jl to avoid it? Or use another package...
215- @strided y .= σ .(y)
216- # FastBroadcast. @.. y = σ(y)
272+ # @strided y .= σ.(y)
273+ @. . y = σ (y)
217274end
218275
219276# Piracy, disable @strided on CuArrays:
@@ -223,10 +280,31 @@ Strided.maybestrided(x::Flux.CuArray) = x
223280ChainRulesCore. rrule (:: typeof (act!), y, f) = act! (y, f), dz -> (NoT, ∇act! (y, dy, f), NoT)
224281
225282∇act! (y, dy, :: typeof (identity)) = dy
226- ∇act! (y, dy, :: typeof (relu)) = @. y = ifelse (y> 0 , dy, 0f0 )
227- ∇act! (y, dy, :: typeof (tanh)) = @. y = (1 - y^ 2 )
228- ∇act! (y, dy, :: typeof (sigmoid)) = @. y = y * (1 - y)
283+ ∇act! (y, dy, :: typeof (relu)) = @. . y = ifelse (y> 0 , dy, 0f0 )
284+ ∇act! (y, dy, :: typeof (tanh)) = @. . y = (1 - y^ 2 )
285+ ∇act! (y, dy, :: typeof (sigmoid)) = @. . y = y * (1 - y)
286+
287+
288+ function PreLayer (:: typeof (relu))
289+ fwd, rev = zeros (Float32, 0 ), zeros (Float32, 0 ) # not ideal
290+ PreLayer (relu, nothing , fwd, rev)
291+ end
292+
293+ function (p:: PreLayer{typeof(relu)} )(x:: AbstractArray{<:Real} )
294+ y, dx = _pre_setup (p, x) # generic version
295+ _relucall! (y, p, x, dx)
296+ end
229297
298+ _relucall! (y, p, x, dx) = y .= relu .(x)
299+
300+ function ChainRulesCore. rrule (:: typeof (_relucall!), y, p, x, dx)
301+ y = _relucall! (y, p, x, dx)
302+ function back (dy)
303+ @. dx = ifelse (y> 0 , dy, 0f0 )
304+ return (NoT, NoT, NoT, dx, NoT)
305+ end
306+ y, back
307+ end
230308
231309# ####
232310# #### PreLayer utils
@@ -249,10 +327,14 @@ ChainRulesCore.@non_differentiable _pre_setup(::Any, ::Any)
249327
250328# Cannot use reshape(::Array), as that prevents later resize!
251329_pre_reshape (x:: Array , size:: Tuple ) = Base. ReshapedArray (x, size, ())
330+ # _pre_reshape(x::Array, size::Tuple) = Base.__reshape((x, Base.IndexStyle(x)), size) # what Base does, no better
252331# Must use reshape(::CuArray) as mul! rejects ReshapedArray
253332_pre_reshape (x:: Flux.CuArray , size:: Tuple ) = reshape (x, size)
254333_pre_reshape (x, size:: Tuple ) = reshape (x, size)
255334
335+ # Base piracy! to prevent ReshapedArray from going missing
336+ Base. _reshape (R:: Base.ReshapedArray , dims:: Base.Dims ) = Base. ReshapedArray (R. parent, dims, ())
337+
256338∇bias! (:: Bool , dx) = NoT
257339∇bias! (bias, dx) = sum! (bias, dx)
258340
0 commit comments