Skip to content

Combine the revealed types of multiple iteration steps in a more robust manner. #19324

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,8 +643,8 @@ def accept_loop(

for error_info in watcher.yield_error_infos():
self.msg.fail(*error_info[:2], code=error_info[2])
for note_info in watcher.yield_note_infos(self.options):
self.note(*note_info)
for note_info, context in watcher.yield_note_infos():
self.msg.reveal_type(note_info, context)

# If exit_condition is set, assume it must be False on exit from the loop:
if exit_condition:
Expand Down Expand Up @@ -3037,7 +3037,7 @@ def is_noop_for_reachability(self, s: Statement) -> bool:
if isinstance(s.expr, EllipsisExpr):
return True
elif isinstance(s.expr, CallExpr):
with self.expr_checker.msg.filter_errors():
with self.expr_checker.msg.filter_errors(filter_revealed_type=True):
typ = get_proper_type(
self.expr_checker.accept(
s.expr, allow_none_return=True, always_allow_any=True
Expand Down Expand Up @@ -4987,8 +4987,8 @@ def visit_try_stmt(self, s: TryStmt) -> None:

for error_info in watcher.yield_error_infos():
self.msg.fail(*error_info[:2], code=error_info[2])
for note_info in watcher.yield_note_infos(self.options):
self.msg.note(*note_info)
for note_info, context in watcher.yield_note_infos():
self.msg.reveal_type(note_info, context)

def visit_try_without_finally(self, s: TryStmt, try_frame: bool) -> None:
"""Type check a try statement, ignoring the finally block.
Expand Down
60 changes: 27 additions & 33 deletions mypy/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from mypy.nodes import Context
from mypy.options import Options
from mypy.scope import Scope
from mypy.typeops import make_simplified_union
from mypy.types import Type
from mypy.util import DEFAULT_SOURCE_OFFSET, is_typeshed_file
from mypy.version import __version__ as mypy_version

Expand Down Expand Up @@ -166,18 +168,24 @@ class ErrorWatcher:
out by one of the ErrorWatcher instances.
"""

# public attribute for the special treatment of `reveal_type` by
# `MessageBuilder.reveal_type`:
filter_revealed_type: bool

def __init__(
self,
errors: Errors,
*,
filter_errors: bool | Callable[[str, ErrorInfo], bool] = False,
save_filtered_errors: bool = False,
filter_deprecated: bool = False,
filter_revealed_type: bool = False,
) -> None:
self.errors = errors
self._has_new_errors = False
self._filter = filter_errors
self._filter_deprecated = filter_deprecated
self.filter_revealed_type = filter_revealed_type
self._filtered: list[ErrorInfo] | None = [] if save_filtered_errors else None

def __enter__(self) -> Self:
Expand Down Expand Up @@ -236,15 +244,15 @@ class IterationDependentErrors:
# the error report occurs but really all unreachable lines.
unreachable_lines: list[set[int]]

# One set of revealed types for each `reveal_type` statement. Each created set can
# grow during the iteration. Meaning of the tuple items: function_or_member, line,
# column, end_line, end_column:
revealed_types: dict[tuple[str | None, int, int, int, int], set[str]]
# One list of revealed types for each `reveal_type` statement. Each created list
# can grow during the iteration. Meaning of the tuple items: line, column,
# end_line, end_column:
revealed_types: dict[tuple[int, int, int | None, int | None], list[Type]]

def __init__(self) -> None:
self.uselessness_errors = []
self.unreachable_lines = []
self.revealed_types = defaultdict(set)
self.revealed_types = defaultdict(list)


class IterationErrorWatcher(ErrorWatcher):
Expand Down Expand Up @@ -287,15 +295,6 @@ def on_error(self, file: str, info: ErrorInfo) -> bool:
iter_errors.unreachable_lines[-1].update(range(info.line, info.end_line + 1))
return True

if info.code == codes.MISC and info.message.startswith("Revealed type is "):
key = info.function_or_member, info.line, info.column, info.end_line, info.end_column
types = info.message.split('"')[1]
if types.startswith("Union["):
iter_errors.revealed_types[key].update(types[6:-1].split(", "))
else:
iter_errors.revealed_types[key].add(types)
return True

return super().on_error(file, info)

def yield_error_infos(self) -> Iterator[tuple[str, Context, ErrorCode]]:
Expand All @@ -318,21 +317,14 @@ def yield_error_infos(self) -> Iterator[tuple[str, Context, ErrorCode]]:
context.end_column = error_info[5]
yield error_info[1], context, error_info[0]

def yield_note_infos(self, options: Options) -> Iterator[tuple[str, Context]]:
def yield_note_infos(self) -> Iterator[tuple[Type, Context]]:
"""Yield all types revealed in at least one iteration step."""

for note_info, types in self.iteration_dependent_errors.revealed_types.items():
sorted_ = sorted(types, key=lambda typ: typ.lower())
if len(types) == 1:
revealed = sorted_[0]
elif options.use_or_syntax():
revealed = " | ".join(sorted_)
else:
revealed = f"Union[{', '.join(sorted_)}]"
context = Context(line=note_info[1], column=note_info[2])
context.end_line = note_info[3]
context.end_column = note_info[4]
yield f'Revealed type is "{revealed}"', context
context = Context(line=note_info[0], column=note_info[1])
context.end_line = note_info[2]
context.end_column = note_info[3]
yield make_simplified_union(types), context


class Errors:
Expand Down Expand Up @@ -596,18 +588,20 @@ def _add_error_info(self, file: str, info: ErrorInfo) -> None:
if info.code in (IMPORT, IMPORT_UNTYPED, IMPORT_NOT_FOUND):
self.seen_import_error = True

@property
def watchers(self) -> Iterator[ErrorWatcher]:
"""Yield the `ErrorWatcher` stack from top to bottom."""
i = len(self._watchers)
while i > 0:
i -= 1
yield self._watchers[i]

def _filter_error(self, file: str, info: ErrorInfo) -> bool:
"""
process ErrorWatcher stack from top to bottom,
stopping early if error needs to be filtered out
"""
i = len(self._watchers)
while i > 0:
i -= 1
w = self._watchers[i]
if w.on_error(file, info):
return True
return False
return any(w.on_error(file, info) for w in self.watchers)

def add_error_info(self, info: ErrorInfo) -> None:
file, lines = info.origin
Expand Down
22 changes: 21 additions & 1 deletion mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from mypy import errorcodes as codes, message_registry
from mypy.erasetype import erase_type
from mypy.errorcodes import ErrorCode
from mypy.errors import ErrorInfo, Errors, ErrorWatcher
from mypy.errors import ErrorInfo, Errors, ErrorWatcher, IterationErrorWatcher
from mypy.nodes import (
ARG_NAMED,
ARG_NAMED_OPT,
Expand Down Expand Up @@ -188,12 +188,14 @@ def filter_errors(
filter_errors: bool | Callable[[str, ErrorInfo], bool] = True,
save_filtered_errors: bool = False,
filter_deprecated: bool = False,
filter_revealed_type: bool = False,
) -> ErrorWatcher:
return ErrorWatcher(
self.errors,
filter_errors=filter_errors,
save_filtered_errors=save_filtered_errors,
filter_deprecated=filter_deprecated,
filter_revealed_type=filter_revealed_type,
)

def add_errors(self, errors: list[ErrorInfo]) -> None:
Expand Down Expand Up @@ -1738,6 +1740,24 @@ def invalid_signature_for_special_method(
)

def reveal_type(self, typ: Type, context: Context) -> None:

# Search for an error watcher that modifies the "normal" behaviour (we do not
# rely on the normal `ErrorWatcher` filtering approach because we might need to
# collect the original types for a later unionised response):
for watcher in self.errors.watchers:
# The `reveal_type` statement should be ignored:
if watcher.filter_revealed_type:
return
# The `reveal_type` statement might be visited iteratively due to being
# placed in a loop or so. Hence, we collect the respective types of
# individual iterations so that we can report them all in one step later:
if isinstance(watcher, IterationErrorWatcher):
watcher.iteration_dependent_errors.revealed_types[
(context.line, context.column, context.end_line, context.end_column)
].append(typ)
return

# Nothing special here; just create the note:
visitor = TypeStrVisitor(options=self.options)
self.note(f'Revealed type is "{typ.accept(visitor)}"', context)

Expand Down
4 changes: 2 additions & 2 deletions test-data/unit/check-inference.test
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ for var2 in [g, h, i, j, k, l]:
reveal_type(var2) # N: Revealed type is "Union[builtins.int, builtins.str]"

for var3 in [m, n, o, p, q, r]:
reveal_type(var3) # N: Revealed type is "Union[Any, builtins.int]"
reveal_type(var3) # N: Revealed type is "Union[builtins.int, Any]"

T = TypeVar("T", bound=Type[Foo])

Expand Down Expand Up @@ -1247,7 +1247,7 @@ class X(TypedDict):

x: X
for a in ("hourly", "daily"):
reveal_type(a) # N: Revealed type is "Union[Literal['daily']?, Literal['hourly']?]"
reveal_type(a) # N: Revealed type is "Union[Literal['hourly']?, Literal['daily']?]"
reveal_type(x[a]) # N: Revealed type is "builtins.int"
reveal_type(a.upper()) # N: Revealed type is "builtins.str"
c = a
Expand Down
2 changes: 1 addition & 1 deletion test-data/unit/check-narrowing.test
Original file line number Diff line number Diff line change
Expand Up @@ -2346,7 +2346,7 @@ def f() -> bool: ...

y = None
while f():
reveal_type(y) # N: Revealed type is "Union[builtins.int, None]"
reveal_type(y) # N: Revealed type is "Union[None, builtins.int]"
y = 1
reveal_type(y) # N: Revealed type is "Union[builtins.int, None]"

Expand Down
4 changes: 2 additions & 2 deletions test-data/unit/check-redefine2.test
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,7 @@ def f1() -> None:
def f2() -> None:
x = None
while int():
reveal_type(x) # N: Revealed type is "Union[builtins.str, None]"
reveal_type(x) # N: Revealed type is "Union[None, builtins.str]"
if int():
x = ""
reveal_type(x) # N: Revealed type is "Union[None, builtins.str]"
Expand Down Expand Up @@ -923,7 +923,7 @@ class X(TypedDict):

x: X
for a in ("hourly", "daily"):
reveal_type(a) # N: Revealed type is "Union[Literal['daily']?, Literal['hourly']?]"
reveal_type(a) # N: Revealed type is "Union[Literal['hourly']?, Literal['daily']?]"
reveal_type(x[a]) # N: Revealed type is "builtins.int"
reveal_type(a.upper()) # N: Revealed type is "builtins.str"
c = a
Expand Down
2 changes: 1 addition & 1 deletion test-data/unit/check-typevar-tuple.test
Original file line number Diff line number Diff line change
Expand Up @@ -989,7 +989,7 @@ from typing_extensions import Unpack

def pipeline(*xs: Unpack[Tuple[int, Unpack[Tuple[float, ...]], bool]]) -> None:
for x in xs:
reveal_type(x) # N: Revealed type is "Union[builtins.float, builtins.int]"
reveal_type(x) # N: Revealed type is "Union[builtins.int, builtins.float]"
[builtins fixtures/tuple.pyi]

[case testFixedUnpackItemInInstanceArguments]
Expand Down
7 changes: 4 additions & 3 deletions test-data/unit/check-union-error-syntax.test
Original file line number Diff line number Diff line change
Expand Up @@ -62,17 +62,18 @@ x = 3 # E: Incompatible types in assignment (expression has type "Literal[3]", v
try:
x = 1
x = ""
x = {1: ""}
finally:
reveal_type(x) # N: Revealed type is "Union[builtins.int, builtins.str]"
reveal_type(x) # N: Revealed type is "Union[builtins.int, builtins.str, builtins.dict[builtins.int, builtins.str]]"
[builtins fixtures/isinstancelist.pyi]

[case testOrSyntaxRecombined]
# flags: --python-version 3.10 --no-force-union-syntax --allow-redefinition-new --local-partial-types
# The following revealed type is recombined because the finally body is visited twice.
# ToDo: Improve this recombination logic, especially (but not only) for the "or syntax".
try:
x = 1
x = ""
x = {1: ""}
finally:
reveal_type(x) # N: Revealed type is "builtins.int | builtins.str | builtins.str"
reveal_type(x) # N: Revealed type is "builtins.int | builtins.str | builtins.dict[builtins.int, builtins.str]"
[builtins fixtures/isinstancelist.pyi]