Skip to content

Commit a05ba91

Browse files
committed
refactor: move overrides into a separate file
1 parent 65e9976 commit a05ba91

File tree

4 files changed

+26
-19
lines changed

4 files changed

+26
-19
lines changed

src/Compiler.jl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -779,13 +779,6 @@ function compile(f, args; client=nothing, optimize=true, sync=false)
779779
return register_thunk(fname, body)
780780
end
781781

782-
# Compiling within a compile should return simply the original function
783-
Reactant.@reactant_override function Reactant.Compiler.compile(
784-
f, args; client=nothing, optimize=true, sync=false
785-
)
786-
return f
787-
end
788-
789782
# inspired by RuntimeGeneratedFunction.jl
790783
const __thunk_body_cache = Dict{Symbol,Expr}()
791784

src/Interpreter.jl

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -479,15 +479,3 @@ function overload_autodiff(
479479
end
480480
end
481481
end
482-
483-
@reactant_override @noinline function Enzyme.autodiff_deferred(
484-
rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs}
485-
) where {FA<:Annotation,A<:Annotation,Nargs}
486-
return overload_autodiff(rmode, f, rt, args...)
487-
end
488-
489-
@reactant_override @noinline function Enzyme.autodiff(
490-
rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs}
491-
) where {FA<:Annotation,A<:Annotation,Nargs}
492-
return overload_autodiff(rmode, f, rt, args...)
493-
end

src/Overrides.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# NOTE: We are placing all the reactant_overrides here to avoid incompatibilities with
2+
# Revise.jl. Essentially files that contain reactant_overrides cannot be revised
3+
# correctly. Once that (https://github.com/timholy/Revise.jl/issues/646) is resolved
4+
# we should move all the reactant_overrides to relevant files.
5+
6+
# Compiling within a compile should return simply the original function
7+
@reactant_override function Compiler.compile(
8+
f, args; client=nothing, optimize=true, sync=false
9+
)
10+
return f
11+
end
12+
13+
# Enzyme overrides
14+
@reactant_override @noinline function Enzyme.autodiff_deferred(
15+
rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs}
16+
) where {FA<:Annotation,A<:Annotation,Nargs}
17+
return overload_autodiff(rmode, f, rt, args...)
18+
end
19+
20+
@reactant_override @noinline function Enzyme.autodiff(
21+
rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs}
22+
) where {FA<:Annotation,A<:Annotation,Nargs}
23+
return overload_autodiff(rmode, f, rt, args...)
24+
end

src/Reactant.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,8 @@ include("ControlFlow.jl")
130130
include("Tracing.jl")
131131
include("Compiler.jl")
132132

133+
include("Overrides.jl")
134+
133135
function Enzyme.make_zero(
134136
::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false)
135137
)::RT where {copy_if_inactive,RT<:RArray}

0 commit comments

Comments
 (0)