Skip to content

Perf optimizations and inferred intersections #14605

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions lib/elixir/lib/module/types/apply.ex
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ defmodule Module.Types.Apply do
{union(type, fun_from_non_overlapping_clauses(clauses)), fallback?, context}

{{:infer, _, clauses}, context} when length(clauses) <= @max_clauses ->
{union(type, fun_from_overlapping_clauses(clauses)), fallback?, context}
{union(type, fun_from_inferred_clauses(clauses)), fallback?, context}

{_, context} ->
{type, true, context}
Expand Down Expand Up @@ -705,7 +705,7 @@ defmodule Module.Types.Apply do
result =
case info do
{:infer, _, clauses} when length(clauses) <= @max_clauses ->
fun_from_overlapping_clauses(clauses)
fun_from_inferred_clauses(clauses)

_ ->
dynamic(fun(arity))
Expand Down
135 changes: 101 additions & 34 deletions lib/elixir/lib/module/types/descr.ex
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ defmodule Module.Types.Descr do
@not_non_empty_list Map.delete(@term, :list)
@not_list Map.replace!(@not_non_empty_list, :bitmap, @bit_top - @bit_empty_list)

@empty_intersection [0, @none, []]
@empty_difference [0, []]
@empty_intersection [0, @none, [], :fun_bottom]
@empty_difference [0, [], :fun_bottom]

# Type definitions

Expand Down Expand Up @@ -137,16 +137,17 @@ defmodule Module.Types.Descr do
@doc """
Creates a function from overlapping function clauses.
"""
def fun_from_overlapping_clauses(args_clauses) do
def fun_from_inferred_clauses(args_clauses) do
domain_clauses =
Enum.reduce(args_clauses, [], fn {args, return}, acc ->
pivot_overlapping_clause(args_to_domain(args), return, acc)
domain = args |> Enum.map(&upper_bound/1) |> args_to_domain()
pivot_overlapping_clause(domain, upper_bound(return), acc)
end)

funs =
for {domain, return} <- domain_clauses,
args <- domain_to_args(domain),
do: fun(args, return)
do: fun(args, dynamic(return))

Enum.reduce(funs, &intersection/2)
end
Expand Down Expand Up @@ -200,19 +201,19 @@ defmodule Module.Types.Descr do
def domain_to_args(descr) do
case :maps.take(:dynamic, descr) do
:error ->
tuple_elim_negations_static(descr, &Function.identity/1)
unwrap_domain_tuple(descr, fn {:closed, elems} -> elems end)

{dynamic, static} ->
tuple_elim_negations_static(static, &Function.identity/1) ++
tuple_elim_negations_static(dynamic, fn elems -> Enum.map(elems, &dynamic/1) end)
unwrap_domain_tuple(static, fn {:closed, elems} -> elems end) ++
unwrap_domain_tuple(dynamic, fn {:closed, elems} -> Enum.map(elems, &dynamic/1) end)
end
end

defp tuple_elim_negations_static(%{tuple: dnf} = descr, transform) when map_size(descr) == 1 do
Enum.map(dnf, fn {:closed, elements} -> transform.(elements) end)
defp unwrap_domain_tuple(%{tuple: dnf} = descr, transform) when map_size(descr) == 1 do
Enum.map(dnf, transform)
end

defp tuple_elim_negations_static(descr, _transform) when descr == %{}, do: []
defp unwrap_domain_tuple(descr, _transform) when descr == %{}, do: []

defp domain_to_flat_args(domain, arity) do
case domain_to_args(domain) do
Expand Down Expand Up @@ -1170,6 +1171,7 @@ defmodule Module.Types.Descr do

static_arrows == [] ->
# TODO: We need to validate this within the theory
arguments = Enum.map(arguments, &upper_bound/1)
{:ok, dynamic(fun_apply_static(arguments, dynamic_arrows, false))}

true ->
Expand Down Expand Up @@ -1324,9 +1326,9 @@ defmodule Module.Types.Descr do
if subtype?(rets_reached, result), do: result, else: union(result, rets_reached)
end

