Skip to content

Commit d3ddda5

Browse files
committed
restore
1 parent 6f7fa21 commit d3ddda5

File tree

1 file changed

+146
-2
lines changed

1 file changed

+146
-2
lines changed

test/runtests.jl

Lines changed: 146 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,30 @@ function test_matrix_to_number(f, x; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1)
6868
@test isapproxfn((Enzyme.Forward, f), dx_fwd, dx_fd; rtol=rtol, atol=atol, kwargs...)
6969
end
7070

71+
# Aqua.test_all(Enzyme, unbound_args=false, piracies=false, deps_compat=false, stale_deps=(;:ignore=>[:EnzymeTestUtils]))
72+
# Aqua.test_all(Enzyme, unbound_args=false, piracies=false, deps_compat=false, stale_deps=(;:ignore=>[:EnzymeTestUtils]))
73+
74+
include("abi.jl")
75+
include("typetree.jl")
76+
include("passes.jl")
77+
include("optimize.jl")
78+
include("make_zero.jl")
79+
include("runtime_calls.jl")
80+
81+
include("rules.jl")
82+
include("rrules.jl")
83+
include("kwrules.jl")
84+
include("kwrrules.jl")
85+
include("internal_rules.jl")
86+
include("ruleinvalidation.jl")
87+
include("typeunstable.jl")
88+
include("absint.jl")
89+
include("array.jl")
90+
91+
@static if !Sys.iswindows()
92+
include("blas.jl")
93+
end
94+
7195
f0(x) = 1.0 + x
7296
function vrec(start, x)
7397
if start > length(x)
@@ -87,6 +111,128 @@ mutable struct MInts{A, B}
87111
q::Int
88112
end
89113

