From cec694a49b39b337b19c94cd124209fc5892acd0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 15 Dec 2024 12:23:45 +0530 Subject: [PATCH 1/5] refactor: move overrides into a separate file --- src/Compiler.jl | 7 ------- src/Interpreter.jl | 12 ------------ src/Overrides.jl | 24 ++++++++++++++++++++++++ src/Reactant.jl | 2 ++ src/TracedUtils.jl | 3 ++- 5 files changed, 28 insertions(+), 20 deletions(-) create mode 100644 src/Overrides.jl diff --git a/src/Compiler.jl b/src/Compiler.jl index 5f7158d82e..cc32a90b1a 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -779,13 +779,6 @@ function compile(f, args; client=nothing, optimize=true, sync=false) return register_thunk(fname, body) end -# Compiling within a compile should return simply the original function -Reactant.@reactant_override function Reactant.Compiler.compile( - f, args; client=nothing, optimize=true, sync=false -) - return f -end - # inspired by RuntimeGeneratedFunction.jl const __thunk_body_cache = Dict{Symbol,Expr}() diff --git a/src/Interpreter.jl b/src/Interpreter.jl index 72e27c5d87..cd265c9ebc 100644 --- a/src/Interpreter.jl +++ b/src/Interpreter.jl @@ -479,15 +479,3 @@ function overload_autodiff( end end end - -@reactant_override @noinline function Enzyme.autodiff_deferred( - rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs} -) where {FA<:Annotation,A<:Annotation,Nargs} - return overload_autodiff(rmode, f, rt, args...) -end - -@reactant_override @noinline function Enzyme.autodiff( - rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs} -) where {FA<:Annotation,A<:Annotation,Nargs} - return overload_autodiff(rmode, f, rt, args...) -end diff --git a/src/Overrides.jl b/src/Overrides.jl new file mode 100644 index 0000000000..0df144d3a5 --- /dev/null +++ b/src/Overrides.jl @@ -0,0 +1,24 @@ +# NOTE: We are placing all the reactant_overrides here to avoid incompatibilities with +# Revise.jl. Essentially files that contain reactant_overrides cannot be revised +# correctly. Once that (https://github.com/timholy/Revise.jl/issues/646) is resolved +# we should move all the reactant_overrides to relevant files. + +# Compiling within a compile should return simply the original function +@reactant_override function Compiler.compile( + f, args; client=nothing, optimize=true, sync=false +) + return f +end + +# Enzyme overrides +@reactant_override @noinline function Enzyme.autodiff_deferred( + rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs} +) where {FA<:Annotation,A<:Annotation,Nargs} + return overload_autodiff(rmode, f, rt, args...) +end + +@reactant_override @noinline function Enzyme.autodiff( + rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs} +) where {FA<:Annotation,A<:Annotation,Nargs} + return overload_autodiff(rmode, f, rt, args...) +end diff --git a/src/Reactant.jl b/src/Reactant.jl index ba2da588d9..9830965df8 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -130,6 +130,8 @@ include("ControlFlow.jl") include("Tracing.jl") include("Compiler.jl") +include("Overrides.jl") + function Enzyme.make_zero( ::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false) )::RT where {copy_if_inactive,RT<:RArray} diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index c4a839ab71..dd802c69fe 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -13,7 +13,8 @@ using ..Reactant: WrappedTracedRArray, AnyTracedRArray, MissingTracedValue, - OrderedIdDict + OrderedIdDict, + Compiler import ..Reactant import ..Reactant.MLIR import ..ReactantPrimitive From c492390daa0003d37d4b77c9f617f6c26948aeaf Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 15 Dec 2024 17:01:50 +0530 Subject: [PATCH 2/5] fix: traced_getfield move to TracedUtils --- src/Compiler.jl | 6 +----- src/ConcreteRArray.jl | 17 ----------------- src/Interpreter.jl | 2 +- src/Reactant.jl | 23 ++++++++++++++++++++--- src/TracedUtils.jl | 8 +++++++- 5 files changed, 29 insertions(+), 27 deletions(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index cc32a90b1a..8267ec5e94 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -13,11 +13,7 @@ import ..Reactant: TracedToConcrete, append_path, TracedType - -@inline traced_getfield(@nospecialize(obj), field) = Base.getfield(obj, field) -@inline traced_getfield( - @nospecialize(obj::AbstractArray{<:Union{ConcreteRNumber,ConcreteRArray}}), field -) = Base.getindex(obj, field) +import ..TracedUtils: TracedUtils, traced_getfield function create_result(tocopy::T, path, result_stores) where {T} if !isstructtype(typeof(tocopy)) diff --git a/src/ConcreteRArray.jl b/src/ConcreteRArray.jl index e9d9c02d7f..9a072a8071 100644 --- a/src/ConcreteRArray.jl +++ b/src/ConcreteRArray.jl @@ -1,20 +1,3 @@ -struct XLAArray{T,N} <: RArray{T,N} - # size::NTuple{N,Int} -end - -mutable struct ConcreteRArray{T,N} <: RArray{T,N} - data::XLA.AsyncBuffer - # data::XLAArray{T, N} - shape::NTuple{N,Int} -end - -const WrappedConcreteRArray{T,N} = WrappedArray{T,N,ConcreteRArray,ConcreteRArray{T,N}} -const AnyConcreteRArray{T,N} = Union{ConcreteRArray{T,N},WrappedConcreteRArray{T,N}} - -mutable struct ConcreteRNumber{T} <: RNumber{T} - data::XLA.AsyncBuffer -end - function ConcreteRNumber{T}( data::T2; client=XLA.default_backend[], idx=XLA.default_device_idx[], device=nothing ) where {T<:Number,T2<:Number} diff --git a/src/Interpreter.jl b/src/Interpreter.jl index cd265c9ebc..8472121564 100644 --- a/src/Interpreter.jl +++ b/src/Interpreter.jl @@ -202,7 +202,7 @@ function set_act!(inp, path, reverse, tostore; emptypath=false) end for p in path - x = traced_getfield(x, p) + x = TracedUtils.traced_getfield(x, p) end #if inp isa Enzyme.Active || !reverse diff --git a/src/Reactant.jl b/src/Reactant.jl index 9830965df8..c611e49e3e 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -114,6 +114,25 @@ mutable struct TracedRNumber{T} <: RNumber{T} end end +struct XLAArray{T,N} <: RArray{T,N} + # size::NTuple{N,Int} +end + +mutable struct ConcreteRArray{T,N} <: RArray{T,N} + data::XLA.AsyncBuffer + # data::XLAArray{T, N} + shape::NTuple{N,Int} +end + +const WrappedConcreteRArray{T,N} = WrappedArray{T,N,ConcreteRArray,ConcreteRArray{T,N}} +const AnyConcreteRArray{T,N} = Union{ConcreteRArray{T,N},WrappedConcreteRArray{T,N}} + +mutable struct ConcreteRNumber{T} <: RNumber{T} + data::XLA.AsyncBuffer +end + +const TracedType = Union{TracedRArray,TracedRNumber,MissingTracedValue} + include("Ops.jl") include("TracedUtils.jl") @@ -124,8 +143,6 @@ include("ConcreteRArray.jl") include("linear_algebra.jl") -const TracedType = Union{TracedRArray,TracedRNumber,MissingTracedValue} - include("ControlFlow.jl") include("Tracing.jl") include("Compiler.jl") @@ -146,7 +163,7 @@ function Enzyme.make_zero( return res end -using .Compiler: @compile, @code_hlo, @jit, traced_getfield, create_result, compile +using .Compiler: @compile, @code_hlo, @jit, create_result, compile export ConcreteRArray, ConcreteRNumber, @compile, @code_hlo, @jit, @trace const registry = Ref{MLIR.IR.DialectRegistry}() diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index dd802c69fe..c5ae212055 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -14,12 +14,18 @@ using ..Reactant: AnyTracedRArray, MissingTracedValue, OrderedIdDict, - Compiler + ConcreteRArray, + ConcreteRNumber import ..Reactant import ..Reactant.MLIR import ..ReactantPrimitive import ..Ops +@inline traced_getfield(@nospecialize(obj), field) = Base.getfield(obj, field) +@inline traced_getfield( + @nospecialize(obj::AbstractArray{<:Union{ConcreteRNumber,ConcreteRArray}}), field +) = Base.getindex(obj, field) + materialize_traced_array(x::TracedRArray) = x materialize_traced_array(x::WrappedTracedRArray) = x[axes(x)...] function materialize_traced_array( From 07a9f410464b0a9efedc8c9cff3e98e0cb737175 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 15 Dec 2024 17:06:40 +0530 Subject: [PATCH 3/5] refactor: rename to overlay --- src/Interpreter.jl | 2 +- src/{Overrides.jl => Overlay.jl} | 6 +++--- src/Reactant.jl | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) rename src/{Overrides.jl => Overlay.jl} (83%) diff --git a/src/Interpreter.jl b/src/Interpreter.jl index 8472121564..7fdfc10dec 100644 --- a/src/Interpreter.jl +++ b/src/Interpreter.jl @@ -23,7 +23,7 @@ import Core.Compiler: Base.Experimental.@MethodTable(REACTANT_METHOD_TABLE) -function var"@reactant_override"(__source__::LineNumberNode, __module__::Module, def) +function var"@reactant_overlay"(__source__::LineNumberNode, __module__::Module, def) return Base.Experimental.var"@overlay"( __source__, __module__, :(Reactant.REACTANT_METHOD_TABLE), def ) diff --git a/src/Overrides.jl b/src/Overlay.jl similarity index 83% rename from src/Overrides.jl rename to src/Overlay.jl index 0df144d3a5..6d4752acd9 100644 --- a/src/Overrides.jl +++ b/src/Overlay.jl @@ -4,20 +4,20 @@ # we should move all the reactant_overrides to relevant files. # Compiling within a compile should return simply the original function -@reactant_override function Compiler.compile( +@reactant_overlay function Compiler.compile( f, args; client=nothing, optimize=true, sync=false ) return f end # Enzyme overrides -@reactant_override @noinline function Enzyme.autodiff_deferred( +@reactant_overlay @noinline function Enzyme.autodiff_deferred( rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs} ) where {FA<:Annotation,A<:Annotation,Nargs} return overload_autodiff(rmode, f, rt, args...) end -@reactant_override @noinline function Enzyme.autodiff( +@reactant_overlay @noinline function Enzyme.autodiff( rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs} ) where {FA<:Annotation,A<:Annotation,Nargs} return overload_autodiff(rmode, f, rt, args...) diff --git a/src/Reactant.jl b/src/Reactant.jl index c611e49e3e..5f6de5d7bf 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -147,7 +147,7 @@ include("ControlFlow.jl") include("Tracing.jl") include("Compiler.jl") -include("Overrides.jl") +include("Overlay.jl") function Enzyme.make_zero( ::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false) From b25303886ce65260749cd3adb1b38e1fe6538cb4 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 15 Dec 2024 17:41:43 +0530 Subject: [PATCH 4/5] ci: increase build time allowance --- .buildkite/pipeline.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index ec261c66a7..2d9e3710fa 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -1,7 +1,7 @@ steps: - group: ":test_tube: Tests" steps: - - label: "CUDA Julia v{{matrix.version}} -- {{matrix.group}}" + - label: ":julia: :linux: CUDA Julia v{{matrix.version}} -- {{matrix.group}}" matrix: setup: version: @@ -33,7 +33,7 @@ steps: env: REACTANT_TEST_GROUP: "{{matrix.group}}" if: build.message !~ /\[skip tests\]/ - timeout_in_minutes: 60 + timeout_in_minutes: 120 - label: ":julia: :linux: aarch64 - Julia v{{matrix.version}} -- {{matrix.group}}" matrix: @@ -70,7 +70,7 @@ steps: env: REACTANT_TEST_GROUP: "{{matrix.group}}" if: build.message !~ /\[skip tests\]/ - timeout_in_minutes: 60 + timeout_in_minutes: 120 - group: ":racehorse: Benchmarks" steps: From 361c3f422d8f41a9a9036fbbb5e682721d7eabfa Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 16 Dec 2024 20:25:47 +0530 Subject: [PATCH 5/5] Revert "fix: traced_getfield move to TracedUtils" This reverts commit f35d911715c14361b99996866a81c2719d6c13c1. --- src/Compiler.jl | 6 +++++- src/ConcreteRArray.jl | 17 +++++++++++++++++ src/Interpreter.jl | 2 +- src/Reactant.jl | 23 +++-------------------- src/TracedUtils.jl | 9 +-------- 5 files changed, 27 insertions(+), 30 deletions(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index 8267ec5e94..cc32a90b1a 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -13,7 +13,11 @@ import ..Reactant: TracedToConcrete, append_path, TracedType -import ..TracedUtils: TracedUtils, traced_getfield + +@inline traced_getfield(@nospecialize(obj), field) = Base.getfield(obj, field) +@inline traced_getfield( + @nospecialize(obj::AbstractArray{<:Union{ConcreteRNumber,ConcreteRArray}}), field +) = Base.getindex(obj, field) function create_result(tocopy::T, path, result_stores) where {T} if !isstructtype(typeof(tocopy)) diff --git a/src/ConcreteRArray.jl b/src/ConcreteRArray.jl index 9a072a8071..e9d9c02d7f 100644 --- a/src/ConcreteRArray.jl +++ b/src/ConcreteRArray.jl @@ -1,3 +1,20 @@ +struct XLAArray{T,N} <: RArray{T,N} + # size::NTuple{N,Int} +end + +mutable struct ConcreteRArray{T,N} <: RArray{T,N} + data::XLA.AsyncBuffer + # data::XLAArray{T, N} + shape::NTuple{N,Int} +end + +const WrappedConcreteRArray{T,N} = WrappedArray{T,N,ConcreteRArray,ConcreteRArray{T,N}} +const AnyConcreteRArray{T,N} = Union{ConcreteRArray{T,N},WrappedConcreteRArray{T,N}} + +mutable struct ConcreteRNumber{T} <: RNumber{T} + data::XLA.AsyncBuffer +end + function ConcreteRNumber{T}( data::T2; client=XLA.default_backend[], idx=XLA.default_device_idx[], device=nothing ) where {T<:Number,T2<:Number} diff --git a/src/Interpreter.jl b/src/Interpreter.jl index 7fdfc10dec..4b71a13413 100644 --- a/src/Interpreter.jl +++ b/src/Interpreter.jl @@ -202,7 +202,7 @@ function set_act!(inp, path, reverse, tostore; emptypath=false) end for p in path - x = TracedUtils.traced_getfield(x, p) + x = traced_getfield(x, p) end #if inp isa Enzyme.Active || !reverse diff --git a/src/Reactant.jl b/src/Reactant.jl index 5f6de5d7bf..e7c8805de9 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -114,25 +114,6 @@ mutable struct TracedRNumber{T} <: RNumber{T} end end -struct XLAArray{T,N} <: RArray{T,N} - # size::NTuple{N,Int} -end - -mutable struct ConcreteRArray{T,N} <: RArray{T,N} - data::XLA.AsyncBuffer - # data::XLAArray{T, N} - shape::NTuple{N,Int} -end - -const WrappedConcreteRArray{T,N} = WrappedArray{T,N,ConcreteRArray,ConcreteRArray{T,N}} -const AnyConcreteRArray{T,N} = Union{ConcreteRArray{T,N},WrappedConcreteRArray{T,N}} - -mutable struct ConcreteRNumber{T} <: RNumber{T} - data::XLA.AsyncBuffer -end - -const TracedType = Union{TracedRArray,TracedRNumber,MissingTracedValue} - include("Ops.jl") include("TracedUtils.jl") @@ -143,6 +124,8 @@ include("ConcreteRArray.jl") include("linear_algebra.jl") +const TracedType = Union{TracedRArray,TracedRNumber,MissingTracedValue} + include("ControlFlow.jl") include("Tracing.jl") include("Compiler.jl") @@ -163,7 +146,7 @@ function Enzyme.make_zero( return res end -using .Compiler: @compile, @code_hlo, @jit, create_result, compile +using .Compiler: @compile, @code_hlo, @jit, traced_getfield, create_result, compile export ConcreteRArray, ConcreteRNumber, @compile, @code_hlo, @jit, @trace const registry = Ref{MLIR.IR.DialectRegistry}() diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index c5ae212055..c4a839ab71 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -13,19 +13,12 @@ using ..Reactant: WrappedTracedRArray, AnyTracedRArray, MissingTracedValue, - OrderedIdDict, - ConcreteRArray, - ConcreteRNumber + OrderedIdDict import ..Reactant import ..Reactant.MLIR import ..ReactantPrimitive import ..Ops -@inline traced_getfield(@nospecialize(obj), field) = Base.getfield(obj, field) -@inline traced_getfield( - @nospecialize(obj::AbstractArray{<:Union{ConcreteRNumber,ConcreteRArray}}), field -) = Base.getindex(obj, field) - materialize_traced_array(x::TracedRArray) = x materialize_traced_array(x::WrappedTracedRArray) = x[axes(x)...] function materialize_traced_array(