320320end
321321
322322struct GPUInterpreter <: CC.AbstractInterpreter
323+ meta:: Any
323324 world:: UInt
324325 method_table:: GPUMethodTableView
325326
336337
337338@static if HAS_INTEGRATED_CACHE
338339function GPUInterpreter (world:: UInt = Base. get_world_counter ();
340+ meta = nothing ,
339341 method_table:: MTType ,
340342 token:: Any ,
341343 inf_params:: CC.InferenceParams ,
@@ -345,26 +347,28 @@ function GPUInterpreter(world::UInt=Base.get_world_counter();
345347 method_table = get_method_table_view (world, method_table)
346348 inf_cache = Vector {CC.InferenceResult} ()
347349
348- return GPUInterpreter (world, method_table,
350+ return GPUInterpreter (meta, world, method_table,
349351 token, inf_cache,
350352 inf_params, opt_params)
351353end
352354
353355function GPUInterpreter (interp:: GPUInterpreter ;
356+ meta= interp. meta,
354357 world:: UInt = interp. world,
355358 method_table:: GPUMethodTableView = interp. method_table,
356359 token:: Any = interp. token,
357360 inf_cache:: Vector{CC.InferenceResult} = interp. inf_cache,
358361 inf_params:: CC.InferenceParams = interp. inf_params,
359362 opt_params:: CC.OptimizationParams = interp. opt_params)
360- return GPUInterpreter (world, method_table,
363+ return GPUInterpreter (meta, world, method_table,
361364 token, inf_cache,
362365 inf_params, opt_params)
363366end
364367
365368else
366369
367370function GPUInterpreter (world:: UInt = Base. get_world_counter ();
371+ meta= nothing ,
368372 method_table:: MTType ,
369373 code_cache:: CodeCache ,
370374 inf_params:: CC.InferenceParams ,
@@ -374,19 +378,20 @@ function GPUInterpreter(world::UInt=Base.get_world_counter();
374378 method_table = get_method_table_view (world, method_table)
375379 inf_cache = Vector {CC.InferenceResult} ()
376380
377- return GPUInterpreter (world, method_table,
381+ return GPUInterpreter (meta, world, method_table,
378382 code_cache, inf_cache,
379383 inf_params, opt_params)
380384end
381385
382386function GPUInterpreter (interp:: GPUInterpreter ;
387+ meta= interp. meta,
383388 world:: UInt = interp. world,
384389 method_table:: GPUMethodTableView = interp. method_table,
385390 code_cache:: CodeCache = interp. code_cache,
386391 inf_cache:: Vector{CC.InferenceResult} = interp. inf_cache,
387392 inf_params:: CC.InferenceParams = interp. inf_params,
388393 opt_params:: CC.OptimizationParams = interp. opt_params)
389- return GPUInterpreter (world, method_table,
394+ return GPUInterpreter (meta, world, method_table,
390395 code_cache, inf_cache,
391396 inf_params, opt_params)
392397end
@@ -437,28 +442,76 @@ function CC.concrete_eval_eligible(interp::GPUInterpreter,
437442end
438443
439444
445+ within_gpucompiler () = false
446+
440447# # deferred compilation
441448
442449struct DeferredCallInfo <: CC.CallInfo
450+ meta:: Any
443451 rt:: DataType
444452 info:: CC.CallInfo
445453end
446454
447455# recognize calls to gpuc.deferred and save DeferredCallInfo metadata
448- function CC. abstract_call_known (interp:: GPUInterpreter , @nospecialize (f),
449- arginfo:: CC.ArgInfo , si:: CC.StmtInfo , sv:: CC.AbsIntState ,
450- max_methods:: Int = CC. get_max_methods (interp, f, sv))
456+ # default implementation, extensible through meta argument.
457+ # XXX : (or should we dispatch on `f`)?
458+ function abstract_call_known (meta:: Nothing , interp:: GPUInterpreter , @nospecialize (f),
459+ arginfo:: CC.ArgInfo , si:: CC.StmtInfo , sv:: CC.AbsIntState ,
460+ max_methods:: Int = CC. get_max_methods (interp, f, sv))
451461 (; fargs, argtypes) = arginfo
452462 if f === var"gpuc.deferred"
453- argvec = argtypes[2 : end ]
463+ argvec = argtypes[3 : end ]
454464 call = CC. abstract_call (interp, CC. ArgInfo (nothing , argvec), si, sv, max_methods)
455- callinfo = DeferredCallInfo (call. rt, call. info)
465+ metaT = argtypes[2 ]
466+ meta = CC. singleton_type (metaT)
467+ if meta === nothing
468+ if metaT isa Core. Const
469+ meta = metaT. val
470+ else
471+ # meta is not a singleton type result may depend on runtime configuration
472+ add_remark! (interp, sv, " Skipped gpuc.deferred since meta not constant" )
473+ @static if VERSION < v " 1.11.0-"
474+ return CC. CallMeta (Union{}, CC. Effects (), CC. NoCallInfo ())
475+ else
476+ return CC. CallMeta (Union{}, Union{}, CC. Effects (), CC. NoCallInfo ())
477+ end
478+ end
479+ end
480+
481+ callinfo = DeferredCallInfo (meta, call. rt, call. info)
456482 @static if VERSION < v " 1.11.0-"
457483 return CC. CallMeta (Ptr{Cvoid}, CC. Effects (), callinfo)
458484 else
459485 return CC. CallMeta (Ptr{Cvoid}, Union{}, CC. Effects (), callinfo)
460486 end
487+ elseif f === within_gpucompiler
488+ if length (argtypes) != 1
489+ @static if VERSION < v " 1.11.0-"
490+ return CC. CallMeta (Union{}, CC. Effects (), CC. NoCallInfo ())
491+ else
492+ return CC. CallMeta (Union{}, Union{}, CC. Effects (), CC. NoCallInfo ())
493+ end
494+ end
495+ @static if VERSION < v " 1.11.0-"
496+ return CC. CallMeta (Core. Const (true ), CC. EFFECTS_TOTAL, CC. MethodResultPure ())
497+ else
498+ return CC. CallMeta (Core. Const (true ), Union{}, CC. EFFECTS_TOTAL, CC. MethodResultPure (),)
499+ end
461500 end
501+ return nothing
502+ end
503+
504+ function CC. abstract_call_known (interp:: GPUInterpreter , @nospecialize (f),
505+ arginfo:: CC.ArgInfo , si:: CC.StmtInfo , sv:: CC.AbsIntState ,
506+ max_methods:: Int = CC. get_max_methods (interp, f, sv))
507+ candidate = abstract_call_known (interp. meta, interp, f, arginfo, si, sv, max_methods)
508+ if candidate === nothing && interp. meta != = nothing
509+ candidate = abstract_call_known (interp. meta, interp, f, arginfo, si, sv, max_methods)
510+ end
511+ if candidate != = nothing
512+ return candidate
513+ end
514+
462515 return @invoke CC. abstract_call_known (interp:: CC.AbstractInterpreter , f,
463516 arginfo:: CC.ArgInfo , si:: CC.StmtInfo , sv:: CC.AbsIntState ,
464517 max_methods:: Int )
@@ -485,32 +538,39 @@ function CC.handle_call!(todo::Vector{Pair{Int,Any}}, ir::CC.IRCode, idx::CC.Int
485538 args = Any[
486539 " extern gpuc.lookup" ,
487540 Ptr{Cvoid},
488- Core. svec (Any, Any, match. spec_types. parameters[2 : end ]. .. ), # Must use Any for MethodInstance or ftype
541+ Core. svec (Any, Any, Any, match. spec_types. parameters[2 : end ]. .. ), # Must use Any for MethodInstance or ftype
489542 0 ,
490543 QuoteNode (:llvmcall ),
544+ info. meta,
491545 case. invoke,
492- stmt. args[2 : end ]. ..
546+ stmt. args[3 : end ]. ..
493547 ]
494548 stmt. head = :foreigncall
495549 stmt. args = args
496550 return nothing
497551end
498552
553+ struct Edge
554+ meta:: Any
555+ mi:: MethodInstance
556+ end
557+
499558struct DeferredEdges
500- edges:: Vector{MethodInstance }
559+ edges:: Vector{Edge }
501560end
502561
503562function find_deferred_edges (ir:: CC.IRCode )
504- edges = MethodInstance []
563+ edges = Edge []
505564 # XXX : can we add this instead in handle_call?
506565 for stmt in ir. stmts
507566 inst = stmt[:inst ]
508567 inst isa Expr || continue
509568 expr = inst:: Expr
510569 if expr. head === :foreigncall &&
511570 expr. args[1 ] == " extern gpuc.lookup"
512- deferred_mi = expr. args[6 ]
513- push! (edges, deferred_mi)
571+ deferred_meta = expr. args[6 ]
572+ deferred_mi = expr. args[7 ]
573+ push! (edges, Edge (deferred_meta, deferred_mi))
514574 end
515575 end
516576 unique! (edges)
@@ -542,6 +602,116 @@ function CC.finish(interp::GPUInterpreter, opt::CC.OptimizationState, ir::CC.IRC
542602end
543603end
544604
605+ import . CC: CallInfo
606+ struct NoInlineCallInfo <: CallInfo
607+ info:: CallInfo # wrapped call
608+ tt:: Any # ::Type
609+ kind:: Symbol
610+ NoInlineCallInfo (@nospecialize (info:: CallInfo ), @nospecialize (tt), kind:: Symbol ) =
611+ new (info, tt, kind)
612+ end
613+
614+ CC. nsplit_impl (info:: NoInlineCallInfo ) = CC. nsplit (info. info)
615+ CC. getsplit_impl (info:: NoInlineCallInfo , idx:: Int ) = CC. getsplit (info. info, idx)
616+ CC. getresult_impl (info:: NoInlineCallInfo , idx:: Int ) = CC. getresult (info. info, idx)
617+ struct AlwaysInlineCallInfo <: CallInfo
618+ info:: CallInfo # wrapped call
619+ tt:: Any # ::Type
620+ AlwaysInlineCallInfo (@nospecialize (info:: CallInfo ), @nospecialize (tt)) = new (info, tt)
621+ end
622+
623+ CC. nsplit_impl (info:: AlwaysInlineCallInfo ) = Core. Compiler. nsplit (info. info)
624+ CC. getsplit_impl (info:: AlwaysInlineCallInfo , idx:: Int ) = CC. getsplit (info. info, idx)
625+ CC. getresult_impl (info:: AlwaysInlineCallInfo , idx:: Int ) = CC. getresult (info. info, idx)
626+
627+
628+ function inlining_handler (meta:: Nothing , interp:: GPUInterpreter , @nospecialize (atype), callinfo)
629+ return nothing
630+ end
631+
632+ using Core. Compiler: ArgInfo, StmtInfo, AbsIntState
633+ function CC. abstract_call_gf_by_type (interp:: GPUInterpreter , @nospecialize (f), arginfo:: ArgInfo ,
634+ si:: StmtInfo , @nospecialize (atype), sv:: AbsIntState , max_methods:: Int )
635+ ret = @invoke CC. abstract_call_gf_by_type (interp:: CC.AbstractInterpreter , f:: Any , arginfo:: ArgInfo ,
636+ si:: StmtInfo , atype:: Any , sv:: AbsIntState , max_methods:: Int )
637+
638+ callinfo = nothing
639+ if interp. meta != = nothing
640+ callinfo = inlining_handler (interp. meta, interp, atype, ret. info)
641+ end
642+ if callinfo === nothing
643+ callinfo = inlining_handler (nothing , interp, atype, ret. info)
644+ end
645+ if callinfo === nothing
646+ callinfo = ret. info
647+ end
648+
649+ @static if VERSION ≥ v " 1.11-"
650+ return CC. CallMeta (ret. rt, ret. exct, ret. effects, callinfo)
651+ else
652+ return CC. CallMeta (ret. rt, ret. effects, callinfo)
653+ end
654+ end
655+
656+ @static if VERSION < v " 1.12.0-DEV.45"
657+ let # overload `inlining_policy`
658+ @static if VERSION ≥ v " 1.11.0-DEV.879"
659+ sigs_ex = :(
660+ interp:: GPUInterpreter ,
661+ @nospecialize (src),
662+ @nospecialize (info:: CC.CallInfo ),
663+ stmt_flag:: UInt32 ,
664+ )
665+ args_ex = :(
666+ interp:: CC.AbstractInterpreter ,
667+ src:: Any ,
668+ info:: CC.CallInfo ,
669+ stmt_flag:: UInt32 ,
670+ )
671+ else
672+ sigs_ex = :(
673+ interp:: GPUInterpreter ,
674+ @nospecialize (src),
675+ @nospecialize (info:: CC.CallInfo ),
676+ stmt_flag:: UInt8 ,
677+ mi:: MethodInstance ,
678+ argtypes:: Vector{Any} ,
679+ )
680+ args_ex = :(
681+ interp:: CC.AbstractInterpreter ,
682+ src:: Any ,
683+ info:: CC.CallInfo ,
684+ stmt_flag:: UInt8 ,
685+ mi:: MethodInstance ,
686+ argtypes:: Vector{Any} ,
687+ )
688+ end
689+ @eval function CC. inlining_policy ($ (sigs_ex. args... ))
690+ if info isa NoInlineCallInfo
691+ @safe_debug " Blocking inlining" info. tt info. kind
692+ return nothing
693+ elseif info isa AlwaysInlineCallInfo
694+ @safe_debug " Forcing inlining for" info. tt
695+ return src
696+ end
697+ return @invoke CC. inlining_policy ($ (args_ex. args... ))
698+ end
699+ end
700+ else
701+ function CC. src_inlining_policy (interp:: GPUInterpreter ,
702+ @nospecialize (src), @nospecialize (info:: CC.CallInfo ), stmt_flag:: UInt32 )
703+
704+ if info isa NoInlineCallInfo
705+ @safe_debug " Blocking inlining" info. tt info. kind
706+ return false
707+ elseif info isa AlwaysInlineCallInfo
708+ @safe_debug " Forcing inlining for" info. tt
709+ return true
710+ end
711+ return @invoke CC. src_inlining_policy (interp:: CC.AbstractInterpreter , src, info:: CC.CallInfo , stmt_flag:: UInt32 )
712+ end
713+ end
714+
545715
546716# # world view of the cache
547717using Core. Compiler: WorldView
@@ -704,7 +874,7 @@ function compile_method_instance(@nospecialize(job::CompilerJob))
704874 source = pop! (worklist)
705875 haskey (compiled, source) && continue # We have fulfilled the request already
706876 # Create a new compiler job for this edge, reusing the config settings from the inital one
707- job2 = CompilerJob (source, job. config)
877+ job2 = CompilerJob (source, job. config) # TODO : GPUInterpreter.meta in config?
708878 llvm_mod2, outstanding = compile_method_instance (job2, compiled)
709879 append! (worklist, outstanding) # merge worklist with new outstanding edges
710880 @assert context (llvm_mod) == context (llvm_mod2)
@@ -844,7 +1014,9 @@ function compile_method_instance(@nospecialize(job::CompilerJob), compiled::IdDi
8441014 end
8451015 end
8461016 if edges != = nothing
847- for deferred_mi in (edges:: DeferredEdges ). edges
1017+ for edge in (edges:: DeferredEdges ). edges
1018+ # TODO
1019+ deferred_mi = edge. mi
8481020 if ! haskey (compiled, deferred_mi)
8491021 push! (outstanding, deferred_mi)
8501022 end
0 commit comments