Skip to content

Commit 8f1b9d3

Browse files
JeffBezansonKristofferC
authored andcommitted
fix #30679, call correct method for invoke calls in jl_invoke fallback (#30880)
(cherry picked from commit f97c443)
1 parent 1f87b5e commit 8f1b9d3

File tree

2 files changed

+55
-28
lines changed

2 files changed

+55
-28
lines changed

src/gf.c

+41-28
Original file line numberDiff line numberDiff line change
@@ -35,28 +35,6 @@ JL_DLLEXPORT size_t jl_get_tls_world_age(void)
3535
return jl_get_ptls_states()->world_age;
3636
}
3737

38-
JL_DLLEXPORT jl_value_t *jl_invoke(jl_method_instance_t *meth, jl_value_t **args, uint32_t nargs)
39-
{
40-
jl_callptr_t fptr = meth->invoke;
41-
if (fptr != jl_fptr_trampoline) {
42-
return fptr(meth, args, nargs);
43-
}
44-
else {
45-
// if this hasn't been inferred (compiled) yet,
46-
// inferring it might not be able to handle the world range
47-
// so we just do a generic apply here
48-
// because that might actually be faster
49-
// since it can go through the unrolled caches for this world
50-
// and if inference is successful, this meth would get updated anyways,
51-
// and we'll get the fast path here next time
52-
53-
// TODO: if `meth` came from an `invoke` call, we should make sure
54-
// meth->def is called instead of doing normal dispatch.
55-
56-
return jl_apply(args, nargs);
57-
}
58-
}
59-
6038
/// ----- Handling for Julia callbacks ----- ///
6139

6240
JL_DLLEXPORT int8_t jl_is_in_pure_context(void)
@@ -2235,6 +2213,8 @@ JL_DLLEXPORT jl_value_t *jl_gf_invoke_lookup(jl_value_t *types JL_PROPAGATES_ROO
22352213
return (jl_value_t*)entry;
22362214
}
22372215

