Skip to content

Commit 8ee672c

Browse files
JeffBezansonKristofferC
authored andcommitted
fix #30679, call correct method for invoke calls in jl_invoke fallback (#31845)
(cherry picked from commit 4e20cc2)
1 parent e5366e5 commit 8ee672c

File tree

2 files changed

+55
-28
lines changed

2 files changed

+55
-28
lines changed

src/gf.c

+41-28
Original file line numberDiff line numberDiff line change
@@ -35,28 +35,6 @@ JL_DLLEXPORT size_t jl_get_tls_world_age(void)
3535
return jl_get_ptls_states()->world_age;
3636
}
3737

38-
JL_DLLEXPORT jl_value_t *jl_invoke(jl_method_instance_t *meth, jl_value_t **args, uint32_t nargs)
39-
{
40-
jl_callptr_t fptr = meth->invoke;
41-
if (fptr != jl_fptr_trampoline) {
42-
return fptr(meth, args, nargs);
43-
}
44-
else {
45-
// if this hasn't been inferred (compiled) yet,
46-
// inferring it might not be able to handle the world range
47-
// so we just do a generic apply here
48-
// because that might actually be faster
49-
// since it can go through the unrolled caches for this world
50-
// and if inference is successful, this meth would get updated anyways,
51-
// and we'll get the fast path here next time
52-
53-
// TODO: if `meth` came from an `invoke` call, we should make sure
54-
// meth->def is called instead of doing normal dispatch.
55-
56-
return jl_apply(args, nargs);
57-
}
58-
}
59-
6038
/// ----- Handling for Julia callbacks ----- ///
6139

6240
JL_DLLEXPORT int8_t jl_is_in_pure_context(void)
@@ -2200,6 +2178,8 @@ JL_DLLEXPORT jl_value_t *jl_gf_invoke_lookup(jl_value_t *types, size_t world)
22002178
return (jl_value_t*)entry;
22012179
}
22022180

