Skip to content

Commit 2253843

Browse files
authored
[GPUCodegen] Re-infer deleted methods if needed (#35655)
1 parent b72e191 commit 2253843

File tree

6 files changed

+37
-9
lines changed

6 files changed

+37
-9
lines changed

src/aotcompile.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -250,9 +250,12 @@ static void makeSafeName(GlobalObject &G)
250250

251251

252252
// takes the running content that has collected in the shadow module and dump it to disk
253-
// this builds the object file portion of the sysimage files for fast startup
253+
// this builds the object file portion of the sysimage files for fast startup, and can
254+
// also be used be extern consumers like GPUCompiler.jl to obtain a module containing
255+
// all reachable & inferrrable functions. The `policy` flag switches between the defaul
256+
// mode `0` and the extern mode `1`.
254257
extern "C" JL_DLLEXPORT
255-
void *jl_create_native(jl_array_t *methods, const jl_cgparams_t cgparams)
258+
void *jl_create_native(jl_array_t *methods, const jl_cgparams_t cgparams, int _policy)
256259
{
257260
jl_native_code_desc_t *data = new jl_native_code_desc_t;
258261
jl_codegen_params_t params;
@@ -263,12 +266,17 @@ void *jl_create_native(jl_array_t *methods, const jl_cgparams_t cgparams)
263266
JL_GC_PUSH1(&src);
264267
JL_LOCK(&codegen_lock);
265268

269+
CompilationPolicy policy = (CompilationPolicy) _policy;
270+
266271
// compile all methods for the current world and type-inference world
267272
size_t compile_for[] = { jl_typeinf_world, jl_world_counter };
268273
for (int worlds = 0; worlds < 2; worlds++) {
269274
params.world = compile_for[worlds];
270275
if (!params.world)
271276
continue;
277+
// Don't emit methods for the typeinf_world with extern policy
278+
if (policy == CompilationPolicy::Extern && params.world == jl_typeinf_world)
279+
continue;
272280
size_t i, l;
273281
for (i = 0, l = jl_array_len(methods); i < l; i++) {
274282
// each item in this list is either a MethodInstance indicating something
@@ -311,8 +319,9 @@ void *jl_create_native(jl_array_t *methods, const jl_cgparams_t cgparams)
311319
}
312320
}
313321
}
322+
314323
// finally, make sure all referenced methods also get compiled or fixed up
315-
jl_compile_workqueue(emitted, params);
324+
jl_compile_workqueue(emitted, params, policy);
316325
}
317326
JL_GC_POP();
318327

src/codegen.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6534,8 +6534,10 @@ jl_compile_result_t jl_emit_codeinst(
65346534

65356535
void jl_compile_workqueue(
65366536
std::map<jl_code_instance_t*, jl_compile_result_t> &emitted,
6537-
jl_codegen_params_t &params)
6537+
jl_codegen_params_t &params, CompilationPolicy policy)
65386538
{
6539+
jl_code_info_t *src = NULL;
6540+
JL_GC_PUSH1(&src);
65396541
while (!params.workqueue.empty()) {
65406542
jl_code_instance_t *codeinst;
65416543
Function *protodecl;
@@ -6565,7 +6567,17 @@ void jl_compile_workqueue(
65656567
decls = &std::get<1>(result);
65666568
}
65676569
else {
6568-
result = jl_emit_codeinst(codeinst, NULL, params);
6570+
// Reinfer the function. The JIT came along and removed the inferred
6571+
// method body. See #34993
6572+
if (policy == CompilationPolicy::Extern &&
6573+
codeinst->inferred && codeinst->inferred == jl_nothing) {
6574+
src = jl_type_infer(codeinst->def, jl_world_counter, 0);
6575+
if (src)
6576+
result = jl_emit_code(codeinst->def, src, src->rettype, params);
6577+
}
6578+
else {
6579+
result = jl_emit_codeinst(codeinst, NULL, params);
6580+
}
65696581
if (std::get<0>(result))
65706582
decls = &std::get<1>(result);
65716583
else
@@ -6619,6 +6631,7 @@ void jl_compile_workqueue(
66196631
}
66206632
}
66216633
}
6634+
JL_GC_POP();
66226635
}
66236636

66246637

src/jitlayers.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ static jl_callptr_t _jl_compile_codeinst(
9797
jl_compile_result_t result = jl_emit_codeinst(codeinst, src, params);
9898
if (std::get<0>(result))
9999
emitted[codeinst] = std::move(result);
100-
jl_compile_workqueue(emitted, params);
100+
jl_compile_workqueue(emitted, params, CompilationPolicy::Default);
101101

102102
jl_add_to_ee();
103103
StringMap<std::unique_ptr<Module>*> NewExports;

src/jitlayers.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,15 @@ jl_compile_result_t jl_emit_codeinst(
7575
jl_code_info_t *src,
7676
jl_codegen_params_t &params);
7777

78+
enum CompilationPolicy {
79+
Default = 0,
80+
Extern = 1
81+
};
82+
7883
void jl_compile_workqueue(
7984
std::map<jl_code_instance_t*, jl_compile_result_t> &emitted,
80-
jl_codegen_params_t &params);
85+
jl_codegen_params_t &params,
86+
CompilationPolicy policy);
8187

8288
Function *jl_cfunction_object(jl_function_t *f, jl_value_t *rt, jl_tupletype_t *argt,
8389
jl_codegen_params_t &params);

src/julia_internal.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -615,7 +615,7 @@ JL_DLLEXPORT jl_value_t *jl_dump_fptr_asm(uint64_t fptr, int raw_mc, const char*
615615
JL_DLLEXPORT jl_value_t *jl_dump_llvm_asm(void *F, const char* asm_variant, const char *debuginfo);
616616
JL_DLLEXPORT jl_value_t *jl_dump_function_ir(void *f, char strip_ir_metadata, char dump_module, const char *debuginfo);
617617

618-
void *jl_create_native(jl_array_t *methods, const jl_cgparams_t cgparams);
618+
void *jl_create_native(jl_array_t *methods, const jl_cgparams_t cgparams, int policy);
619619
void jl_dump_native(void *native_code,
620620
const char *bc_fname, const char *unopt_bc_fname, const char *obj_fname, const char *asm_fname,
621621
const char *sysimg_data, size_t sysimg_len);

src/precompile.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ void *jl_precompile(int all)
405405
}
406406
}
407407
m = NULL;
408-
void *native_code = jl_create_native(m2, jl_default_cgparams);
408+
void *native_code = jl_create_native(m2, jl_default_cgparams, 0);
409409
JL_GC_POP();
410410
return native_code;
411411
}

0 commit comments

Comments
 (0)