@@ -36,7 +36,7 @@ struct NeverInlineMeta <: InlineStateMeta end
3636import GPUCompiler: abstract_call_known, GPUInterpreter
3737import Core. Compiler: CallMeta, Effects, NoCallInfo, ArgInfo,
3838 StmtInfo, AbsIntState, EFFECTS_TOTAL,
39- MethodResultPure
39+ MethodResultPure, CallInfo, IRCode
4040
4141function abstract_call_known (meta:: InlineStateMeta , interp:: GPUInterpreter , @nospecialize (f),
4242 arginfo:: ArgInfo , si:: StmtInfo , sv:: AbsIntState , max_methods:: Int )
@@ -69,5 +69,143 @@ function inlining_handler(meta::InlineStateMeta, interp::GPUInterpreter, @nospec
6969 return nothing
7070end
7171
72+ struct MockEnzymeMeta end
7273
73- end
74+ # Having to define this function is annoying
75+ # introduce `abstract type InferenceMeta`
76+ function inlining_handler (meta:: MockEnzymeMeta , interp:: GPUInterpreter , @nospecialize (atype), callinfo)
77+ return nothing
78+ end
79+
80+ function autodiff end
81+
82+ import GPUCompiler: DeferredCallInfo
83+ struct AutodiffCallInfo <: CallInfo
84+ rt
85+ info:: DeferredCallInfo
86+ end
87+
88+ function abstract_call_known (meta:: MockEnzymeMeta , interp:: GPUInterpreter , @nospecialize (f),
89+ arginfo:: ArgInfo , si:: StmtInfo , sv:: AbsIntState , max_methods:: Int )
90+ (; fargs, argtypes) = arginfo
91+
92+ if f === autodiff
93+ if length (argtypes) <= 1
94+ @static if VERSION < v " 1.11.0-"
95+ return CallMeta (Union{}, Effects (), NoCallInfo ())
96+ else
97+ return CallMeta (Union{}, Union{}, Effects (), NoCallInfo ())
98+ end
99+ end
100+
101+ other_fargs = fargs === nothing ? nothing : fargs[2 : end ]
102+ other_arginfo = ArgInfo (other_fargs, argtypes[2 : end ])
103+ call = Core. Compiler. abstract_call (interp, other_arginfo, si, sv, max_methods)
104+ callinfo = DeferredCallInfo (MockEnzymeMeta (), call. rt, call. info)
105+
106+ # Real Enzyme must compute `rt` and `exct` according to enzyme semantics
107+ # and likely perform a unwrapping of fargs...
108+ rt = call. rt
109+
110+ # TODO : Edges? Effects?
111+ @static if VERSION < v " 1.11.0-"
112+ # Can't use call.effects since otherwise this call might be just replaced with rt
113+ return CallMeta (rt, Effects (), AutodiffCallInfo (rt, callinfo))
114+ else
115+ return CallMeta (rt, call. exct, Effects (), AutodiffCallInfo (rt, callinfo))
116+ end
117+ end
118+
119+ return nothing
120+ end
121+
122+ import Core. Compiler: insert_node!, NewInstruction, ReturnNode, Instruction, InliningState, Signature
123+
124+ # We really need a Compiler stdlib
125+ Base. getindex (ir:: IRCode , i) = Core. Compiler. getindex (ir, i)
126+ Base. setindex! (inst:: Instruction , val, i) = Core. Compiler. setindex! (inst, val, i)
127+
128+ const FlagType = VERSION >= v " 1.11.0-" ? UInt32 : UInt8
129+ function Core. Compiler. handle_call! (todo:: Vector{Pair{Int,Any}} , ir:: IRCode , stmt_idx:: Int ,
130+ stmt:: Expr , info:: AutodiffCallInfo , flag:: FlagType ,
131+ sig:: Signature , state:: InliningState )
132+
133+ # Goal:
134+ # The IR we want to inline here is:
135+ # unpack the args ..
136+ # ptr = gpuc.deferred(MockEnzymeMeta(), f, primal_args...)
137+ # ret = ccall("extern __autodiff", llvmcall, RT, Tuple{Ptr{Cvoid, args...}}, ptr, adjoint_args...)
138+
139+ # 0. Obtain primal mi from DeferredCallInfo
140+ # TODO : remove this code duplication
141+ deferred_info = info. info
142+ minfo = deferred_info. info
143+ results = minfo. results
144+ if length (results. matches) != 1
145+ return nothing
146+ end
147+ match = only (results. matches)
148+
149+ # lookup the target mi with correct edge tracking
150+ # TODO : Effects?
151+ case = Core. Compiler. compileable_specialization (
152+ match, Core. Compiler. Effects (), Core. Compiler. InliningEdgeTracker (state), info)
153+ @assert case isa Core. Compiler. InvokeCase
154+ @assert stmt. head === :call
155+
156+ # Now create the IR we want to inline
157+ ir = Core. Compiler. IRCode () # contains a placeholder
158+ args = [Core. Compiler. Argument (i) for i in 2 : length (stmt. args)] # f, args...
159+ idx = 0
160+
161+ # 0. Enzyme proper: Desugar args
162+ primal_args = args
163+ primal_argtypes = match. spec_types. parameters[2 : end ]
164+
165+ adjoint_rt = info. rt
166+ adjoint_args = args # TODO
167+ adjoint_argtypes = primal_argtypes
168+
169+ # 1: Since Julia's inliner goes bottom up we need to pretend that we inlined the deferred call
170+ expr = Expr (:foreigncall ,
171+ " extern gpuc.lookup" ,
172+ Ptr{Cvoid},
173+ Core. svec (#= meta=# Any, #= mi=# Any, #= f=# Any, primal_argtypes... ), # Must use Any for MethodInstance or ftype
174+ 0 ,
175+ QuoteNode (:llvmcall ),
176+ deferred_info. meta,
177+ case. invoke,
178+ primal_args...
179+ )
180+ ptr = insert_node! (ir, (idx += 1 ), NewInstruction (expr, Ptr{Cvoid}))
181+
182+ # 2. Call to magic `__autodiff`
183+ expr = Expr (:foreigncall ,
184+ " extern __autodiff" ,
185+ adjoint_rt,
186+ Core. svec (Any, Ptr{Cvoid}, adjoint_argtypes... ),
187+ 0 ,
188+ QuoteNode (:llvmcall ),
189+ ptr,
190+ adjoint_args...
191+ )
192+ ret = insert_node! (ir, idx, NewInstruction (expr, adjoint_rt))
193+
194+ # Finally replace placeholder return
195+ ir[Core. SSAValue (1 )][:inst ] = Core. ReturnNode (ret)
196+ ir[Core. SSAValue (1 )][:type ] = Ptr{Cvoid}
197+
198+ ir = Core. Compiler. compact! (ir)
199+
200+ # which mi to use here?
201+ # push inlining todos
202+ # TODO : Effects
203+ # aviatesk mentioned using inlining_policy instead...
204+ itodo = Core. Compiler. InliningTodo (case. invoke, ir, Core. Compiler. Effects ())
205+ @assert itodo. linear_inline_eligible
206+ push! (todo, (stmt_idx=> itodo))
207+
208+ return nothing
209+ end
210+
211+ end # module
0 commit comments