defp aux_apply(result, input, returns_reached, [{dom, ret} | arrow_intersections]) do
defp aux_apply(result, input, returns_reached, [{args, ret} | arrow_intersections]) do
# Calculate the part of the input not covered by this arrow's domain
dom_subtract = difference(input, args_to_domain(dom))
dom_subtract = difference(input, args_to_domain(args))

# Refine the return type by intersecting with this arrow's return type
ret_refine = intersection(returns_reached, ret)
Expand Down Expand Up @@ -1423,7 +1425,7 @@ defmodule Module.Types.Descr do
# determines emptiness.
length(neg_arguments) == positive_arity and
subtype?(args_to_domain(neg_arguments), positive_domain) and
phi_starter(neg_arguments, negation(neg_return), positives)
phi_starter(neg_arguments, neg_return, positives)
end)
end
end
Expand Down Expand Up @@ -1461,27 +1463,75 @@ defmodule Module.Types.Descr do
#
# See [Castagna and Lanvin (2024)](https://arxiv.org/abs/2408.14345), Theorem 4.2.
defp phi_starter(arguments, return, positives) do
n = length(arguments)
# Arity mismatch: if there is one positive function with a different arity,
# then it cannot be a subtype of the (arguments->type) functions.
if Enum.any?(positives, fn {args, _ret} -> length(args) != n end) do
false
# Optimization: When all positive functions have non-empty domains,
# we can simplify the phi function check to a direct subtyping test.
# This avoids the expensive recursive phi computation by checking only that applying the
# input to the positive intersection yields a subtype of the return
if all_non_empty_domains?([{arguments, return} | positives]) do
fun_apply_static(arguments, [positives], false)
|> subtype?(return)
else
arguments = Enum.map(arguments, &{false, &1})
phi(arguments, {false, return}, positives)
n = length(arguments)
# Arity mismatch: functions with different arities cannot be subtypes
# of the target function type (arguments -> return)
if Enum.any?(positives, fn {args, _ret} -> length(args) != n end) do
false
else
# Initialize memoization cache for the recursive phi computation
arguments = Enum.map(arguments, &{false, &1})
{result, _cache} = phi(arguments, {false, negation(return)}, positives, %{})
result
end
end
end

defp phi(args, {b, t}, []) do
Enum.any?(args, fn {bool, typ} -> bool and empty?(typ) end) or (b and empty?(t))
defp phi(args, {b, t}, [], cache) do
{Enum.any?(args, fn {bool, typ} -> bool and empty?(typ) end) or (b and empty?(t)), cache}
end

defp phi(args, {b, ret}, [{arguments, return} | rest_positive]) do
phi(args, {true, intersection(ret, return)}, rest_positive) and
Enum.all?(Enum.with_index(arguments), fn {type, index} ->
List.update_at(args, index, fn {_, arg} -> {true, difference(arg, type)} end)
|> phi({b, ret}, rest_positive)
end)
defp phi(args, {b, ret}, [{arguments, return} | rest_positive], cache) do
# Create cache key from function arguments
cache_key = {args, {b, ret}, [{arguments, return} | rest_positive]}

case Map.get(cache, cache_key) do
nil ->
# Compute result and cache it
{result1, cache} = phi(args, {true, intersection(ret, return)}, rest_positive, cache)

if not result1 do
# Store false result in cache
cache = Map.put(cache, cache_key, false)
{false, cache}
else
# This doesn't stop if one intermediate result is false?
{result2, cache} =
Enum.with_index(arguments)
|> Enum.reduce_while({true, cache}, fn {type, index}, {acc_result, acc_cache} ->
{new_result, new_cache} =
List.update_at(args, index, fn {_, arg} -> {true, difference(arg, type)} end)
|> phi({b, ret}, rest_positive, acc_cache)

if new_result do
{:cont, {acc_result and new_result, new_cache}}
else
{:halt, {false, new_cache}}
end
end)

result = result1 and result2
# Store result in cache
cache = Map.put(cache, cache_key, result)
{result, cache}
end

cached_result ->
# Return cached result
{cached_result, cache}
end
end

