Skip to content

Commit e2d7377

Browse files
committed
Revert "fix: traced_getfield move to TracedUtils"
This reverts commit f35d911.
1 parent d9c167f commit e2d7377

File tree

5 files changed

+29
-31
lines changed

5 files changed

+29
-31
lines changed

src/Compiler.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,11 @@ import ..Reactant:
1313
TracedToConcrete,
1414
append_path,
1515
TracedType
16-
import ..TracedUtils: TracedUtils, traced_getfield
16+
17+
@inline traced_getfield(@nospecialize(obj), field) = Base.getfield(obj, field)
18+
@inline traced_getfield(
19+
@nospecialize(obj::AbstractArray{<:Union{ConcreteRNumber,ConcreteRArray}}), field
20+
) = Base.getindex(obj, field)
1721

1822
function create_result(tocopy::T, path, result_stores) where {T}
1923
if !isstructtype(typeof(tocopy))

src/ConcreteRArray.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,20 @@
1+
struct XLAArray{T,N} <: RArray{T,N}
2+
# size::NTuple{N,Int}
3+
end
4+
5+
mutable struct ConcreteRArray{T,N} <: RArray{T,N}
6+
data::XLA.AsyncBuffer
7+
# data::XLAArray{T, N}
8+
shape::NTuple{N,Int}
9+
end
10+
11+
const WrappedConcreteRArray{T,N} = WrappedArray{T,N,ConcreteRArray,ConcreteRArray{T,N}}
12+
const AnyConcreteRArray{T,N} = Union{ConcreteRArray{T,N},WrappedConcreteRArray{T,N}}
13+
14+
mutable struct ConcreteRNumber{T} <: RNumber{T}
15+
data::XLA.AsyncBuffer
16+
end
17+
118
function ConcreteRNumber{T}(
219
data::T2; client=XLA.default_backend[], idx=XLA.default_device_idx[], device=nothing
320
) where {T<:Number,T2<:Number}

src/Interpreter.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ function set_act!(inp, path, reverse, tostore; emptypath=false)
202202
end
203203

204204
for p in path
205-
x = TracedUtils.traced_getfield(x, p)
205+
x = traced_getfield(x, p)
206206
end
207207

208208
#if inp isa Enzyme.Active || !reverse

src/Reactant.jl

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -114,25 +114,6 @@ mutable struct TracedRNumber{T} <: RNumber{T}
114114
end
115115
end
116116

117-
struct XLAArray{T,N} <: RArray{T,N}
118-
# size::NTuple{N,Int}
119-
end
120-
121-
mutable struct ConcreteRArray{T,N} <: RArray{T,N}
122-
data::XLA.AsyncBuffer
123-
# data::XLAArray{T, N}
124-
shape::NTuple{N,Int}
125-
end
126-
127-
const WrappedConcreteRArray{T,N} = WrappedArray{T,N,ConcreteRArray,ConcreteRArray{T,N}}
128-
const AnyConcreteRArray{T,N} = Union{ConcreteRArray{T,N},WrappedConcreteRArray{T,N}}
129-
130-
mutable struct ConcreteRNumber{T} <: RNumber{T}
131-
data::XLA.AsyncBuffer
132-
end
133-
134-
const TracedType = Union{TracedRArray,TracedRNumber,MissingTracedValue}
135-
136117
include("Ops.jl")
137118
include("TracedUtils.jl")
138119

@@ -143,6 +124,8 @@ include("ConcreteRArray.jl")
143124

144125
include("linear_algebra.jl")
145126

127+
const TracedType = Union{TracedRArray,TracedRNumber,MissingTracedValue}
128+
146129
include("ControlFlow.jl")
147130
include("Tracing.jl")
148131
include("Compiler.jl")
@@ -163,7 +146,7 @@ function Enzyme.make_zero(
163146
return res
164147
end
165148

166-
using .Compiler: @compile, @code_hlo, @jit, create_result, compile
149+
using .Compiler: @compile, @code_hlo, @jit, traced_getfield, create_result, compile
167150
export ConcreteRArray, ConcreteRNumber, @compile, @code_hlo, @jit, @trace
168151

169152
const registry = Ref{MLIR.IR.DialectRegistry}()

src/TracedUtils.jl

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,12 @@ using ..Reactant:
1414
AnyTracedRArray,
1515
MissingTracedValue,
1616
OrderedIdDict,
17-
ConcreteRArray,
18-
ConcreteRNumber
17+
Compiler
1918
import ..Reactant
2019
import ..Reactant.MLIR
2120
import ..ReactantPrimitive
2221
import ..Ops
2322

24-
@inline traced_getfield(@nospecialize(obj), field) = Base.getfield(obj, field)
25-
@inline traced_getfield(
26-
@nospecialize(obj::AbstractArray{<:Union{ConcreteRNumber,ConcreteRArray}}), field
27-
) = Base.getindex(obj, field)
28-
2923
materialize_traced_array(x::TracedRArray) = x
3024
materialize_traced_array(x::WrappedTracedRArray) = x[axes(x)...]
3125
function materialize_traced_array(
@@ -330,7 +324,7 @@ end
330324

331325
function push_val!(ad_inputs, x, path)
332326
for p in path
333-
x = traced_getfield(x, p)
327+
x = Compiler.traced_getfield(x, p)
334328
end
335329
x = x.mlir_data
336330
return push!(ad_inputs, x)
@@ -350,7 +344,7 @@ end
350344

351345
function set!(x, path, tostore; emptypath=false)
352346
for p in path
353-
x = traced_getfield(x, p)
347+
x = Compiler.traced_getfield(x, p)
354348
end
355349

356350
x.mlir_data = tostore

0 commit comments

Comments
 (0)