Skip to content

Allow zero-arrays #58

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/Optimisers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module Optimisers

using Functors: functor, fmap, isleaf
using LinearAlgebra
using Base.Broadcast: broadcast_preserving_zero_d, broadcasted

include("interface.jl")

Expand Down
8 changes: 7 additions & 1 deletion src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@ function setup(rule, x; seen = Base.IdSet())
end
end

subtract!(x, x̄) = iswriteable(x) ? (x .= x .- x̄) : eltype(x).(x .- x̄)
function subtract!(x, x̄)
if iswriteable(x)
x .= x .- x̄
else
broadcast_preserving_zero_d(eltype(x), broadcasted(-, x, x̄))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a public function? Meaning we can expect it to be stable?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not entirely sure. It has a docstring but isn't in the manual. It's used quite a bit, e.g. to implement conj at https://github.com/JuliaLang/julia/blob/4c8c5153a566b25ef8c7b7091b5126328812d287/base/abstractarraymath.jl#L145 .

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I doubt this is the only use of it in the wild either: https://juliahub.com/ui/Search?q=broadcast_preserving_zero_d&type=code

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good enough for me!

end
end

update!(::Nothing, x, ::Zero, ::Zero...) = nothing, x
update!(::Nothing, x, x̄s...) = nothing, x
Expand Down
28 changes: 28 additions & 0 deletions test/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -226,3 +226,31 @@ end
@test static_loss(static_model) < 1.9
end
end

@testset "zero-dim arrays" begin
empty!(LOG)
@testset "$(name(o))" for o in RULES
m = (; arr = fill(1.0), pda = PermutedDimsArray(fill(1.0), ()), ref = Ref(1.0))
# The point of PermutedDimsArray here is to test the out-of-place path, so check:
@test Optimisers.iswriteable(m.arr)
@test !Optimisers.iswriteable(m.pda)
s = Optimisers.setup(o, m)
for _ in 1:10^3
g = loggradient(o)(x -> abs2(first(x.arr) + first(x.pda) + first(x.ref)), m)[1]
s, m = Optimisers.update(s, m, g)
end
# Goal is to check that broadcasting does not accidentally make a scalar,
# but `m.arr` iscopied & mutated, so only `m.pda` is a real test:
@test m.arr isa Array{Float64, 0}
@test m.pda isa AbstractArray{Float64, 0}
@test m.ref isa Ref # because it's mutated, broadcast_preserving_zero_d would make an array
if o isa RADAM
@test sum(m.arr) < 0.7
@test_broken sum(m.arr) < 0.3
else
@test sum(m.arr) < 0.3
@test sum(m.pda) < 0.3
end
@test only(m.ref) ≈ 1 # not currently regarded as trainable
end
end
26 changes: 25 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,)
end

@testset "trainable subset" begin
@info "ignore these warnings about trainable, testing the old path"
@info "ignore these warnings about `trainable`, they are testing the path for old-style methods"
# Foo has an old-style tuple trainable, both elements
mf = Foo([1.0, 2.0], (a = sin, b = [3.0, 4.0], c = 5))
sf = Optimisers.setup(Descent(0.1), mf)
Expand Down Expand Up @@ -131,6 +131,30 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,)
@test eltype(m4[2]) == Float32
end

@testset "zero dimension" begin
# Mutable Array{T,0}
m = fill(1.0)
s = Optimisers.setup(Descent(0.1), m)
s2, m2 = Optimisers.update!(s, m, fill(2.0))
@test m2 === m
@test only(m) ≈ 0.8

# "Immutable" zero-array, takes out-of-place path:
m3 = PermutedDimsArray(fill(1.0), ())
@test !Optimisers.iswriteable(m3) # note that there's Base.iswritable, it seems I can't spell
s3 = Optimisers.setup(Descent(0.1), m3)
s4, m4 = Optimisers.update!(s3, m3, fill(2.0))
@test m4 !== m3
@test only(m4) ≈ 0.8

# Ref, should this be regarded as holding a parameter? At present it's not:
m5 = Ref(1.0)
s5 = Optimisers.setup(Descent(0.1), m5)
g5 = gradient(m -> m[]^2, m5)[1] # (x = 2.0,)
s6, m6 = Optimisers.update!(s5, m5, g5)
@test m6[] ≈ 1
end

@testset "forgotten gradient" begin
x = [1.0, 2.0]
sx = Optimisers.setup(Descent(), x)
Expand Down