From bc2c3f14952f6fd20e55f358755e18f246ccb581 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 19 Feb 2022 19:24:03 -0500 Subject: [PATCH 1/3] test zero-arrays --- test/rules.jl | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/test/rules.jl b/test/rules.jl index ffb4ca65..1a0e83d8 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -226,3 +226,28 @@ end @test static_loss(static_model) < 1.9 end end + +@testset "zero-dim arrays" begin + empty!(LOG) + @testset "$(name(o))" for o in RULES + m = (a = fill(1.0), b = SArray{Tuple{}}(fill(1.0)), c = PermutedDimsArray(fill(1.0), ())) + s = Optimisers.setup(o, m) + for _ in 1:10^3 + g = loggradient(o)(x -> abs2(first(x.a) + first(x.b) + first(x.c)), m)[1] + s, m = Optimisers.update(s, m, g) + end + # The main point here is that broadcasting should not accidentally make a scalar, + # but `m.a` is mutated, and `m.b .+ 1` is an array, so only `m.c` is a real test. + @test m.a isa Array{Float64, 0} + @test m.b isa AbstractArray{Float64, 0} + @test_broken m.c isa AbstractArray{Float64, 0} + if o isa RADAM + @test sum(m.a) < 0.7 + @test_broken sum(m.a) < 0.3 + else + @test sum(m.a) < 0.3 + @test sum(m.b) < 0.3 + @test sum(m.c) < 0.3 + end + end +end From 3045aef02a39e02fd3812e961ff6195434e00119 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 23 Feb 2022 00:13:31 -0500 Subject: [PATCH 2/3] explicit broadcast_preserving_zero_d, ignore StaticArrays --- src/Optimisers.jl | 1 + src/interface.jl | 8 +++++++- test/rules.jl | 27 +++++++++++++++------------ test/runtests.jl | 26 +++++++++++++++++++++++++- 4 files changed, 48 insertions(+), 14 deletions(-) diff --git a/src/Optimisers.jl b/src/Optimisers.jl index 417b90d4..616f0ea2 100644 --- a/src/Optimisers.jl +++ b/src/Optimisers.jl @@ -2,6 +2,7 @@ module Optimisers using Functors: functor, fmap, isleaf using LinearAlgebra +using Base.Broadcast: broadcast_preserving_zero_d, broadcasted include("interface.jl") diff --git a/src/interface.jl b/src/interface.jl index 235c2e94..d4b0b542 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -21,7 +21,13 @@ function setup(rule, x; seen = Base.IdSet()) end end -subtract!(x, x̄) = iswriteable(x) ? (x .= x .- x̄) : eltype(x).(x .- x̄) +function subtract!(x, x̄) + if iswriteable(x) + x .= x .- x̄ + else + broadcast_preserving_zero_d(eltype(x), broadcasted(-, x, x̄)) + end +end update!(::Nothing, x, ::Zero, ::Zero...) = nothing, x update!(::Nothing, x, x̄s...) = nothing, x diff --git a/test/rules.jl b/test/rules.jl index 1a0e83d8..450fccb8 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -230,24 +230,27 @@ end @testset "zero-dim arrays" begin empty!(LOG) @testset "$(name(o))" for o in RULES - m = (a = fill(1.0), b = SArray{Tuple{}}(fill(1.0)), c = PermutedDimsArray(fill(1.0), ())) + m = (; arr = fill(1.0), pda = PermutedDimsArray(fill(1.0), ()), ref = Ref(1.0)) + # The point of PermutedDimsArray here is to test the out-of-place path, so check: + @test Optimisers.iswriteable(m.arr) + @test !Optimisers.iswriteable(m.pda) s = Optimisers.setup(o, m) for _ in 1:10^3 - g = loggradient(o)(x -> abs2(first(x.a) + first(x.b) + first(x.c)), m)[1] + g = loggradient(o)(x -> abs2(first(x.arr) + first(x.pda) + first(x.ref)), m)[1] s, m = Optimisers.update(s, m, g) end - # The main point here is that broadcasting should not accidentally make a scalar, - # but `m.a` is mutated, and `m.b .+ 1` is an array, so only `m.c` is a real test. - @test m.a isa Array{Float64, 0} - @test m.b isa AbstractArray{Float64, 0} - @test_broken m.c isa AbstractArray{Float64, 0} + # Goal is to check that broadcasting does not accidentally make a scalar, + # but `m.arr` iscopied & mutated, so only `m.pda` is a real test: + @test m.arr isa Array{Float64, 0} + @test m.pda isa AbstractArray{Float64, 0} + @test m.ref isa Ref # because it's mutated, broadcast_preserving_zero_d would make an array if o isa RADAM - @test sum(m.a) < 0.7 - @test_broken sum(m.a) < 0.3 + @test sum(m.arr) < 0.7 + @test_broken sum(m.arr) < 0.3 else - @test sum(m.a) < 0.3 - @test sum(m.b) < 0.3 - @test sum(m.c) < 0.3 + @test sum(m.arr) < 0.3 + @test sum(m.pda) < 0.3 end + @test_broken only(m.ref) < 0.3 # not currently regarded as trainable end end diff --git a/test/runtests.jl b/test/runtests.jl index d47bce08..3b336884 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -80,7 +80,7 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,) end @testset "trainable subset" begin - @info "ignore these warnings about trainable, testing the old path" + @info "ignore these warnings about `trainable`, they are testing the path for old-style methods" # Foo has an old-style tuple trainable, both elements mf = Foo([1.0, 2.0], (a = sin, b = [3.0, 4.0], c = 5)) sf = Optimisers.setup(Descent(0.1), mf) @@ -131,6 +131,30 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,) @test eltype(m4[2]) == Float32 end + @testset "zero dimension" begin + # Mutable Array{T,0} + m = fill(1.0) + s = Optimisers.setup(Descent(0.1), m) + s2, m2 = Optimisers.update!(s, m, fill(2.0)) + @test m2 === m + @test only(m) ≈ 0.8 + + # "Immutable" zero-array, takes out-of-place path: + m3 = PermutedDimsArray(fill(1.0), ()) + @test !Optimisers.iswriteable(m3) # note that there's Base.iswritable, it seems I can't spell + s3 = Optimisers.setup(Descent(0.1), m3) + s4, m4 = Optimisers.update!(s3, m3, fill(2.0)) + @test m4 !== m3 + @test only(m4) ≈ 0.8 + + # Ref, should this be regarded as holding a parameter? At present it's not: + m5 = Ref(1.0) + s5 = Optimisers.setup(Descent(0.1), m5) + g5 = gradient(m -> m[]^2, m5)[1] # (x = 2.0,) + s6, m6 = Optimisers.update!(s5, m5, g5) + @test_broken m6[] ≈ 0.8 + end + @testset "forgotten gradient" begin x = [1.0, 2.0] sx = Optimisers.setup(Descent(), x) From 62cc4578e667151fe6096523b51e7f9b6c5fb6c3 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 23 Feb 2022 21:16:03 -0500 Subject: [PATCH 3/3] don't mark Ref tests broken --- test/rules.jl | 2 +- test/runtests.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/rules.jl b/test/rules.jl index 450fccb8..63aef5bb 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -251,6 +251,6 @@ end @test sum(m.arr) < 0.3 @test sum(m.pda) < 0.3 end - @test_broken only(m.ref) < 0.3 # not currently regarded as trainable + @test only(m.ref) ≈ 1 # not currently regarded as trainable end end diff --git a/test/runtests.jl b/test/runtests.jl index 3b336884..e01ca411 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -152,7 +152,7 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,) s5 = Optimisers.setup(Descent(0.1), m5) g5 = gradient(m -> m[]^2, m5)[1] # (x = 2.0,) s6, m6 = Optimisers.update!(s5, m5, g5) - @test_broken m6[] ≈ 0.8 + @test m6[] ≈ 1 end @testset "forgotten gradient" begin