2181+
jl_value_t *jl_gf_invoke_by_method(jl_method_t *method, jl_value_t **args, size_t nargs);
2182+
22032183
// invoke()
22042184
// this does method dispatch with a set of types to match other than the
22052185
// types of the actual arguments. this means it sometimes does NOT call the
@@ -2212,13 +2192,10 @@ JL_DLLEXPORT jl_value_t *jl_gf_invoke_lookup(jl_value_t *types, size_t world)
22122192
jl_value_t *jl_gf_invoke(jl_value_t *types0, jl_value_t **args, size_t nargs)
22132193
{
22142194
size_t world = jl_get_ptls_states()->world_age;
2215-
jl_svec_t *tpenv = jl_emptysvec;
2216-
jl_tupletype_t *tt = NULL;
22172195
jl_value_t *types = NULL;
2218-
JL_GC_PUSH3(&types, &tpenv, &tt);
2196+
JL_GC_PUSH1(&types);
22192197
jl_value_t *gf = args[0];
22202198
types = jl_argtype_with_function(gf, types0);
2221-
jl_methtable_t *mt = jl_gf_mtable(gf);
22222199
jl_typemap_entry_t *entry = (jl_typemap_entry_t*)jl_gf_invoke_lookup(types, world);
22232200

22242201
if ((jl_value_t*)entry == jl_nothing) {
@@ -2228,10 +2205,19 @@ jl_value_t *jl_gf_invoke(jl_value_t *types0, jl_value_t **args, size_t nargs)
22282205

22292206
// now we have found the matching definition.
22302207
// next look for or create a specialization of this definition.
2208+
JL_GC_POP();
2209+
return jl_gf_invoke_by_method(entry->func.method, args, nargs);
2210+
}
22312211

2232-
jl_method_t *method = entry->func.method;
2212+
jl_value_t *jl_gf_invoke_by_method(jl_method_t *method, jl_value_t **args, size_t nargs)
2213+
{
2214+
size_t world = jl_get_ptls_states()->world_age;
22332215
jl_method_instance_t *mfunc = NULL;
22342216
jl_typemap_entry_t *tm = NULL;
2217+
jl_methtable_t *mt = jl_gf_mtable(args[0]);
2218+
jl_svec_t *tpenv = jl_emptysvec;
2219+
jl_tupletype_t *tt = NULL;
2220+
JL_GC_PUSH2(&tpenv, &tt);
22352221
if (method->invokes.unknown != NULL)
22362222
tm = jl_typemap_assoc_exact(method->invokes, args, nargs, jl_cachearg_offset(mt), world);
22372223
if (tm) {
@@ -2248,7 +2234,7 @@ jl_value_t *jl_gf_invoke(jl_value_t *types0, jl_value_t **args, size_t nargs)
22482234
if (method->invokes.unknown == NULL)
22492235
method->invokes.unknown = jl_nothing;
22502236

2251-
mfunc = cache_method(mt, &method->invokes, entry->func.value, tt, method, world, tpenv, 1);
2237+
mfunc = cache_method(mt, &method->invokes, (jl_value_t*)method, tt, method, world, tpenv, 1);
22522238
JL_UNLOCK(&method->writelock);
22532239
}
22542240
JL_GC_POP();
@@ -2303,6 +2289,33 @@ JL_DLLEXPORT jl_value_t *jl_get_invoke_lambda(jl_methtable_t *mt,
23032289
return (jl_value_t*)mfunc;
23042290
}
23052291

2292+
JL_DLLEXPORT jl_value_t *jl_invoke(jl_method_instance_t *meth, jl_value_t **args, uint32_t nargs)
2293+
{
2294+
jl_callptr_t fptr = meth->invoke;
2295+
if (fptr != jl_fptr_trampoline) {
2296+
return fptr(meth, args, nargs);
2297+
}
2298+
else {
2299+
// if this hasn't been inferred (compiled) yet,
2300+
// inferring it might not be able to handle the world range
2301+
// so we just do a generic apply here
2302+
// because that might actually be faster
2303+
// since it can go through the unrolled caches for this world
2304+
// and if inference is successful, this meth would get updated anyways,
2305+
// and we'll get the fast path here next time
2306+
2307+
jl_method_instance_t *mfunc = jl_lookup_generic_(args, nargs,
2308+
jl_int32hash_fast(jl_return_address()),
2309+
jl_get_ptls_states()->world_age);
2310+
// check whether `jl_apply_generic` would call the right method
2311+
if (mfunc->def.method == meth->def.method)
2312+
return mfunc->invoke(mfunc, args, nargs);
2313+
2314+
// no; came from an `invoke` call
2315+
return jl_gf_invoke_by_method(meth->def.method, args, nargs);
2316+
}
2317+
}
2318+
23062319
// Return value is rooted globally
23072320
jl_function_t *jl_new_generic_function_with_supertype(jl_sym_t *name, jl_module_t *module, jl_datatype_t *st, int iskw)
23082321
{

test/core.jl

+14
Original file line numberDiff line numberDiff line change
@@ -2432,6 +2432,20 @@ const T24460 = Tuple{T,T} where T
24322432
g24460() = invoke(f24460, T24460, 1, 2)
24332433
@test @inferred(g24460()) === 2.0
24342434

2435+
# issue #30679
2436+
@noinline function f30679(::DataType)
2437+
b = IOBuffer()
2438+
write(b, 0x00)
2439+
2
2440+
end
2441+
@noinline function f30679(t::Type{Int})
2442+
x = invoke(f30679, Tuple{DataType}, t)
2443+
b = IOBuffer()
2444+
write(b, 0x00)
2445+
return x + 40
2446+
end
2447+
@test f30679(Int) == 42
2448+
24352449
call_lambda1() = (()->x)(1)
24362450
call_lambda2() = ((x)->x)()
24372451
call_lambda3() = ((x)->x)(1,2)

0 commit comments

Comments
 (0)