@@ -68,6 +68,30 @@ function test_matrix_to_number(f, x; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1)
6868 @test isapproxfn ((Enzyme. Forward, f), dx_fwd, dx_fd; rtol= rtol, atol= atol, kwargs... )
6969end
7070
71+ # Aqua.test_all(Enzyme, unbound_args=false, piracies=false, deps_compat=false, stale_deps=(;:ignore=>[:EnzymeTestUtils]))
72+ # Aqua.test_all(Enzyme, unbound_args=false, piracies=false, deps_compat=false, stale_deps=(;:ignore=>[:EnzymeTestUtils]))
73+
74+ include (" abi.jl" )
75+ include (" typetree.jl" )
76+ include (" passes.jl" )
77+ include (" optimize.jl" )
78+ include (" make_zero.jl" )
79+ include (" runtime_calls.jl" )
80+
81+ include (" rules.jl" )
82+ include (" rrules.jl" )
83+ include (" kwrules.jl" )
84+ include (" kwrrules.jl" )
85+ include (" internal_rules.jl" )
86+ include (" ruleinvalidation.jl" )
87+ include (" typeunstable.jl" )
88+ include (" absint.jl" )
89+ include (" array.jl" )
90+
91+ @static if ! Sys. iswindows ()
92+ include (" blas.jl" )
93+ end
94+
7195f0 (x) = 1.0 + x
7296function vrec (start, x)
7397 if start > length (x)
@@ -87,6 +111,128 @@ mutable struct MInts{A, B}
87111 q:: Int
88112end
89113
114+ @testset " Internal tests" begin
115+ @static if VERSION < v " 1.11-"
116+ else
117+ @assert Enzyme. Compiler. active_reg_inner (Memory{Float64}, (), nothing ) == Enzyme. Compiler. DupState
118+ end
119+ @assert Enzyme. Compiler. active_reg_inner (Type{Array}, (), nothing ) == Enzyme. Compiler. AnyState
120+ @assert Enzyme. Compiler. active_reg_inner (Ints{<: Any , Integer}, (), nothing ) == Enzyme. Compiler. AnyState
121+ @assert Enzyme. Compiler. active_reg_inner (Ints{<: Any , Float64}, (), nothing ) == Enzyme. Compiler. DupState
122+ @assert Enzyme. Compiler. active_reg_inner (Ints{Integer, <: Any }, (), nothing ) == Enzyme. Compiler. DupState
123+ @assert Enzyme. Compiler. active_reg_inner (Ints{Integer, <: Integer }, (), nothing ) == Enzyme. Compiler. AnyState
124+ @assert Enzyme. Compiler. active_reg_inner (Ints{Integer, <: AbstractFloat }, (), nothing ) == Enzyme. Compiler. DupState
125+ @assert Enzyme. Compiler. active_reg_inner (Ints{Integer, Float64}, (), nothing ) == Enzyme. Compiler. ActiveState
126+ @assert Enzyme. Compiler. active_reg_inner (MInts{Integer, Float64}, (), nothing ) == Enzyme. Compiler. DupState
127+
128+ @assert Enzyme. Compiler. active_reg (Tuple{Float32,Float32,Int})
129+ @assert ! Enzyme. Compiler. active_reg (Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}})
130+ @assert ! Enzyme. Compiler. active_reg (Base. RefValue{Float32})
131+ @assert Enzyme. Compiler. active_reg_inner (Ptr, (), nothing ) == Enzyme. Compiler. DupState
132+ @assert Enzyme. Compiler. active_reg_inner (Base. RefValue{Float32}, (), nothing ) == Enzyme. Compiler. DupState
133+ @assert Enzyme. Compiler. active_reg_inner (Colon, (), nothing ) == Enzyme. Compiler. AnyState
134+ @assert Enzyme. Compiler. active_reg_inner (Symbol, (), nothing ) == Enzyme. Compiler. AnyState
135+ @assert Enzyme. Compiler. active_reg_inner (String, (), nothing ) == Enzyme. Compiler. AnyState
136+ @assert Enzyme. Compiler. active_reg_inner (Tuple{Any,Int64}, (), nothing ) == Enzyme. Compiler. DupState
137+ @assert Enzyme. Compiler. active_reg_inner (Tuple{S,Int64} where S, (), Base. get_world_counter ()) == Enzyme. Compiler. DupState
138+ @assert Enzyme. Compiler. active_reg_inner (Union{Float64,Nothing}, (), nothing ) == Enzyme. Compiler. DupState
139+ @assert Enzyme. Compiler. active_reg_inner (Union{Float64,Nothing}, (), nothing , #= justActive=# Val (false ), #= unionSret=# Val (true )) == Enzyme. Compiler. ActiveState
140+ @test Enzyme. Compiler. active_reg_inner (Tuple, (), nothing ) == Enzyme. Compiler. DupState
141+ @test Enzyme. Compiler. active_reg_inner (Tuple, (), nothing , #= justactive=# Val (false ), #= unionsret=# Val (false ), #= abstractismixed=# Val (true )) == Enzyme. Compiler. MixedState
142+ @test Enzyme. Compiler. active_reg_inner (Tuple{A,A} where A, (), nothing , #= justactive=# Val (false ), #= unionsret=# Val (false ), #= abstractismixed=# Val (true )) == Enzyme. Compiler. MixedState
143+
144+ # issue #1935
145+ struct Incomplete
146+ x:: Float64
147+ y
148+ Incomplete (x) = new (x)
149+ # incomplete constructor & non-bitstype field => !Base.allocatedinline(Incomplete)
150+ end
151+ @test Enzyme. Compiler. active_reg_inner (Tuple{Incomplete}, (), nothing , #= justActive=# Val (false )) == Enzyme. Compiler. MixedState
152+ @test Enzyme. Compiler. active_reg_inner (Tuple{Incomplete}, (), nothing , #= justActive=# Val (true )) == Enzyme. Compiler. ActiveState
153+
154+ thunk_a = Enzyme. Compiler. thunk (Val (0 ), Const{typeof (f0)}, Active, Tuple{Active{Float64}}, Val (API. DEM_ReverseModeCombined), Val (1 ), Val ((false , false )), Val (false ), Val (false ), DefaultABI, Val (false ), Val (false ), Val (false ))
155+ thunk_b = Enzyme. Compiler. thunk (Val (0 ), Const{typeof (f0)}, Const, Tuple{Const{Float64}}, Val (API. DEM_ReverseModeCombined), Val (1 ), Val ((false , false )), Val (false ), Val (false ), DefaultABI, Val (false ), Val (false ), Val (false ))
156+ thunk_c = Enzyme. Compiler. thunk (Val (0 ), Const{typeof (f0)}, Active{Float64}, Tuple{Active{Float64}}, Val (API. DEM_ReverseModeCombined), Val (1 ), Val ((false , false )), Val (false ), Val (false ), DefaultABI, Val (false ), Val (false ), Val (false ))
157+ thunk_d = Enzyme. Compiler. thunk (Val (0 ), Const{typeof (f0)}, Active{Float64}, Tuple{Active{Float64}}, Val (API. DEM_ReverseModeCombined), Val (1 ), Val ((false , false )), Val (false ), Val (false ), DefaultABI, Val (false ), Val (false ), Val (false ))
158+ @test thunk_a. adjoint != = thunk_b. adjoint
159+ @test thunk_c. adjoint === thunk_a. adjoint
160+ @test thunk_c. adjoint === thunk_d. adjoint
161+
162+ @test thunk_a (Const (f0), Active (2.0 ), 1.0 ) == ((1.0 ,),)
163+ @test thunk_a (Const (f0), Active (2.0 ), 2.0 ) == ((2.0 ,),)
164+ @test thunk_b (Const (f0), Const (2.0 )) === ((nothing ,),)
165+
166+ forward, pullback = Enzyme. Compiler. thunk (Val (0 ), Const{typeof (f0)}, Active, Tuple{Active{Float64}}, Val (Enzyme. API. DEM_ReverseModeGradient), Val (1 ), Val ((false , false )), Val (false ), Val (false ), DefaultABI, Val (false ), Val (false ), Val (false ))
167+
168+ @test forward (Const (f0), Active (2.0 )) == (nothing ,nothing ,nothing )
169+ @test pullback (Const (f0), Active (2.0 ), 1.0 , nothing ) == ((1.0 ,),)
170+
171+ function mul2 (x)
172+ x[1 ] * x[2 ]
173+ end
174+ d = Duplicated ([3.0 , 5.0 ], [0.0 , 0.0 ])
175+
176+ forward, pullback = Enzyme. Compiler. thunk (Val (0 ), Const{typeof (mul2)}, Active, Tuple{Duplicated{Vector{Float64}}}, Val (Enzyme. API. DEM_ReverseModeGradient), Val (1 ), Val ((false , true )), Val (false ), Val (false ), DefaultABI, Val (false ), Val (false ), Val (false ))
177+ res = forward (Const (mul2), d)
178+
179+ @static if VERSION < v " 1.11-"
180+ @test typeof (res[1 ]) == Tuple{Float64, Float64}
181+ else
182+ @test typeof (res[1 ]) == NamedTuple{(Symbol (" 1" ),Symbol (" 2" ),Symbol (" 3" )), Tuple{Any, Float64, Float64}}
183+ end
184+
185+ pullback (Const (mul2), d, 1.0 , res[1 ])
186+ @test d. dval[1 ] ≈ 5.0
187+ @test d. dval[2 ] ≈ 3.0
188+
189+ d = Duplicated ([3.0 , 5.0 ], [0.0 , 0.0 ])
190+ forward, pullback = Enzyme. Compiler. thunk (Val (0 ), Const{typeof (vrec)}, Active, Tuple{Const{Int}, Duplicated{Vector{Float64}}}, Val (Enzyme. API. DEM_ReverseModeGradient), Val (1 ), Val ((false , false , true )), Val (false ), Val (false ), DefaultABI, Val (false ), Val (false ), Val (false ))
191+ res = forward (Const (vrec), Const (Int (1 )), d)
192+ pullback (Const (vrec), Const (1 ), d, 1.0 , res[1 ])
193+ @test d. dval[1 ] ≈ 5.0
194+ @test d. dval[2 ] ≈ 3.0
195+
196+ # @test thunk_split.primal !== C_NULL
197+ # @test thunk_split.primal !== thunk_split.adjoint
198+ # @test thunk_a.adjoint !== thunk_split.adjoint
199+ #
200+ z = ([3.14 , 21.5 , 16.7 ], [0 ,1 ], [5.6 , 8.9 ])
201+ Enzyme. make_zero! (z)
202+ @test z[1 ] ≈ [0.0 , 0.0 , 0.0 ]
203+ @test z[2 ][1 ] == 0
204+ @test z[2 ][2 ] == 1
205+ @test z[3 ] ≈ [0.0 , 0.0 ]
206+
207+ z2 = ([3.14 , 21.5 , 16.7 ], [0 ,1 ], [5.6 , 8.9 ])
208+ Enzyme. make_zero! (z2)
209+ @test z2[1 ] ≈ [0.0 , 0.0 , 0.0 ]
210+ @test z2[2 ][1 ] == 0
211+ @test z2[2 ][2 ] == 1
212+ @test z2[3 ] ≈ [0.0 , 0.0 ]
213+
214+ z3 = [3.4 , " foo" ]
215+ Enzyme. make_zero! (z3)
216+ @test z3[1 ] ≈ 0.0
217+ @test z3[2 ] == " foo"
218+
219+ z4 = sin
220+ Enzyme. make_zero! (z4)
221+
222+ struct Dense
223+ n_inp:: Int
224+ b:: Vector{Float64}
225+ end
226+
227+ function Dense (n)
228+ Dense (n, rand (n))
229+ end
230+
231+ nn = Dense (4 )
232+ Enzyme. make_zero! (nn)
233+ @test nn. b ≈ [0.0 , 0.0 , 0.0 , 0.0 ]
234+ end
235+
90236@testset " Reflection" begin
91237 Enzyme. Compiler. enzyme_code_typed (Active, Tuple{Active{Float64}}) do x
92238 x ^ 2
516662 test_scalar (Base. inv, 3.0 + 4.0im )
517663end
518664
519- @testset " Base functions" begin
520665 f1 (x) = prod (ntuple (i -> i * x, 3 ))
521666 @test autodiff (Reverse, f1, Active, Active (2.0 ))[1 ][1 ] == 72
522667 @test autodiff (Forward, f1, Duplicated (2.0 , 1.0 ))[1 ] == 72
673818 f30 (x) = reverse ([x 2.0 3 x])[1 ]
674819 @test autodiff (Reverse, f30, Active, Active (2.0 ))[1 ][1 ] == 3
675820 @test autodiff (Forward, f30, Duplicated (2.0 , 1.0 ))[1 ] == 3
676- end
0 commit comments