defp all_non_empty_domains?(positives) do
Enum.all?(positives, fn {args, _ret} -> not empty?(args_to_domain(args)) end)
end

defp fun_union(bdd1, bdd2) do
Expand Down Expand Up @@ -1828,6 +1878,10 @@ defmodule Module.Types.Descr do
# b) If only the last type differs, subtracts it
# 3. Base case: adds dnf2 type to negations of dnf1 type
# The result may be larger than the initial dnf1, which is maintained in the accumulator.
defp list_difference(_, dnf) when dnf == @non_empty_list_top do
0
end

defp list_difference(dnf1, dnf2) do
Enum.reduce(dnf2, dnf1, fn {t2, last2, negs2}, acc_dnf1 ->
last2 = list_tail_unfold(last2)
Expand Down Expand Up @@ -1855,6 +1909,8 @@ defmodule Module.Types.Descr do
end)
end

defp list_empty?(@non_empty_list_top), do: false

defp list_empty?(dnf) do
Enum.all?(dnf, fn {list_type, last_type, negs} ->
last_type = list_tail_unfold(last_type)
Expand Down Expand Up @@ -2115,9 +2171,6 @@ defmodule Module.Types.Descr do

defp dynamic_to_quoted(descr, opts) do
cond do
descr == %{} ->
[]

# We check for :term literally instead of using term_type?
# because we check for term_type? in to_quoted before we
# compute the difference(dynamic, static).
Expand All @@ -2127,6 +2180,9 @@ defmodule Module.Types.Descr do
single = indivisible_bitmap(descr, opts) ->
[single]

empty?(descr) ->
[]

true ->
case non_term_type_to_quoted(descr, opts) do
{:none, _meta, []} = none -> [none]
Expand Down Expand Up @@ -2395,6 +2451,10 @@ defmodule Module.Types.Descr do
if empty?(type), do: throw(:empty), else: type
end

defp map_difference(_, dnf) when dnf == @map_top do
0
end

defp map_difference(dnf1, dnf2) do
Enum.reduce(dnf2, dnf1, fn
# Optimization: we are removing an open map with one field.
Expand Down Expand Up @@ -3045,10 +3105,15 @@ defmodule Module.Types.Descr do
zip_non_empty_intersection!(rest1, rest2, [non_empty_intersection!(type1, type2) | acc])
end

defp tuple_difference(_, dnf) when dnf == @tuple_top do
0
end

defp tuple_difference(dnf1, dnf2) do
Enum.reduce(dnf2, dnf1, fn {tag2, elements2}, dnf1 ->
Enum.reduce(dnf1, [], fn {tag1, elements1}, acc ->
tuple_eliminate_single_negation(tag1, elements1, {tag2, elements2}) ++ acc
tuple_eliminate_single_negation(tag1, elements1, {tag2, elements2})
|> tuple_union(acc)
end)
end)
end
Expand All @@ -3063,8 +3128,10 @@ defmodule Module.Types.Descr do
if (tag == :closed and n < m) or (neg_tag == :closed and n > m) do
[{tag, elements}]
else
tuple_elim_content([], tag, elements, neg_elements) ++
tuple_union(
tuple_elim_content([], tag, elements, neg_elements),
tuple_elim_size(n, m, tag, elements, neg_tag)
)
end
end

Expand Down
8 changes: 6 additions & 2 deletions lib/elixir/lib/module/types/expr.ex
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ defmodule Module.Types.Expr do
add_inferred(acc, args, body)
end)

{fun_from_overlapping_clauses(acc), context}
{fun_from_inferred_clauses(acc), context}
end
end

Expand Down Expand Up @@ -461,7 +461,11 @@ defmodule Module.Types.Expr do
{args_types, context} =
Enum.map_reduce(args, context, &of_expr(&1, @pending, &1, stack, &2))

Apply.fun_apply(fun_type, args_types, call, stack, context)
if stack.mode == :traversal do
{dynamic(), context}
else
Apply.fun_apply(fun_type, args_types, call, stack, context)
end
end

def of_expr({{:., _, [callee, key_or_fun]}, meta, []} = call, expected, expr, stack, context)
Expand Down
Loading