Skip to content

Fix nondeterministic type checking caused by nonassociativity of joins #19147

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 26 additions & 13 deletions mypy/solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from mypy.constraints import SUBTYPE_OF, SUPERTYPE_OF, Constraint, infer_constraints, neg_op
from mypy.expandtype import expand_type
from mypy.graph_utils import prepare_sccs, strongly_connected_components, topsort
from mypy.join import join_types
from mypy.join import join_type_list
from mypy.meet import meet_type_list, meet_types
from mypy.subtypes import is_subtype
from mypy.typeops import get_all_type_vars
Expand Down Expand Up @@ -247,10 +247,16 @@ def solve_iteratively(
return solutions


def _join_sorted_key(t: Type) -> int:
t = get_proper_type(t)
if isinstance(t, UnionType):
return -1
return 0


def solve_one(lowers: Iterable[Type], uppers: Iterable[Type]) -> Type | None:
"""Solve constraints by finding by using meets of upper bounds, and joins of lower bounds."""
bottom: Type | None = None
top: Type | None = None

candidate: Type | None = None

# Filter out previous results of failed inference, they will only spoil the current pass...
Expand All @@ -267,19 +273,26 @@ def solve_one(lowers: Iterable[Type], uppers: Iterable[Type]) -> Type | None:
candidate.ambiguous = True
return candidate

bottom: Type | None = None
top: Type | None = None

# Process each bound separately, and calculate the lower and upper
# bounds based on constraints. Note that we assume that the constraint
# targets do not have constraint references.
for target in lowers:
if bottom is None:
bottom = target
else:
if type_state.infer_unions:
# This deviates from the general mypy semantics because
# recursive types are union-heavy in 95% of cases.
bottom = UnionType.make_union([bottom, target])
else:
bottom = join_types(bottom, target)
if type_state.infer_unions:
# This deviates from the general mypy semantics because
# recursive types are union-heavy in 95% of cases.
bottom = UnionType.make_union(list(lowers))
else:
# The order of lowers is non-deterministic.
# We attempt to sort lowers because joins are non-associative. For instance:
# join(join(int, str), int | str) == join(object, int | str) == object
# join(int, join(str, int | str)) == join(int, int | str) == int | str
# Note that joins in theory should be commutative, but in practice some bugs mean this is
# also a source of non-deterministic type checking results.
sorted_lowers = sorted(lowers, key=_join_sorted_key)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Though this is probably fine, right now sorted with a key= argument is a bit slow when compiled using mypyc, so using a for loop here instead that moves union types to the front would likely be faster. It's also fine to merge this as is and only replace this with a for loop if there appears to be a measurable performance regression.

if sorted_lowers:
bottom = join_type_list(sorted_lowers)

for target in uppers:
if top is None:
Expand Down
40 changes: 40 additions & 0 deletions test-data/unit/check-generics.test
Original file line number Diff line number Diff line change
Expand Up @@ -3563,3 +3563,43 @@ def foo(x: T):
reveal_type(C) # N: Revealed type is "Overload(def [T, S] (x: builtins.int, y: S`-1) -> __main__.C[__main__.Int[S`-1]], def [T, S] (x: builtins.str, y: S`-1) -> __main__.C[__main__.Str[S`-1]])"
reveal_type(C(0, x)) # N: Revealed type is "__main__.C[__main__.Int[T`-1]]"
reveal_type(C("yes", x)) # N: Revealed type is "__main__.C[__main__.Str[T`-1]]"

[case testDeterminismFromJoinOrderingInSolver]
# Used to fail non-deterministically
# https://github.com/python/mypy/issues/19121
from __future__ import annotations
from typing import Generic, Iterable, Iterator, Self, TypeVar

_T1 = TypeVar("_T1")
_T2 = TypeVar("_T2")
_T3 = TypeVar("_T3")
_T_co = TypeVar("_T_co", covariant=True)

class Base(Iterable[_T1]):
def __iter__(self) -> Iterator[_T1]: ...
class A(Base[_T1]): ...
class B(Base[_T1]): ...
class C(Base[_T1]): ...
class D(Base[_T1]): ...
class E(Base[_T1]): ...

class zip2(Generic[_T_co]):
def __new__(
cls,
iter1: Iterable[_T1],
iter2: Iterable[_T2],
iter3: Iterable[_T3],
) -> zip2[tuple[_T1, _T2, _T3]]: ...
def __iter__(self) -> Self: ...
def __next__(self) -> _T_co: ...

def draw(
colors1: A[str] | B[str] | C[int] | D[int | str],
colors2: A[str] | B[str] | C[int] | D[int | str],
colors3: A[str] | B[str] | C[int] | D[int | str],
) -> None:
for c1, c2, c3 in zip2(colors1, colors2, colors3):
reveal_type(c1) # N: Revealed type is "Union[builtins.int, builtins.str]"
reveal_type(c2) # N: Revealed type is "Union[builtins.int, builtins.str]"
reveal_type(c3) # N: Revealed type is "Union[builtins.int, builtins.str]"
[builtins fixtures/tuple.pyi]
4 changes: 2 additions & 2 deletions test-data/unit/check-recursive-types.test
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ reveal_type(flatten([1, [2, [3]]])) # N: Revealed type is "builtins.list[builti

class Bad: ...
x: Nested[int] = [1, [2, [3]]]
x = [1, [Bad()]] # E: List item 0 has incompatible type "Bad"; expected "Union[int, Nested[int]]"
x = [1, [Bad()]] # E: List item 1 has incompatible type "List[Bad]"; expected "Union[int, Nested[int]]"
[builtins fixtures/isinstancelist.pyi]

[case testRecursiveAliasGenericInferenceNested]
Expand Down Expand Up @@ -605,7 +605,7 @@ class NT(NamedTuple, Generic[T]):
class A: ...
class B(A): ...

nti: NT[int] = NT(key=0, value=NT(key=1, value=A())) # E: Argument "value" to "NT" has incompatible type "A"; expected "Union[int, NT[int]]"
nti: NT[int] = NT(key=0, value=NT(key=1, value=A())) # E: Argument "value" to "NT" has incompatible type "NT[A]"; expected "Union[int, NT[int]]"
reveal_type(nti) # N: Revealed type is "Tuple[builtins.int, Union[builtins.int, ...], fallback=__main__.NT[builtins.int]]"

nta: NT[A]
Expand Down