From cfe6da970ad17f1c13f68624c79c1e19653616df Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Sun, 2 Feb 2025 19:03:31 +0000 Subject: [PATCH 1/3] Fix inference when class and instance match protocol --- mypy/constraints.py | 48 +++++++++++++++++++------------- test-data/unit/check-enum.test | 22 +++++++++++++++ test-data/unit/fixtures/enum.pyi | 9 +++++- 3 files changed, 58 insertions(+), 21 deletions(-) diff --git a/mypy/constraints.py b/mypy/constraints.py index 45a96b993563..e7d93c091bf9 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -756,40 +756,40 @@ def visit_instance(self, template: Instance) -> list[Constraint]: "__call__", template, actual, is_operator=True ) assert call is not None - if mypy.subtypes.is_subtype(actual, erase_typevars(call)): - subres = infer_constraints(call, actual, self.direction) - res.extend(subres) + if ( + self.direction == SUPERTYPE_OF + and mypy.subtypes.is_subtype(actual, erase_typevars(call)) + or self.direction == SUBTYPE_OF + and mypy.subtypes.is_subtype(erase_typevars(call), actual) + ): + res.extend(infer_constraints(call, actual, self.direction)) template.type.inferring.pop() if isinstance(actual, CallableType) and actual.fallback is not None: - if actual.is_type_obj() and template.type.is_protocol: + if ( + actual.is_type_obj() + and template.type.is_protocol + and self.direction == SUPERTYPE_OF + ): ret_type = get_proper_type(actual.ret_type) if isinstance(ret_type, TupleType): ret_type = mypy.typeops.tuple_fallback(ret_type) if isinstance(ret_type, Instance): - if self.direction == SUBTYPE_OF: - subtype = template - else: - subtype = ret_type res.extend( self.infer_constraints_from_protocol_members( - ret_type, template, subtype, template, class_obj=True + ret_type, template, ret_type, template, class_obj=True ) ) actual = actual.fallback if isinstance(actual, TypeType) and template.type.is_protocol: - if isinstance(actual.item, Instance): - if self.direction == SUBTYPE_OF: - subtype = template - else: - subtype = actual.item - res.extend( - self.infer_constraints_from_protocol_members( - actual.item, template, subtype, template, class_obj=True - ) - ) if self.direction == SUPERTYPE_OF: - # Infer constraints for Type[T] via metaclass of T when it makes sense. a_item = actual.item + if isinstance(a_item, Instance): + res.extend( + self.infer_constraints_from_protocol_members( + a_item, template, a_item, template, class_obj=True + ) + ) + # Infer constraints for Type[T] via metaclass of T when it makes sense. if isinstance(a_item, TypeVarType): a_item = get_proper_type(a_item.upper_bound) if isinstance(a_item, Instance) and a_item.type.metaclass_type: @@ -1043,6 +1043,14 @@ def infer_constraints_from_protocol_members( return [] # See #11020 # The above is safe since at this point we know that 'instance' is a subtype # of (erased) 'template', therefore it defines all protocol members + if class_obj: + # For class objects we must only infer constraints if possible, otherwise it + # can lead to confusion between class and instance, for example StrEnum is + # Iterable[str] for an instance, but Iterable[StrEnum] for a class object. + if not mypy.subtypes.is_subtype( + inst, erase_typevars(temp), ignore_pos_arg_names=True + ): + continue res.extend(infer_constraints(temp, inst, self.direction)) if mypy.subtypes.IS_SETTABLE in mypy.subtypes.get_member_flags(member, protocol): # Settable members are invariant, add opposite constraints diff --git a/test-data/unit/check-enum.test b/test-data/unit/check-enum.test index 37c63f43179d..4b7460696aec 100644 --- a/test-data/unit/check-enum.test +++ b/test-data/unit/check-enum.test @@ -2394,3 +2394,25 @@ def do_check(value: E) -> None: [builtins fixtures/primitives.pyi] [typing fixtures/typing-full.pyi] + +[case testStrEnumClassCorrectIterable] +from enum import StrEnum +from typing import Type, TypeVar + +class Choices(StrEnum): + LOREM = "lorem" + IPSUM = "ipsum" + +var = list(Choices) +reveal_type(var) # N: Revealed type is "builtins.list[__main__.Choices]" + +e: type[StrEnum] +reveal_type(list(e)) # N: Revealed type is "builtins.list[enum.StrEnum]" + +T = TypeVar("T", bound=StrEnum) +def list_vals(e: Type[T]) -> list[T]: + reveal_type(list(e)) # N: Revealed type is "builtins.list[T`-1]" + return list(e) + +reveal_type(list_vals(Choices)) # N: Revealed type is "builtins.list[__main__.Choices]" +[builtins fixtures/enum.pyi] diff --git a/test-data/unit/fixtures/enum.pyi b/test-data/unit/fixtures/enum.pyi index 135e9cd16e7c..22e7193da041 100644 --- a/test-data/unit/fixtures/enum.pyi +++ b/test-data/unit/fixtures/enum.pyi @@ -1,5 +1,5 @@ # Minimal set of builtins required to work with Enums -from typing import TypeVar, Generic +from typing import TypeVar, Generic, Iterator, Sequence, overload, Iterable T = TypeVar('T') @@ -13,6 +13,13 @@ class tuple(Generic[T]): class int: pass class str: def __len__(self) -> int: pass + def __iter__(self) -> Iterator[str]: pass class dict: pass class ellipsis: pass + +class list(Sequence[T]): + @overload + def __init__(self) -> None: pass + @overload + def __init__(self, x: Iterable[T]) -> None: pass From f6caf006407857bc179e71fab7b6287bd1251bc2 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Mon, 3 Feb 2025 10:18:10 +0000 Subject: [PATCH 2/3] Minimal fix for exposed bugs --- mypy/constraints.py | 3 +++ mypy/join.py | 12 ++++++++- mypy/subtypes.py | 4 +-- test-data/unit/check-functions.test | 11 +++++--- .../unit/check-parameter-specification.test | 27 +++++++++++++++++++ 5 files changed, 50 insertions(+), 7 deletions(-) diff --git a/mypy/constraints.py b/mypy/constraints.py index e7d93c091bf9..defcac21bc66 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -1051,6 +1051,9 @@ def infer_constraints_from_protocol_members( inst, erase_typevars(temp), ignore_pos_arg_names=True ): continue + # This exception matches the one in subtypes.py, see PR #14121 for context. + if member == "__call__" and instance.type.is_metaclass(): + continue res.extend(infer_constraints(temp, inst, self.direction)) if mypy.subtypes.IS_SETTABLE in mypy.subtypes.get_member_flags(member, protocol): # Settable members are invariant, add opposite constraints diff --git a/mypy/join.py b/mypy/join.py index 166434f58f8d..9fa6e27207f4 100644 --- a/mypy/join.py +++ b/mypy/join.py @@ -355,7 +355,8 @@ def visit_unpack_type(self, t: UnpackType) -> UnpackType: def visit_parameters(self, t: Parameters) -> ProperType: if isinstance(self.s, Parameters): - if len(t.arg_types) != len(self.s.arg_types): + if not is_similar_params(t, self.s): + # TODO: it would be prudent to return [*object, **object] instead of Any. return self.default(self.s) from mypy.meet import meet_types @@ -724,6 +725,15 @@ def is_similar_callables(t: CallableType, s: CallableType) -> bool: ) +def is_similar_params(t: Parameters, s: Parameters) -> bool: + # This matches the logic in is_similar_callables() above. + return ( + len(t.arg_types) == len(s.arg_types) + and t.min_args == s.min_args + and (t.var_arg() is not None) == (s.var_arg() is not None) + ) + + def update_callable_ids(c: CallableType, ids: list[TypeVarId]) -> CallableType: tv_map = {} tvs = [] diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 804930fc9d0c..2d48c957308e 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -1796,12 +1796,12 @@ def _incompatible(left_arg: FormalArgument | None, right_arg: FormalArgument | N ) # If the left hand argument corresponds to two right-hand arguments, - # neither of them can be required. + # both of them can't be required. if ( right_by_name is not None and right_by_pos is not None and right_by_name != right_by_pos - and (right_by_pos.required or right_by_name.required) + and (right_by_pos.required and right_by_name.required) and strict_concatenate_check and not right.imprecise_arg_kinds ): diff --git a/test-data/unit/check-functions.test b/test-data/unit/check-functions.test index 58973307a1ae..d005b200a919 100644 --- a/test-data/unit/check-functions.test +++ b/test-data/unit/check-functions.test @@ -105,16 +105,19 @@ if int(): h = h [case testSubtypingFunctionsDoubleCorrespondence] +def l(x) -> None: ... +def r(__a, *, x) -> None: ... +r = l # E: Incompatible types in assignment (expression has type "Callable[[Any], None]", variable has type "Callable[[Any, NamedArg(Any, 'x')], None]") +[case testSubtypingFunctionsDoubleCorrespondenceNamedOptional] def l(x) -> None: ... -def r(__, *, x) -> None: ... -r = l # E: Incompatible types in assignment (expression has type "Callable[[Any], None]", variable has type "Callable[[Any, NamedArg(Any, 'x')], None]") +def r(__a, *, x = 1) -> None: ... +r = l # E: Incompatible types in assignment (expression has type "Callable[[Any], None]", variable has type "Callable[[Any, DefaultNamedArg(Any, 'x')], None]") [case testSubtypingFunctionsRequiredLeftArgNotPresent] - def l(x, y) -> None: ... def r(x) -> None: ... -r = l # E: Incompatible types in assignment (expression has type "Callable[[Any, Any], None]", variable has type "Callable[[Any], None]") +r = l # E: Incompatible types in assignment (expression has type "Callable[[Any, Any], None]", variable has type "Callable[[Any], None]") [case testSubtypingFunctionsImplicitNames] from typing import Any diff --git a/test-data/unit/check-parameter-specification.test b/test-data/unit/check-parameter-specification.test index 352503023f97..f938226f8472 100644 --- a/test-data/unit/check-parameter-specification.test +++ b/test-data/unit/check-parameter-specification.test @@ -2532,3 +2532,30 @@ class GenericWrapper(Generic[P]): def contains(c: Callable[P, None], *args: P.args, **kwargs: P.kwargs) -> None: ... def inherits(*args: P.args, **kwargs: P.kwargs) -> None: ... [builtins fixtures/paramspec.pyi] + +[case testCallbackProtocolClassObjectParamSpec] +from typing import Any, Callable, Protocol, Optional, Generic +from typing_extensions import ParamSpec + +P = ParamSpec("P") + +class App: ... + +class MiddlewareFactory(Protocol[P]): + def __call__(self, app: App, /, *args: P.args, **kwargs: P.kwargs) -> App: + ... + +class Capture(Generic[P]): ... + +class ServerErrorMiddleware(App): + def __init__( + self, + app: App, + handler: Optional[str] = None, + debug: bool = False, + ) -> None: ... + +def fn(f: MiddlewareFactory[P]) -> Capture[P]: ... + +reveal_type(fn(ServerErrorMiddleware)) # N: Revealed type is "__main__.Capture[[handler: Union[builtins.str, None] =, debug: builtins.bool =]]" +[builtins fixtures/paramspec.pyi] From 72661cbc7ada05e60fa58807dd7a9428bb0e2c43 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Mon, 3 Feb 2025 12:26:58 +0000 Subject: [PATCH 3/3] A more principled fix for the subtyping issue --- mypy/subtypes.py | 13 +++++++++---- test-data/unit/check-functions.test | 23 +++++++++++++++++++++-- 2 files changed, 30 insertions(+), 6 deletions(-) diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 2d48c957308e..75cc7e25fde3 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -1719,11 +1719,16 @@ def _incompatible(left_arg: FormalArgument | None, right_arg: FormalArgument | N ): return False + if trivial_suffix: + # For trivial right suffix we *only* check that every non-star right argument + # has a valid match on the left. + return True + # Phase 1c: Check var args. Right has an infinite series of optional positional # arguments. Get all further positional args of left, and make sure # they're more general than the corresponding member in right. # TODO: are we handling UnpackType correctly here? - if right_star is not None and not trivial_suffix: + if right_star is not None: # Synthesize an anonymous formal argument for the right right_by_position = right.try_synthesizing_arg_from_vararg(None) assert right_by_position is not None @@ -1750,7 +1755,7 @@ def _incompatible(left_arg: FormalArgument | None, right_arg: FormalArgument | N # Phase 1d: Check kw args. Right has an infinite series of optional named # arguments. Get all further named args of left, and make sure # they're more general than the corresponding member in right. - if right_star2 is not None and not trivial_suffix: + if right_star2 is not None: right_names = {name for name in right.arg_names if name is not None} left_only_names = set() for name, kind in zip(left.arg_names, left.arg_kinds): @@ -1796,12 +1801,12 @@ def _incompatible(left_arg: FormalArgument | None, right_arg: FormalArgument | N ) # If the left hand argument corresponds to two right-hand arguments, - # both of them can't be required. + # neither of them can be required. if ( right_by_name is not None and right_by_pos is not None and right_by_name != right_by_pos - and (right_by_pos.required and right_by_name.required) + and (right_by_pos.required or right_by_name.required) and strict_concatenate_check and not right.imprecise_arg_kinds ): diff --git a/test-data/unit/check-functions.test b/test-data/unit/check-functions.test index d005b200a919..ccce2cb96a88 100644 --- a/test-data/unit/check-functions.test +++ b/test-data/unit/check-functions.test @@ -106,14 +106,33 @@ if int(): [case testSubtypingFunctionsDoubleCorrespondence] def l(x) -> None: ... -def r(__a, *, x) -> None: ... +def r(__x, *, x) -> None: ... r = l # E: Incompatible types in assignment (expression has type "Callable[[Any], None]", variable has type "Callable[[Any, NamedArg(Any, 'x')], None]") [case testSubtypingFunctionsDoubleCorrespondenceNamedOptional] def l(x) -> None: ... -def r(__a, *, x = 1) -> None: ... +def r(__x, *, x = 1) -> None: ... r = l # E: Incompatible types in assignment (expression has type "Callable[[Any], None]", variable has type "Callable[[Any, DefaultNamedArg(Any, 'x')], None]") +[case testSubtypingFunctionsDoubleCorrespondenceBothNamedOptional] +def l(x = 1) -> None: ... +def r(__x, *, x = 1) -> None: ... +r = l # E: Incompatible types in assignment (expression has type "Callable[[Any], None]", variable has type "Callable[[Any, DefaultNamedArg(Any, 'x')], None]") + +[case testSubtypingFunctionsTrivialSuffixRequired] +def l(__x) -> None: ... +def r(x, *args, **kwargs) -> None: ... + +r = l # E: Incompatible types in assignment (expression has type "Callable[[Any], None]", variable has type "Callable[[Arg(Any, 'x'), VarArg(Any), KwArg(Any)], None]") +[builtins fixtures/dict.pyi] + +[case testSubtypingFunctionsTrivialSuffixOptional] +def l(__x = 1) -> None: ... +def r(x = 1, *args, **kwargs) -> None: ... + +r = l # E: Incompatible types in assignment (expression has type "Callable[[DefaultArg(Any)], None]", variable has type "Callable[[DefaultArg(Any, 'x'), VarArg(Any), KwArg(Any)], None]") +[builtins fixtures/dict.pyi] + [case testSubtypingFunctionsRequiredLeftArgNotPresent] def l(x, y) -> None: ... def r(x) -> None: ...