@@ -2,16 +2,10 @@ module ReactantNNlibExt
22
33using NNlib
44using GPUArraysCore: @allowscalar
5- using Reactant:
6- Reactant,
7- Ops,
8- TracedRArray,
9- AnyTracedRArray,
10- materialize_traced_array,
11- MLIR,
12- TracedRNumber,
13- get_mlir_data,
14- set_mlir_data!
5+ using Reactant: Reactant, Ops, TracedRArray, AnyTracedRArray, MLIR, TracedRNumber
6+
7+ using Reactant. TracedUtils: materialize_traced_array, get_mlir_data, set_mlir_data!
8+
159using ReactantCore: @trace
1610using LinearAlgebra: LinearAlgebra, triu
1711
@@ -238,9 +232,9 @@ function NNlib.batched_mul!(
238232 if size (x, 3 ) != size (y, 3 )
239233 B = max (size (x, 3 ), size (y, 3 ))
240234 if size (x, 3 ) == 1
241- x = Reactant. broadcast_to_size (x, (size (x, 1 ), size (x, 2 ), B))
235+ x = Reactant. TracedUtils . broadcast_to_size (x, (size (x, 1 ), size (x, 2 ), B))
242236 elseif size (y, 3 ) == 1
243- y = Reactant. broadcast_to_size (y, (size (y, 1 ), size (y, 2 ), B))
237+ y = Reactant. TracedUtils . broadcast_to_size (y, (size (y, 1 ), size (y, 2 ), B))
244238 end
245239 end
246240
@@ -250,9 +244,9 @@ function NNlib.batched_mul!(
250244 if size (x, 1 ) != size (y, 1 )
251245 B = max (size (x, 1 ), size (y, 1 ))
252246 if size (x, 1 ) == 1
253- x = Reactant. broadcast_to_size (x, (B, size (x, 2 ), size (x, 3 )))
247+ x = Reactant. TracedUtils . broadcast_to_size (x, (B, size (x, 2 ), size (x, 3 )))
254248 elseif size (y, 1 ) == 1
255- y = Reactant. broadcast_to_size (y, (B, size (y, 2 ), size (y, 3 )))
249+ y = Reactant. TracedUtils . broadcast_to_size (y, (B, size (y, 2 ), size (y, 3 )))
256250 end
257251 end
258252
270264function NNlib. pad_constant (
271265 x:: AnyTracedRArray{T,N} , pad:: NTuple{N,Tuple{Int,Int}} , value
272266) where {T,N}
273- value = Reactant. promote_to (TracedRNumber{T}, value)
267+ value = Reactant. TracedUtils . promote_to (TracedRNumber{T}, value)
274268 low = [i[1 ] for i in pad]
275269 high = [i[2 ] for i in pad]
276270 interior = [0 for i in pad]
@@ -329,7 +323,8 @@ function NNlib.gather!(dst::TracedRArray, src::AnyTracedRArray, idxs::AbstractAr
329323 start_sizes = ntuple (i -> size (src, i), dims)
330324 results = map (CartesianIndices (idxs)) do k
331325 res = @allowscalar src[colons... , Tuple (idxs[k])... ]
332- res isa TracedRNumber && (res = Reactant. broadcast_to_size (res, (1 ,)))
326+ res isa TracedRNumber &&
327+ (res = Reactant. TracedUtils. broadcast_to_size (res, (1 ,)))
333328 return reshape (res, start_sizes... , :)
334329 end
335330 res = reshape (cat (results... ; dims= (dims + 1 )), size (dst))
0 commit comments