Skip to content

Commit 1b600f0

Browse files
authored
optimizer: fully support inlining of union-split, partially constant-prop' callsite (#43347)
Makes full use of constant-propagation, by addressing this [TODO](https://github.com/JuliaLang/julia/blob/00734c5fd045316a00d287ca2c0ec1a2eef6e4d1/base/compiler/ssair/inlining.jl#L1212). Here is a performance improvement from #43287: ```julia ulia> using BenchmarkTools julia> X = rand(ComplexF32, 64, 64); julia> dst = reinterpret(reshape, Float32, X); julia> src = copy(dst); julia> @Btime copyto!($dst, $src); 50.819 μs (1 allocation: 32 bytes) # v1.6.4 41.081 μs (0 allocations: 0 bytes) # this commit ``` fixes #43287
1 parent 85a6990 commit 1b600f0

File tree

4 files changed

+115
-91
lines changed

4 files changed

+115
-91
lines changed

base/compiler/abstractinterpretation.jl

+1
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
160160
# by constant analysis, but let's create `ConstCallInfo` if there has been any successful
161161
# constant propagation happened since other consumers may be interested in this
162162
if any_const_result && seen == napplicable
163+
@assert napplicable == nmatches(info) == length(const_results)
163164
info = ConstCallInfo(info, const_results)
164165
end
165166

base/compiler/ssair/inlining.jl

+84-75
Original file line numberDiff line numberDiff line change
@@ -689,19 +689,16 @@ function rewrite_apply_exprargs!(
689689
new_sig = with_atype(call_sig(ir, new_stmt)::Signature)
690690
new_info = call.info
691691
if isa(new_info, ConstCallInfo)
692-
maybe_handle_const_call!(
692+
handle_const_call!(
693693
ir, state1.id, new_stmt, new_info, flag,
694-
new_sig, istate, todo) && @goto analyzed
695-
new_info = new_info.call # cascade to the non-constant handling
696-
end
697-
if isa(new_info, MethodMatchInfo) || isa(new_info, UnionSplitInfo)
694+
new_sig, istate, todo)
695+
elseif isa(new_info, MethodMatchInfo) || isa(new_info, UnionSplitInfo)
698696
new_infos = isa(new_info, MethodMatchInfo) ? MethodMatchInfo[new_info] : new_info.matches
699697
# See if we can inline this call to `iterate`
700698
analyze_single_call!(
701699
ir, state1.id, new_stmt, new_infos, flag,
702700
new_sig, istate, todo)
703701
end
704-
@label analyzed
705702
if i != length(thisarginfo.each)
706703
valT = getfield_tfunc(call.rt, Const(1))
707704
val_extracted = insert_node!(ir, idx, NewInstruction(
@@ -1136,139 +1133,150 @@ function process_simple!(ir::IRCode, idx::Int, state::InliningState, todo::Vecto
11361133
return stmt, sig
11371134
end
11381135

1139-
# TODO inline non-`isdispatchtuple`, union-split callsites
1136+
# TODO inline non-`isdispatchtuple`, union-split callsites?
11401137
function analyze_single_call!(
11411138
ir::IRCode, idx::Int, stmt::Expr, infos::Vector{MethodMatchInfo}, flag::UInt8,
11421139
sig::Signature, state::InliningState, todo::Vector{Pair{Int, Any}})
11431140
(; argtypes, atype) = sig
11441141
cases = InliningCase[]
11451142
local signature_union = Bottom
11461143
local only_method = nothing # keep track of whether there is one matching method
1147-
local meth
1144+
local meth::MethodLookupResult
11481145
local fully_covered = true
11491146
for i in 1:length(infos)
1150-
info = infos[i]
1151-
meth = info.results
1147+
meth = infos[i].results
11521148
if meth.ambig
11531149
# Too many applicable methods
11541150
# Or there is a (partial?) ambiguity
1155-
return
1151+
return nothing
11561152
elseif length(meth) == 0
11571153
# No applicable methods; try next union split
11581154
continue
1159-
elseif length(meth) == 1 && only_method !== false
1160-
if only_method === nothing
1161-
only_method = meth[1].method
1162-
elseif only_method !== meth[1].method
1155+
else
1156+
if length(meth) == 1 && only_method !== false
1157+
if only_method === nothing
1158+
only_method = meth[1].method
1159+
elseif only_method !== meth[1].method
1160+
only_method = false
1161+
end
1162+
else
11631163
only_method = false
11641164
end
1165-
else
1166-
only_method = false
11671165
end
11681166
for match in meth
1169-
spec_types = match.spec_types
1170-
signature_union = Union{signature_union, spec_types}
1171-
if !isdispatchtuple(spec_types)
1172-
fully_covered = false
1173-
continue
1174-
end
1175-
item = analyze_method!(match, argtypes, flag, state)
1176-
if item === nothing
1177-
fully_covered = false
1178-
continue
1179-
elseif _any(case->case.sig === spec_types, cases)
1180-
continue
1181-
end
1182-
push!(cases, InliningCase(spec_types, item))
1167+
signature_union = Union{signature_union, match.spec_types}
1168+
fully_covered &= handle_match!(match, argtypes, flag, state, cases)
11831169
end
11841170
end
11851171

1186-
# if the signature is fully or mostly covered and there is only one applicable method,
1172+
# if the signature is fully covered and there is only one applicable method,
11871173
# we can try to inline it even if the signature is not a dispatch tuple
11881174
if length(cases) == 0 && only_method isa Method
11891175
if length(infos) > 1
11901176
(metharg, methsp) = ccall(:jl_type_intersection_with_env, Any, (Any, Any),
11911177
atype, only_method.sig)::SimpleVector
11921178
match = MethodMatch(metharg, methsp::SimpleVector, only_method, true)
11931179
else
1194-
meth = meth::MethodLookupResult
11951180
@assert length(meth) == 1
11961181
match = meth[1]
11971182
end
11981183
item = analyze_method!(match, argtypes, flag, state)
1199-
item === nothing && return
1184+
item === nothing && return nothing
12001185
push!(cases, InliningCase(match.spec_types, item))
12011186
fully_covered = match.fully_covers
12021187
else
12031188
fully_covered &= atype <: signature_union
12041189
end
12051190

1206-
# If we only have one case and that case is fully covered, we may either
1207-
# be able to do the inlining now (for constant cases), or push it directly
1208-
# onto the todo list
1209-
if fully_covered && length(cases) == 1
1210-
handle_single_case!(ir, idx, stmt, cases[1].item, todo)
1211-
elseif length(cases) > 0
1212-
push!(todo, idx=>UnionSplit(fully_covered, atype, cases))
1213-
end
1214-
return nothing
1191+
handle_cases!(ir, idx, stmt, sig, cases, fully_covered, todo)
12151192
end
12161193

1217-
# try to create `InliningCase`s using constant-prop'ed results
1218-
# currently it works only when constant-prop' succeeded for all (union-split) signatures
1219-
# TODO use any of constant-prop'ed results, and leave the other unhandled cases to later
1220-
# TODO this function contains a lot of duplications with `analyze_single_call!`, factor them out
1221-
function maybe_handle_const_call!(
1222-
ir::IRCode, idx::Int, stmt::Expr, info::ConstCallInfo, flag::UInt8,
1194+
# similar to `analyze_single_call!`, but with constant results
1195+
function handle_const_call!(
1196+
ir::IRCode, idx::Int, stmt::Expr, cinfo::ConstCallInfo, flag::UInt8,
12231197
sig::Signature, state::InliningState, todo::Vector{Pair{Int, Any}})
12241198
(; argtypes, atype) = sig
1225-
results = info.results
1226-
cases = InliningCase[] # TODO avoid this allocation for single cases ?
1199+
(; call, results) = cinfo
1200+
infos = isa(call, MethodMatchInfo) ? MethodMatchInfo[call] : call.matches
1201+
cases = InliningCase[]
12271202
local fully_covered = true
12281203
local signature_union = Bottom
1229-
for result in results
1230-
isa(result, InferenceResult) || return false
1231-
(; mi) = item = InliningTodo(result, argtypes)
1232-
spec_types = mi.specTypes
1233-
signature_union = Union{signature_union, spec_types}
1234-
if !isdispatchtuple(spec_types)
1235-
fully_covered = false
1236-
continue
1237-
end
1238-
if !validate_sparams(mi.sparam_vals)
1239-
fully_covered = false
1204+
local j = 0
1205+
for i in 1:length(infos)
1206+
meth = infos[i].results
1207+
if meth.ambig
1208+
# Too many applicable methods
1209+
# Or there is a (partial?) ambiguity
1210+
return nothing
1211+
elseif length(meth) == 0
1212+
# No applicable methods; try next union split
12401213
continue
12411214
end
1242-
state.mi_cache !== nothing && (item = resolve_todo(item, state, flag))
1243-
if item === nothing
1244-
fully_covered = false
1245-
continue
1215+
for match in meth
1216+
j += 1
1217+
result = results[j]
1218+
if result === nothing
1219+
signature_union = Union{signature_union, match.spec_types}
1220+
fully_covered &= handle_match!(match, argtypes, flag, state, cases)
1221+
else
1222+
signature_union = Union{signature_union, result.linfo.specTypes}
1223+
fully_covered &= handle_const_result!(result, argtypes, flag, state, cases)
1224+
end
12461225
end
1247-
push!(cases, InliningCase(spec_types, item))
12481226
end
12491227

12501228
# if the signature is fully covered and there is only one applicable method,
12511229
# we can try to inline it even if the signature is not a dispatch tuple
12521230
if length(cases) == 0 && length(results) == 1
12531231
(; mi) = item = InliningTodo(results[1]::InferenceResult, argtypes)
12541232
state.mi_cache !== nothing && (item = resolve_todo(item, state, flag))
1255-
validate_sparams(mi.sparam_vals) || return true
1256-
item === nothing && return true
1233+
validate_sparams(mi.sparam_vals) || return nothing
1234+
item === nothing && return nothing
12571235
push!(cases, InliningCase(mi.specTypes, item))
12581236
fully_covered = atype <: mi.specTypes
12591237
else
12601238
fully_covered &= atype <: signature_union
12611239
end
12621240

1241+
handle_cases!(ir, idx, stmt, sig, cases, fully_covered, todo)
1242+
end
1243+
1244+
function handle_match!(
1245+
match::MethodMatch, argtypes::Vector{Any}, flag::UInt8, state::InliningState,
1246+
cases::Vector{InliningCase})
1247+
spec_types = match.spec_types
1248+
isdispatchtuple(spec_types) || return false
1249+
item = analyze_method!(match, argtypes, flag, state)
1250+
item === nothing && return false
1251+
_any(case->case.sig === spec_types, cases) && return true
1252+
push!(cases, InliningCase(spec_types, item))
1253+
return true
1254+
end
1255+
1256+
function handle_const_result!(
1257+
result::InferenceResult, argtypes::Vector{Any}, flag::UInt8, state::InliningState,
1258+
cases::Vector{InliningCase})
1259+
(; mi) = item = InliningTodo(result, argtypes)
1260+
spec_types = mi.specTypes
1261+
isdispatchtuple(spec_types) || return false
1262+
validate_sparams(mi.sparam_vals) || return false
1263+
state.mi_cache !== nothing && (item = resolve_todo(item, state, flag))
1264+
item === nothing && return false
1265+
push!(cases, InliningCase(spec_types, item))
1266+
return true
1267+
end
1268+
1269+
function handle_cases!(ir::IRCode, idx::Int, stmt::Expr, sig::Signature,
1270+
cases::Vector{InliningCase}, fully_covered::Bool, todo::Vector{Pair{Int, Any}})
12631271
# If we only have one case and that case is fully covered, we may either
12641272
# be able to do the inlining now (for constant cases), or push it directly
12651273
# onto the todo list
12661274
if fully_covered && length(cases) == 1
12671275
handle_single_case!(ir, idx, stmt, cases[1].item, todo)
12681276
elseif length(cases) > 0
1269-
push!(todo, idx=>UnionSplit(fully_covered, atype, cases))
1277+
push!(todo, idx=>UnionSplit(fully_covered, sig.atype, cases))
12701278
end
1271-
return true
1279+
return nothing
12721280
end
12731281

12741282
function handle_const_opaque_closure_call!(
@@ -1302,7 +1310,8 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState)
13021310
end
13031311
ir.stmts[idx][:flag] |= IR_FLAG_EFFECT_FREE
13041312
info = info.info
1305-
elseif info === false
1313+
end
1314+
if info === false
13061315
# Inference determined this couldn't be analyzed. Don't question it.
13071316
continue
13081317
end
@@ -1333,10 +1342,10 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState)
13331342
# if inference arrived here with constant-prop'ed result(s),
13341343
# we can perform a specialized analysis for just this case
13351344
if isa(info, ConstCallInfo)
1336-
maybe_handle_const_call!(
1345+
handle_const_call!(
13371346
ir, idx, stmt, info, flag,
1338-
sig, state, todo) && continue
1339-
info = info.call # cascade to the non-constant handling
1347+
sig, state, todo)
1348+
continue
13401349
end
13411350

13421351
# Ok, now figure out what method to call

base/compiler/stmtinfo.jl

+21-12
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,27 @@ struct UnionSplitInfo
3838
matches::Vector{MethodMatchInfo}
3939
end
4040

41+
nmatches(info::MethodMatchInfo) = length(info.results)
42+
function nmatches(info::UnionSplitInfo)
43+
n = 0
44+
for mminfo in info.matches
45+
n += nmatches(mminfo)
46+
end
47+
return n
48+
end
49+
50+
"""
51+
info::ConstCallInfo
52+
53+
The precision of this call was improved using constant information.
54+
In addition to the original call information `info.call`, this info also keeps
55+
the inference results with constant information `info.results::Vector{Union{Nothing,InferenceResult}}`.
56+
"""
57+
struct ConstCallInfo
58+
call::Union{MethodMatchInfo,UnionSplitInfo}
59+
results::Vector{Union{Nothing,InferenceResult}}
60+
end
61+
4162
"""
4263
info::MethodResultPure
4364
@@ -92,18 +113,6 @@ struct UnionSplitApplyCallInfo
92113
infos::Vector{ApplyCallInfo}
93114
end
94115

95-
"""
96-
info::ConstCallInfo
97-
98-
The precision of this call was improved using constant information.
99-
In addition to the original call information `info.call`, this info also keeps
100-
the inference results with constant information `info.results::Vector{Union{Nothing,InferenceResult}}`.
101-
"""
102-
struct ConstCallInfo
103-
call::Union{MethodMatchInfo,UnionSplitInfo}
104-
results::Vector{Union{Nothing,InferenceResult}}
105-
end
106-
107116
"""
108117
info::InvokeCallInfo
109118

test/compiler/inline.jl

+9-4
Original file line numberDiff line numberDiff line change
@@ -759,13 +759,18 @@ end
759759
import Base: @constprop
760760

761761
# test union-split callsite with successful and unsuccessful constant-prop' results
762-
@constprop :aggressive @inline f42840(xs, a::Int) = xs[a] # should be successful, and inlined
763-
@constprop :none @noinline f42840(xs::AbstractVector, a::Int) = xs[a] # should be unsuccessful, but still statically resolved
762+
# (also for https://github.com/JuliaLang/julia/issues/43287)
763+
@constprop :aggressive @inline f42840(cond::Bool, xs::Tuple, a::Int) = # should be successful, and inlined with constant prop' result
764+
cond ? xs[a] : @noinline(length(xs))
765+
@constprop :none @noinline f42840(::Bool, xs::AbstractVector, a::Int) = # should be unsuccessful, but still statically resolved
766+
xs[a]
764767
let src = code_typed((Union{Tuple{Int,Int,Int}, Vector{Int}},)) do xs
765-
f42840(xs, 2)
768+
f42840(true, xs, 2)
766769
end |> only |> first
767-
# `(xs::Tuple{Int,Int,Int})[a::Const(2)]` => `getfield(xs, 2)`
770+
# `f43287(true, xs::Tuple{Int,Int,Int}, 2)` => `getfield(xs, 2)`
771+
# `f43287(true, xs::Vector{Int}, 2)` => `:invoke f43287(true, xs, 2)`
768772
@test count(iscall((src, getfield)), src.code) == 1
773+
@test count(isinvoke(:length), src.code) == 0
769774
@test count(isinvoke(:f42840), src.code) == 1
770775
end
771776
# a bit weird, but should handle this kind of case as well

0 commit comments

Comments
 (0)