From 77679f8cd31d79a08244ece52b72133460aec247 Mon Sep 17 00:00:00 2001 From: hauntsaninja Date: Wed, 1 Jan 2025 01:42:42 -0800 Subject: [PATCH 1/6] Fix nondeterministic type checking involving noncommutative join See https://github.com/python/mypy/issues/16979#issuecomment-1982283536 --- mypy/join.py | 16 ++++++++--- test-data/unit/check-inference.test | 41 +++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 3 deletions(-) diff --git a/mypy/join.py b/mypy/join.py index 166434f58f8d..c87ff1cf5b15 100644 --- a/mypy/join.py +++ b/mypy/join.py @@ -8,7 +8,7 @@ import mypy.typeops from mypy.expandtype import expand_type from mypy.maptype import map_instance_to_supertype -from mypy.nodes import CONTRAVARIANT, COVARIANT, INVARIANT, VARIANCE_NOT_READY +from mypy.nodes import CONTRAVARIANT, COVARIANT, INVARIANT, VARIANCE_NOT_READY, TypeInfo from mypy.state import state from mypy.subtypes import ( SubtypeContext, @@ -168,9 +168,19 @@ def join_instances_via_supertype(self, t: Instance, s: Instance) -> ProperType: # Compute the "best" supertype of t when joined with s. # The definition of "best" may evolve; for now it is the one with # the longest MRO. Ties are broken by using the earlier base. - best: ProperType | None = None + + # Go over both sets of bases in case there's an explicit Protocol base. This is important + # to ensure commutativity of join (although in cases where both classes have relevant + # Protocol bases this maybe might still not be commutative) + base_types: dict[TypeInfo, None] = {} for base in t.type.bases: - mapped = map_instance_to_supertype(t, base.type) + base_types[base.type] = None + for base in s.type.bases: + base_types[base.type] = None + + best: ProperType | None = None + for base_type in base_types: + mapped = map_instance_to_supertype(t, base_type) res = self.join_instances(mapped, s) if best is None or is_better(res, best): best = res diff --git a/test-data/unit/check-inference.test b/test-data/unit/check-inference.test index 0da1c092efe8..657797ffe5cc 100644 --- a/test-data/unit/check-inference.test +++ b/test-data/unit/check-inference.test @@ -3886,3 +3886,44 @@ def a4(x: List[str], y: List[Never]) -> None: reveal_type(z2) # N: Revealed type is "builtins.list[builtins.object]" z1[1].append("asdf") # E: "object" has no attribute "append" [builtins fixtures/dict.pyi] + + +[case testNonDeterminismFromNonCommuativeJoinInvolvingProtocolBaseAndPromotableType] +# flags: --python-version 3.11 +# Regression test for https://github.com/python/mypy/issues/16979#issuecomment-1982246306 +from __future__ import annotations + +from typing import Any, Generic, Protocol, TypeVar, overload, cast +from typing_extensions import Never + +T = TypeVar("T") +U = TypeVar("U") + +class _SupportsCompare(Protocol): + def __lt__(self, other: Any, /) -> bool: + return True + +class Comparable(_SupportsCompare): + pass + +class A(Generic[T, U]): + @overload + def __init__(self: A[T, T], a: T, b: T, /) -> None: ... # type: ignore[overload-overlap] + @overload + def __init__(self: A[T, U], a: T, b: U, /) -> Never: ... + def __init__(self, *a) -> None: ... + +comparable: Comparable = Comparable() + +from typing import _promote + +class floatlike: + def __lt__(self, other: floatlike, /) -> bool: ... + +@_promote(floatlike) +class intlike: + def __lt__(self, other: intlike, /) -> bool: ... + +reveal_type(A(intlike(), comparable)) # N: Revealed type is "__main__.A[__main__._SupportsCompare, __main__._SupportsCompare]" +[builtins fixtures/tuple.pyi] +[typing fixtures/typing-medium.pyi] From 3ef78e88b98263d2e2ac31ade358d7986b067ad7 Mon Sep 17 00:00:00 2001 From: hauntsaninja Date: Sun, 25 May 2025 15:25:07 -0700 Subject: [PATCH 2/6] fix --- mypy/join.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mypy/join.py b/mypy/join.py index 5757f8b6a332..1c0b0069dafb 100644 --- a/mypy/join.py +++ b/mypy/join.py @@ -172,11 +172,12 @@ def join_instances_via_supertype(self, t: Instance, s: Instance) -> ProperType: # Go over both sets of bases in case there's an explicit Protocol base. This is important # to ensure commutativity of join (although in cases where both classes have relevant # Protocol bases this maybe might still not be commutative) - base_types: dict[TypeInfo, None] = {} + base_types: dict[TypeInfo, None] = {} # dict to deduplicate but preserve order for base in t.type.bases: base_types[base.type] = None for base in s.type.bases: - base_types[base.type] = None + if is_subtype(t, base): + base_types[base.type] = None best: ProperType | None = None for base_type in base_types: From 0b18c18416a345d5fd473d0a05c1d8704a110318 Mon Sep 17 00:00:00 2001 From: hauntsaninja Date: Sun, 25 May 2025 15:54:41 -0700 Subject: [PATCH 3/6] efficiency --- mypy/join.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mypy/join.py b/mypy/join.py index 1c0b0069dafb..75e32a69d35c 100644 --- a/mypy/join.py +++ b/mypy/join.py @@ -176,7 +176,7 @@ def join_instances_via_supertype(self, t: Instance, s: Instance) -> ProperType: for base in t.type.bases: base_types[base.type] = None for base in s.type.bases: - if is_subtype(t, base): + if base.type.is_protocol and is_subtype(t, base): base_types[base.type] = None best: ProperType | None = None From 69c1a654ef8179b13b6536f22577b94d029e6de5 Mon Sep 17 00:00:00 2001 From: hauntsaninja Date: Sun, 25 May 2025 16:46:24 -0700 Subject: [PATCH 4/6] fix false positive --- mypy/join.py | 3 +++ test-data/unit/check-protocols.test | 23 +++++++++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/mypy/join.py b/mypy/join.py index 75e32a69d35c..819c6f2df1e0 100644 --- a/mypy/join.py +++ b/mypy/join.py @@ -671,6 +671,9 @@ def is_better(t: Type, s: Type) -> bool: if isinstance(t, Instance): if not isinstance(s, Instance): return True + if t.type.is_protocol != s.type.is_protocol and s.type.fullname != "builtins.object": + # mro of protocol is not really relevant + return not t.type.is_protocol # Use len(mro) as a proxy for the better choice. if len(t.type.mro) > len(s.type.mro): return True diff --git a/test-data/unit/check-protocols.test b/test-data/unit/check-protocols.test index 34e3f3e88080..9aa0747b2b76 100644 --- a/test-data/unit/check-protocols.test +++ b/test-data/unit/check-protocols.test @@ -4460,3 +4460,26 @@ f2(a4) # E: Argument 1 to "f2" has incompatible type "A4"; expected "P2" \ # N: foo: expected "B1", got "str" \ # N: foo: expected setter type "C1", got "str" [builtins fixtures/property.pyi] + + +[case testExplicitProtocolJoinPreference] +from typing import Protocol, TypeVar + +T = TypeVar("T") + +class Proto1(Protocol): + def foo(self) -> int: ... +class Proto2(Proto1): + def bar(self) -> str: ... +class Proto3(Proto2): + def baz(self) -> str: ... + +class Base: ... + +class A(Base, Proto3): ... +class B(Base, Proto3): ... + +def join(a: T, b: T) -> T: ... + +def main(a: A, b: B) -> None: + reveal_type(join(a, b)) # N: Revealed type is "__main__.Proto3" From 34e0a0c71e2b709b85314e0b4cadf270845e5dae Mon Sep 17 00:00:00 2001 From: hauntsaninja Date: Sun, 25 May 2025 16:51:10 -0700 Subject: [PATCH 5/6] fix test --- mypy/join.py | 7 ++++--- test-data/unit/check-inference.test | 21 +++++++++++++-------- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/mypy/join.py b/mypy/join.py index 819c6f2df1e0..0a011b4b3077 100644 --- a/mypy/join.py +++ b/mypy/join.py @@ -671,9 +671,10 @@ def is_better(t: Type, s: Type) -> bool: if isinstance(t, Instance): if not isinstance(s, Instance): return True - if t.type.is_protocol != s.type.is_protocol and s.type.fullname != "builtins.object": - # mro of protocol is not really relevant - return not t.type.is_protocol + if t.type.is_protocol != s.type.is_protocol: + if t.type.fullname != "builtins.object" and s.type.fullname != "builtins.object": + # mro of protocol is not really relevant + return not t.type.is_protocol # Use len(mro) as a proxy for the better choice. if len(t.type.mro) > len(s.type.mro): return True diff --git a/test-data/unit/check-inference.test b/test-data/unit/check-inference.test index e0435a4145a6..6d8cc901b0c1 100644 --- a/test-data/unit/check-inference.test +++ b/test-data/unit/check-inference.test @@ -3889,7 +3889,7 @@ def a4(x: List[str], y: List[Never]) -> None: [builtins fixtures/dict.pyi] -[case testNonDeterminismFromNonCommuativeJoinInvolvingProtocolBaseAndPromotableType] +[case testDeterminismCommutativityWithJoinInvolvingProtocolBaseAndPromotableType] # flags: --python-version 3.11 # Regression test for https://github.com/python/mypy/issues/16979#issuecomment-1982246306 from __future__ import annotations @@ -3907,13 +3907,6 @@ class _SupportsCompare(Protocol): class Comparable(_SupportsCompare): pass -class A(Generic[T, U]): - @overload - def __init__(self: A[T, T], a: T, b: T, /) -> None: ... # type: ignore[overload-overlap] - @overload - def __init__(self: A[T, U], a: T, b: U, /) -> Never: ... - def __init__(self, *a) -> None: ... - comparable: Comparable = Comparable() from typing import _promote @@ -3925,9 +3918,21 @@ class floatlike: class intlike: def __lt__(self, other: intlike, /) -> bool: ... + +class A(Generic[T, U]): + @overload + def __init__(self: A[T, T], a: T, b: T, /) -> None: ... # type: ignore[overload-overlap] + @overload + def __init__(self: A[T, U], a: T, b: U, /) -> Never: ... + def __init__(self, *a) -> None: ... + +def join(a: T, b: T) -> T: ... + +reveal_type(join(intlike(), comparable)) # N: Revealed type is "__main__._SupportsCompare" reveal_type(A(intlike(), comparable)) # N: Revealed type is "__main__.A[__main__._SupportsCompare, __main__._SupportsCompare]" [builtins fixtures/tuple.pyi] [typing fixtures/typing-medium.pyi] + [case testTupleJoinFallbackInference] foo = [ (1, ("a", "b")), From 4149882cfb530129d2cae3c47714ad0bc7cb41a5 Mon Sep 17 00:00:00 2001 From: hauntsaninja Date: Sat, 31 May 2025 17:32:08 -0700 Subject: [PATCH 6/6] test --- test-data/unit/check-inference.test | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test-data/unit/check-inference.test b/test-data/unit/check-inference.test index 49caa1ccb954..4cf24ef9cb6c 100644 --- a/test-data/unit/check-inference.test +++ b/test-data/unit/check-inference.test @@ -3929,7 +3929,9 @@ class A(Generic[T, U]): def join(a: T, b: T) -> T: ... reveal_type(join(intlike(), comparable)) # N: Revealed type is "__main__._SupportsCompare" +reveal_type(join(comparable, intlike())) # N: Revealed type is "__main__._SupportsCompare" reveal_type(A(intlike(), comparable)) # N: Revealed type is "__main__.A[__main__._SupportsCompare, __main__._SupportsCompare]" +reveal_type(A(comparable, intlike())) # N: Revealed type is "__main__.A[__main__._SupportsCompare, __main__._SupportsCompare]" [builtins fixtures/tuple.pyi] [typing fixtures/typing-medium.pyi]