2216+
jl_value_t *jl_gf_invoke_by_method(jl_method_t *method, jl_value_t **args, size_t nargs);
2217+
22382218
// invoke()
22392219
// this does method dispatch with a set of types to match other than the
22402220
// types of the actual arguments. this means it sometimes does NOT call the
@@ -2247,13 +2227,10 @@ JL_DLLEXPORT jl_value_t *jl_gf_invoke_lookup(jl_value_t *types JL_PROPAGATES_ROO
22472227
jl_value_t *jl_gf_invoke(jl_value_t *types0, jl_value_t **args, size_t nargs)
22482228
{
22492229
size_t world = jl_get_ptls_states()->world_age;
2250-
jl_svec_t *tpenv = jl_emptysvec;
2251-
jl_tupletype_t *tt = NULL;
22522230
jl_value_t *types = NULL;
2253-
JL_GC_PUSH3(&types, &tpenv, &tt);
2231+
JL_GC_PUSH1(&types);
22542232
jl_value_t *gf = args[0];
22552233
types = jl_argtype_with_function(gf, types0);
2256-
jl_methtable_t *mt = jl_gf_mtable(gf);
22572234
jl_typemap_entry_t *entry = (jl_typemap_entry_t*)jl_gf_invoke_lookup(types, world);
22582235

22592236
if ((jl_value_t*)entry == jl_nothing) {
@@ -2263,10 +2240,19 @@ jl_value_t *jl_gf_invoke(jl_value_t *types0, jl_value_t **args, size_t nargs)
22632240

22642241
// now we have found the matching definition.
22652242
// next look for or create a specialization of this definition.
2243+
JL_GC_POP();
2244+
return jl_gf_invoke_by_method(entry->func.method, args, nargs);
2245+
}
22662246

2267-
jl_method_t *method = entry->func.method;
2247+
jl_value_t *jl_gf_invoke_by_method(jl_method_t *method, jl_value_t **args, size_t nargs)
2248+
{
2249+
size_t world = jl_get_ptls_states()->world_age;
22682250
jl_method_instance_t *mfunc = NULL;
22692251
jl_typemap_entry_t *tm = NULL;
2252+
jl_methtable_t *mt = jl_gf_mtable(args[0]);
2253+
jl_svec_t *tpenv = jl_emptysvec;
2254+
jl_tupletype_t *tt = NULL;
2255+
JL_GC_PUSH2(&tpenv, &tt);
22702256
if (method->invokes != NULL)
22712257
tm = jl_typemap_assoc_exact(method->invokes, args, nargs, jl_cachearg_offset(mt), world);
22722258
if (tm) {
@@ -2283,7 +2269,7 @@ jl_value_t *jl_gf_invoke(jl_value_t *types0, jl_value_t **args, size_t nargs)
22832269
if (method->invokes == NULL)
22842270
method->invokes = jl_nothing;
22852271

2286-
mfunc = cache_method(mt, &method->invokes, entry->func.value, tt, method, world, tpenv, 1);
2272+
mfunc = cache_method(mt, &method->invokes, (jl_value_t*)method, tt, method, world, tpenv, 1);
22872273
JL_UNLOCK(&method->writelock);
22882274
}
22892275
JL_GC_POP();
@@ -2339,6 +2325,33 @@ JL_DLLEXPORT jl_value_t *jl_get_invoke_lambda(jl_methtable_t *mt,
23392325
return (jl_value_t*)mfunc;
23402326
}
23412327

2328+
JL_DLLEXPORT jl_value_t *jl_invoke(jl_method_instance_t *meth, jl_value_t **args, uint32_t nargs)
2329+
{
2330+
jl_callptr_t fptr = meth->invoke;
2331+
if (fptr != jl_fptr_trampoline) {
2332+
return fptr(meth, args, nargs);
2333+
}
2334+
else {
2335+
// if this hasn't been inferred (compiled) yet,
2336+
// inferring it might not be able to handle the world range
2337+
// so we just do a generic apply here
2338+
// because that might actually be faster
2339+
// since it can go through the unrolled caches for this world
2340+
// and if inference is successful, this meth would get updated anyways,
2341+
// and we'll get the fast path here next time
2342+
2343+
jl_method_instance_t *mfunc = jl_lookup_generic_(args, nargs,
2344+
jl_int32hash_fast(jl_return_address()),
2345+
jl_get_ptls_states()->world_age);
2346+
// check whether `jl_apply_generic` would call the right method
2347+
if (mfunc->def.method == meth->def.method)
2348+
return mfunc->invoke(mfunc, args, nargs);
2349+
2350+
// no; came from an `invoke` call
2351+
return jl_gf_invoke_by_method(meth->def.method, args, nargs);
2352+
}
2353+
}
2354+
23422355
// Return value is rooted globally
23432356
jl_function_t *jl_new_generic_function_with_supertype(jl_sym_t *name, jl_module_t *module, jl_datatype_t *st, int iskw)
23442357
{

test/core.jl

+14
Original file line numberDiff line numberDiff line change
@@ -2435,6 +2435,20 @@ const T24460 = Tuple{T,T} where T
24352435
g24460() = invoke(f24460, T24460, 1, 2)
24362436
@test @inferred(g24460()) === 2.0
24372437

2438+
# issue #30679
2439+
@noinline function f30679(::DataType)
2440+
b = IOBuffer()
2441+
write(b, 0x00)
2442+
2
2443+
end
2444+
@noinline function f30679(t::Type{Int})
2445+
x = invoke(f30679, Tuple{DataType}, t)
2446+
b = IOBuffer()
2447+
write(b, 0x00)
2448+
return x + 40
2449+
end
2450+
@test f30679(Int) == 42
2451+
24382452
call_lambda1() = (()->x)(1)
24392453
call_lambda2() = ((x)->x)()
24402454
call_lambda3() = ((x)->x)(1,2)

0 commit comments

Comments
 (0)