114+
@testset "Internal tests" begin
115+
@static if VERSION < v"1.11-"
116+
else
117+
@assert Enzyme.Compiler.active_reg_inner(Memory{Float64}, (), nothing) == Enzyme.Compiler.DupState
118+
end
119+
@assert Enzyme.Compiler.active_reg_inner(Type{Array}, (), nothing) == Enzyme.Compiler.AnyState
120+
@assert Enzyme.Compiler.active_reg_inner(Ints{<:Any, Integer}, (), nothing) == Enzyme.Compiler.AnyState
121+
@assert Enzyme.Compiler.active_reg_inner(Ints{<:Any, Float64}, (), nothing) == Enzyme.Compiler.DupState
122+
@assert Enzyme.Compiler.active_reg_inner(Ints{Integer, <:Any}, (), nothing) == Enzyme.Compiler.DupState
123+
@assert Enzyme.Compiler.active_reg_inner(Ints{Integer, <:Integer}, (), nothing) == Enzyme.Compiler.AnyState
124+
@assert Enzyme.Compiler.active_reg_inner(Ints{Integer, <:AbstractFloat}, (), nothing) == Enzyme.Compiler.DupState
125+
@assert Enzyme.Compiler.active_reg_inner(Ints{Integer, Float64}, (), nothing) == Enzyme.Compiler.ActiveState
126+
@assert Enzyme.Compiler.active_reg_inner(MInts{Integer, Float64}, (), nothing) == Enzyme.Compiler.DupState
127+
128+
@assert Enzyme.Compiler.active_reg(Tuple{Float32,Float32,Int})
129+
@assert !Enzyme.Compiler.active_reg(Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}})
130+
@assert !Enzyme.Compiler.active_reg(Base.RefValue{Float32})
131+
@assert Enzyme.Compiler.active_reg_inner(Ptr, (), nothing) == Enzyme.Compiler.DupState
132+
@assert Enzyme.Compiler.active_reg_inner(Base.RefValue{Float32}, (), nothing) == Enzyme.Compiler.DupState
133+
@assert Enzyme.Compiler.active_reg_inner(Colon, (), nothing) == Enzyme.Compiler.AnyState
134+
@assert Enzyme.Compiler.active_reg_inner(Symbol, (), nothing) == Enzyme.Compiler.AnyState
135+
@assert Enzyme.Compiler.active_reg_inner(String, (), nothing) == Enzyme.Compiler.AnyState
136+
@assert Enzyme.Compiler.active_reg_inner(Tuple{Any,Int64}, (), nothing) == Enzyme.Compiler.DupState
137+
@assert Enzyme.Compiler.active_reg_inner(Tuple{S,Int64} where S, (), Base.get_world_counter()) == Enzyme.Compiler.DupState
138+
@assert Enzyme.Compiler.active_reg_inner(Union{Float64,Nothing}, (), nothing) == Enzyme.Compiler.DupState
139+
@assert Enzyme.Compiler.active_reg_inner(Union{Float64,Nothing}, (), nothing, #=justActive=#Val(false), #=unionSret=#Val(true)) == Enzyme.Compiler.ActiveState
140+
@test Enzyme.Compiler.active_reg_inner(Tuple, (), nothing) == Enzyme.Compiler.DupState
141+
@test Enzyme.Compiler.active_reg_inner(Tuple, (), nothing, #=justactive=#Val(false), #=unionsret=#Val(false), #=abstractismixed=#Val(true)) == Enzyme.Compiler.MixedState
142+
@test Enzyme.Compiler.active_reg_inner(Tuple{A,A} where A, (), nothing, #=justactive=#Val(false), #=unionsret=#Val(false), #=abstractismixed=#Val(true)) == Enzyme.Compiler.MixedState
143+
144+
# issue #1935
145+
struct Incomplete
146+
x::Float64
147+
y
148+
Incomplete(x) = new(x)
149+
# incomplete constructor & non-bitstype field => !Base.allocatedinline(Incomplete)
150+
end
151+
@test Enzyme.Compiler.active_reg_inner(Tuple{Incomplete}, (), nothing, #=justActive=#Val(false)) == Enzyme.Compiler.MixedState
152+
@test Enzyme.Compiler.active_reg_inner(Tuple{Incomplete}, (), nothing, #=justActive=#Val(true)) == Enzyme.Compiler.ActiveState
153+
154+
thunk_a = Enzyme.Compiler.thunk(Val(0), Const{typeof(f0)}, Active, Tuple{Active{Float64}}, Val(API.DEM_ReverseModeCombined), Val(1), Val((false, false)), Val(false), Val(false), DefaultABI, Val(false), Val(false), Val(false))
155+
thunk_b = Enzyme.Compiler.thunk(Val(0), Const{typeof(f0)}, Const, Tuple{Const{Float64}}, Val(API.DEM_ReverseModeCombined), Val(1), Val((false, false)), Val(false), Val(false), DefaultABI, Val(false), Val(false), Val(false))
156+
thunk_c = Enzyme.Compiler.thunk(Val(0), Const{typeof(f0)}, Active{Float64}, Tuple{Active{Float64}}, Val(API.DEM_ReverseModeCombined), Val(1), Val((false, false)), Val(false), Val(false), DefaultABI, Val(false), Val(false), Val(false))
157+
thunk_d = Enzyme.Compiler.thunk(Val(0), Const{typeof(f0)}, Active{Float64}, Tuple{Active{Float64}}, Val(API.DEM_ReverseModeCombined), Val(1), Val((false, false)), Val(false), Val(false), DefaultABI, Val(false), Val(false), Val(false))
158+
@test thunk_a.adjoint !== thunk_b.adjoint
159+
@test thunk_c.adjoint === thunk_a.adjoint
160+
@test thunk_c.adjoint === thunk_d.adjoint
161+
162+
@test thunk_a(Const(f0), Active(2.0), 1.0) == ((1.0,),)
163+
@test thunk_a(Const(f0), Active(2.0), 2.0) == ((2.0,),)
164+
@test thunk_b(Const(f0), Const(2.0)) === ((nothing,),)
165+
166+
forward, pullback = Enzyme.Compiler.thunk(Val(0), Const{typeof(f0)}, Active, Tuple{Active{Float64}}, Val(Enzyme.API.DEM_ReverseModeGradient), Val(1), Val((false, false)), Val(false), Val(false), DefaultABI, Val(false), Val(false), Val(false))
167+
168+
@test forward(Const(f0), Active(2.0)) == (nothing,nothing,nothing)
169+
@test pullback(Const(f0), Active(2.0), 1.0, nothing) == ((1.0,),)
170+
171+
function mul2(x)
172+
x[1] * x[2]
173+
end
174+
d = Duplicated([3.0, 5.0], [0.0, 0.0])
175+
176+
forward, pullback = Enzyme.Compiler.thunk(Val(0), Const{typeof(mul2)}, Active, Tuple{Duplicated{Vector{Float64}}}, Val(Enzyme.API.DEM_ReverseModeGradient), Val(1), Val((false, true)), Val(false), Val(false), DefaultABI, Val(false), Val(false), Val(false))
177+
res = forward(Const(mul2), d)
178+
179+
@static if VERSION < v"1.11-"
180+
@test typeof(res[1]) == Tuple{Float64, Float64}
181+
else
182+
@test typeof(res[1]) == NamedTuple{(Symbol("1"),Symbol("2"),Symbol("3")), Tuple{Any, Float64, Float64}}
183+
end
184+
185+
pullback(Const(mul2), d, 1.0, res[1])
186+
@test d.dval[1] 5.0
187+
@test d.dval[2] 3.0
188+
189+
d = Duplicated([3.0, 5.0], [0.0, 0.0])
190+
forward, pullback = Enzyme.Compiler.thunk(Val(0), Const{typeof(vrec)}, Active, Tuple{Const{Int}, Duplicated{Vector{Float64}}}, Val(Enzyme.API.DEM_ReverseModeGradient), Val(1), Val((false, false, true)), Val(false), Val(false), DefaultABI, Val(false), Val(false), Val(false))
191+
res = forward(Const(vrec), Const(Int(1)), d)
192+
pullback(Const(vrec), Const(1), d, 1.0, res[1])
193+
@test d.dval[1] 5.0
194+
@test d.dval[2] 3.0
195+
196+
# @test thunk_split.primal !== C_NULL
197+
# @test thunk_split.primal !== thunk_split.adjoint
198+
# @test thunk_a.adjoint !== thunk_split.adjoint
199+
#
200+
z = ([3.14, 21.5, 16.7], [0,1], [5.6, 8.9])
201+
Enzyme.make_zero!(z)
202+
@test z[1] [0.0, 0.0, 0.0]
203+
@test z[2][1] == 0
204+
@test z[2][2] == 1
205+
@test z[3] [0.0, 0.0]
206+
207+
z2 = ([3.14, 21.5, 16.7], [0,1], [5.6, 8.9])
208+
Enzyme.make_zero!(z2)
209+
@test z2[1] [0.0, 0.0, 0.0]
210+
@test z2[2][1] == 0
211+
@test z2[2][2] == 1
212+
@test z2[3] [0.0, 0.0]
213+
214+
z3 = [3.4, "foo"]
215+
Enzyme.make_zero!(z3)
216+
@test z3[1] 0.0
217+
@test z3[2] == "foo"
218+
219+
z4 = sin
220+
Enzyme.make_zero!(z4)
221+
222+
struct Dense
223+
n_inp::Int
224+
b::Vector{Float64}
225+
end
226+
227+
function Dense(n)
228+
Dense(n, rand(n))
229+
end
230+
231+
nn = Dense(4)
232+
Enzyme.make_zero!(nn)
233+
@test nn.b [0.0, 0.0, 0.0, 0.0]
234+
end
235+
90236
@testset "Reflection" begin
91237
Enzyme.Compiler.enzyme_code_typed(Active, Tuple{Active{Float64}}) do x
92238
x ^ 2
@@ -516,7 +662,6 @@ end
516662
test_scalar(Base.inv, 3.0 + 4.0im)
517663
end
518664

519-
@testset "Base functions" begin
520665
f1(x) = prod(ntuple(i -> i * x, 3))
521666
@test autodiff(Reverse, f1, Active, Active(2.0))[1][1] == 72
522667
@test autodiff(Forward, f1, Duplicated(2.0, 1.0))[1] == 72
@@ -673,4 +818,3 @@ end
673818
f30(x) = reverse([x 2.0 3x])[1]
674819
@test autodiff(Reverse, f30, Active, Active(2.0))[1][1] == 3
675820
@test autodiff(Forward, f30, Duplicated(2.0, 1.0))[1] == 3
676-
end

0 commit comments

Comments
 (0)