Skip to content

Commit a10eed0

Browse files
committed
fix #30679, call correct method for invoke calls in jl_invoke fallback
1 parent 8b7c88c commit a10eed0

File tree

1 file changed

+41
-28
lines changed

1 file changed

+41
-28
lines changed

src/gf.c

+41-28
Original file line numberDiff line numberDiff line change
@@ -35,28 +35,6 @@ JL_DLLEXPORT size_t jl_get_tls_world_age(void)
3535
return jl_get_ptls_states()->world_age;
3636
}
3737

38-
JL_DLLEXPORT jl_value_t *jl_invoke(jl_method_instance_t *meth, jl_value_t **args, uint32_t nargs)
39-
{
40-
jl_callptr_t fptr = meth->invoke;
41-
if (fptr != jl_fptr_trampoline) {
42-
return fptr(meth, args, nargs);
43-
}
44-
else {
45-
// if this hasn't been inferred (compiled) yet,
46-
// inferring it might not be able to handle the world range
47-
// so we just do a generic apply here
48-
// because that might actually be faster
49-
// since it can go through the unrolled caches for this world
50-
// and if inference is successful, this meth would get updated anyways,
51-
// and we'll get the fast path here next time
52-
53-
// TODO: if `meth` came from an `invoke` call, we should make sure
54-
// meth->def is called instead of doing normal dispatch.
55-
56-
return jl_apply(args, nargs);
57-
}
58-
}
59-
6038
/// ----- Handling for Julia callbacks ----- ///
6139

6240
JL_DLLEXPORT int8_t jl_is_in_pure_context(void)
@@ -2285,6 +2263,8 @@ JL_DLLEXPORT jl_value_t *jl_gf_invoke_lookup(jl_value_t *types JL_PROPAGATES_ROO
22852263
return (jl_value_t*)entry;
22862264
}
22872265

2266+
jl_value_t *jl_gf_invoke_by_method(jl_method_t *method, jl_value_t **args, size_t nargs);
2267+
22882268
// invoke()
22892269
// this does method dispatch with a set of types to match other than the
22902270
// types of the actual arguments. this means it sometimes does NOT call the
@@ -2297,13 +2277,10 @@ JL_DLLEXPORT jl_value_t *jl_gf_invoke_lookup(jl_value_t *types JL_PROPAGATES_ROO
22972277
jl_value_t *jl_gf_invoke(jl_value_t *types0, jl_value_t **args, size_t nargs)
22982278
{
22992279
size_t world = jl_get_ptls_states()->world_age;
2300-
jl_svec_t *tpenv = jl_emptysvec;
2301-
jl_tupletype_t *tt = NULL;
23022280
jl_value_t *types = NULL;
2303-
JL_GC_PUSH3(&types, &tpenv, &tt);
2281+
JL_GC_PUSH1(&types);
23042282
jl_value_t *gf = args[0];
23052283
types = jl_argtype_with_function(gf, types0);
2306-
jl_methtable_t *mt = jl_gf_mtable(gf);
23072284
jl_typemap_entry_t *entry = (jl_typemap_entry_t*)jl_gf_invoke_lookup(types, world);
23082285

23092286
if ((jl_value_t*)entry == jl_nothing) {
@@ -2313,10 +2290,19 @@ jl_value_t *jl_gf_invoke(jl_value_t *types0, jl_value_t **args, size_t nargs)
23132290

23142291
// now we have found the matching definition.
23152292
// next look for or create a specialization of this definition.
2293+
JL_GC_POP();
2294+
return jl_gf_invoke_by_method(entry->func.method, args, nargs);
2295+
}
23162296

2317-
jl_method_t *method = entry->func.method;
2297+
jl_value_t *jl_gf_invoke_by_method(jl_method_t *method, jl_value_t **args, size_t nargs)
2298+
{
2299+
size_t world = jl_get_ptls_states()->world_age;
23182300
jl_method_instance_t *mfunc = NULL;
23192301
jl_typemap_entry_t *tm = NULL;
2302+
jl_methtable_t *mt = jl_gf_mtable(args[0]);
2303+
jl_svec_t *tpenv = jl_emptysvec;
2304+
jl_tupletype_t *tt = NULL;
2305+
JL_GC_PUSH2(&tpenv, &tt);
23202306
if (method->invokes != NULL)
23212307
tm = jl_typemap_assoc_exact(method->invokes, args, nargs, jl_cachearg_offset(mt), world);
23222308
if (tm) {
@@ -2333,7 +2319,7 @@ jl_value_t *jl_gf_invoke(jl_value_t *types0, jl_value_t **args, size_t nargs)
23332319
if (method->invokes == NULL)
23342320
method->invokes = jl_nothing;
23352321

2336-
mfunc = cache_method(mt, &method->invokes, entry->func.value, tt, method, world, tpenv, 1);
2322+
mfunc = cache_method(mt, &method->invokes, (jl_value_t*)method, tt, method, world, tpenv, 1);
23372323
JL_UNLOCK(&method->writelock);
23382324
}
23392325
JL_GC_POP();
@@ -2389,6 +2375,33 @@ JL_DLLEXPORT jl_value_t *jl_get_invoke_lambda(jl_methtable_t *mt,
23892375
return (jl_value_t*)mfunc;
23902376
}
23912377

2378+
JL_DLLEXPORT jl_value_t *jl_invoke(jl_method_instance_t *meth, jl_value_t **args, uint32_t nargs)
2379+
{
2380+
jl_callptr_t fptr = meth->invoke;
2381+
if (fptr != jl_fptr_trampoline) {
2382+
return fptr(meth, args, nargs);
2383+
}
2384+
else {
2385+
// if this hasn't been inferred (compiled) yet,
2386+
// inferring it might not be able to handle the world range
2387+
// so we just do a generic apply here
2388+
// because that might actually be faster
2389+
// since it can go through the unrolled caches for this world
2390+
// and if inference is successful, this meth would get updated anyways,
2391+
// and we'll get the fast path here next time
2392+
2393+
jl_method_instance_t *mfunc = jl_lookup_generic_(args, nargs,
2394+
jl_int32hash_fast(jl_return_address()),
2395+
jl_get_ptls_states()->world_age);
2396+
// check whether `jl_apply_generic` would call the right method
2397+
if (mfunc->def.method == meth->def.method)
2398+
return mfunc->invoke(mfunc, args, nargs);
2399+
2400+
// no; came from an `invoke` call
2401+
return jl_gf_invoke_by_method(meth->def.method, args, nargs);
2402+
}
2403+
}
2404+
23922405
// Return value is rooted globally
23932406
jl_function_t *jl_new_generic_function_with_supertype(jl_sym_t *name, jl_module_t *module, jl_datatype_t *st, int iskw)
23942407
{

0 commit comments

Comments
 (0)