diff --git a/mypy/build.py b/mypy/build.py index 3ddbe40f1a90..6513414dd298 100644 --- a/mypy/build.py +++ b/mypy/build.py @@ -56,7 +56,7 @@ from mypy.fscache import FileSystemCache from mypy.metastore import MetadataStore, FilesystemMetadataStore, SqliteMetadataStore from mypy.typestate import TypeState, reset_global_state -from mypy.renaming import VariableRenameVisitor +from mypy.renaming import VariableRenameVisitor, LimitedVariableRenameVisitor from mypy.config_parser import parse_mypy_comments from mypy.freetree import free_tree from mypy.stubinfo import legacy_bundled_packages, is_legacy_bundled_package @@ -2119,9 +2119,12 @@ def semantic_analysis_pass1(self) -> None: analyzer.visit_file(self.tree, self.xpath, self.id, options) # TODO: Do this while constructing the AST? self.tree.names = SymbolTable() - if options.allow_redefinition: - # Perform renaming across the AST to allow variable redefinitions - self.tree.accept(VariableRenameVisitor()) + if not self.tree.is_stub: + # Always perform some low-key variable renaming + self.tree.accept(LimitedVariableRenameVisitor()) + if options.allow_redefinition: + # Perform more renaming across the AST to allow variable redefinitions + self.tree.accept(VariableRenameVisitor()) def add_dependency(self, dep: str) -> None: if dep not in self.dependencies_set: diff --git a/mypy/renaming.py b/mypy/renaming.py index c200e94d58e7..aa16d6dc8dd2 100644 --- a/mypy/renaming.py +++ b/mypy/renaming.py @@ -1,11 +1,11 @@ from contextlib import contextmanager -from typing import Dict, Iterator, List +from typing import Dict, Iterator, List, Set from typing_extensions import Final from mypy.nodes import ( Block, AssignmentStmt, NameExpr, MypyFile, FuncDef, Lvalue, ListExpr, TupleExpr, WhileStmt, ForStmt, BreakStmt, ContinueStmt, TryStmt, WithStmt, MatchStmt, StarExpr, - ImportFrom, MemberExpr, IndexExpr, Import, ClassDef + ImportFrom, MemberExpr, IndexExpr, Import, ImportAll, ClassDef ) from mypy.patterns import AsPattern from mypy.traverser import TraverserVisitor @@ -262,15 +262,9 @@ def flush_refs(self) -> None: # as it will be publicly visible outside the module. to_rename = refs[:-1] for i, item in enumerate(to_rename): - self.rename_refs(item, i) + rename_refs(item, i) self.refs.pop() - def rename_refs(self, names: List[NameExpr], index: int) -> None: - name = names[0].name - new_name = name + "'" * (index + 1) - for expr in names: - expr.name = new_name - # Helpers for determining which assignments define new variables def clear(self) -> None: @@ -392,3 +386,162 @@ def record_assignment(self, name: str, can_be_redefined: bool) -> bool: else: # Assigns to an existing variable. return False + + +class LimitedVariableRenameVisitor(TraverserVisitor): + """Perform some limited variable renaming in with statements. + + This allows reusing a variable in multiple with statements with + different types. For example, the two instances of 'x' can have + incompatible types: + + with C() as x: + f(x) + with D() as x: + g(x) + + The above code gets renamed conceptually into this (not valid Python!): + + with C() as x': + f(x') + with D() as x: + g(x) + + If there's a reference to a variable defined in 'with' outside the + statement, or if there's any trickiness around variable visibility + (e.g. function definitions), we give up and won't perform renaming. + + The main use case is to allow binding both readable and writable + binary files into the same variable. These have different types: + + with open(fnam, 'rb') as f: ... + with open(fnam, 'wb') as f: ... + """ + + def __init__(self) -> None: + # Short names of variables bound in with statements using "as" + # in a surrounding scope + self.bound_vars: List[str] = [] + # Names that can't be safely renamed, per scope ('*' means that + # no names can be renamed) + self.skipped: List[Set[str]] = [] + # References to variables that we may need to rename. List of + # scopes; each scope is a mapping from name to list of collections + # of names that refer to the same logical variable. + self.refs: List[Dict[str, List[List[NameExpr]]]] = [] + + def visit_mypy_file(self, file_node: MypyFile) -> None: + """Rename variables within a file. + + This is the main entry point to this class. + """ + with self.enter_scope(): + for d in file_node.defs: + d.accept(self) + + def visit_func_def(self, fdef: FuncDef) -> None: + self.reject_redefinition_of_vars_in_scope() + with self.enter_scope(): + for arg in fdef.arguments: + self.record_skipped(arg.variable.name) + super().visit_func_def(fdef) + + def visit_class_def(self, cdef: ClassDef) -> None: + self.reject_redefinition_of_vars_in_scope() + with self.enter_scope(): + super().visit_class_def(cdef) + + def visit_with_stmt(self, stmt: WithStmt) -> None: + for expr in stmt.expr: + expr.accept(self) + old_len = len(self.bound_vars) + for target in stmt.target: + if target is not None: + self.analyze_lvalue(target) + for target in stmt.target: + if target: + target.accept(self) + stmt.body.accept(self) + + while len(self.bound_vars) > old_len: + self.bound_vars.pop() + + def analyze_lvalue(self, lvalue: Lvalue) -> None: + if isinstance(lvalue, NameExpr): + name = lvalue.name + if name in self.bound_vars: + # Name bound in a surrounding with statement, so it can be renamed + self.visit_name_expr(lvalue) + else: + var_info = self.refs[-1] + if name not in var_info: + var_info[name] = [] + var_info[name].append([]) + self.bound_vars.append(name) + elif isinstance(lvalue, (ListExpr, TupleExpr)): + for item in lvalue.items: + self.analyze_lvalue(item) + elif isinstance(lvalue, MemberExpr): + lvalue.expr.accept(self) + elif isinstance(lvalue, IndexExpr): + lvalue.base.accept(self) + lvalue.index.accept(self) + elif isinstance(lvalue, StarExpr): + self.analyze_lvalue(lvalue.expr) + + def visit_import(self, imp: Import) -> None: + # We don't support renaming imports + for id, as_id in imp.ids: + self.record_skipped(as_id or id) + + def visit_import_from(self, imp: ImportFrom) -> None: + # We don't support renaming imports + for id, as_id in imp.names: + self.record_skipped(as_id or id) + + def visit_import_all(self, imp: ImportAll) -> None: + # Give up, since we don't know all imported names yet + self.reject_redefinition_of_vars_in_scope() + + def visit_name_expr(self, expr: NameExpr) -> None: + name = expr.name + if name in self.bound_vars: + # Record reference so that it can be renamed later + for scope in reversed(self.refs): + if name in scope: + scope[name][-1].append(expr) + else: + self.record_skipped(name) + + @contextmanager + def enter_scope(self) -> Iterator[None]: + self.skipped.append(set()) + self.refs.append({}) + yield None + self.flush_refs() + + def reject_redefinition_of_vars_in_scope(self) -> None: + self.record_skipped('*') + + def record_skipped(self, name: str) -> None: + self.skipped[-1].add(name) + + def flush_refs(self) -> None: + ref_dict = self.refs.pop() + skipped = self.skipped.pop() + if '*' not in skipped: + for name, refs in ref_dict.items(): + if len(refs) <= 1 or name in skipped: + continue + # At module top level we must not rename the final definition, + # as it may be publicly visible + to_rename = refs[:-1] + for i, item in enumerate(to_rename): + rename_refs(item, i) + + +def rename_refs(names: List[NameExpr], index: int) -> None: + name = names[0].name + new_name = name + "'" * (index + 1) + for expr in names: + expr.name = new_name diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index 4c0324b9bec7..0d4f46d53924 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -1175,3 +1175,44 @@ def foo(value) -> int: # E: Missing return statement return 1 case 2: return 2 + +[case testWithStatementScopeAndMatchStatement] +from m import A, B + +with A() as x: + pass +with B() as x: \ + # E: Incompatible types in assignment (expression has type "B", variable has type "A") + pass + +with A() as y: + pass +with B() as y: \ + # E: Incompatible types in assignment (expression has type "B", variable has type "A") + pass + +with A() as z: + pass +with B() as z: \ + # E: Incompatible types in assignment (expression has type "B", variable has type "A") + pass + +with A() as zz: + pass +with B() as zz: \ + # E: Incompatible types in assignment (expression has type "B", variable has type "A") + pass + +match x: + case str(y) as z: + zz = y + +[file m.pyi] +from typing import Any + +class A: + def __enter__(self) -> A: ... + def __exit__(self, x, y, z) -> None: ... +class B: + def __enter__(self) -> B: ... + def __exit__(self, x, y, z) -> None: ... diff --git a/test-data/unit/check-statements.test b/test-data/unit/check-statements.test index 62d82f94a6c1..5819ace83603 100644 --- a/test-data/unit/check-statements.test +++ b/test-data/unit/check-statements.test @@ -1555,7 +1555,6 @@ class LiteralReturn: return False [builtins fixtures/bool.pyi] - [case testWithStmtBoolExitReturnInStub] import stub @@ -1572,6 +1571,370 @@ class C3: def __exit__(self, x, y, z) -> Optional[bool]: pass [builtins fixtures/bool.pyi] +[case testWithStmtScopeBasics] +from m import A, B + +def f1() -> None: + with A() as x: + reveal_type(x) # N: Revealed type is "m.A" + with B() as x: + reveal_type(x) # N: Revealed type is "m.B" + +def f2() -> None: + with A() as x: + reveal_type(x) # N: Revealed type is "m.A" + y = x # Use outside with makes the scope function-level + with B() as x: \ + # E: Incompatible types in assignment (expression has type "B", variable has type "A") + reveal_type(x) # N: Revealed type is "m.A" + +[file m.pyi] +class A: + def __enter__(self) -> A: ... + def __exit__(self, x, y, z) -> None: ... +class B: + def __enter__(self) -> B: ... + def __exit__(self, x, y, z) -> None: ... + +[case testWithStmtScopeAndFuncDef] +from m import A, B + +with A() as x: + reveal_type(x) # N: Revealed type is "m.A" + +def f() -> None: + pass # Don't support function definition in the middle + +with B() as x: \ + # E: Incompatible types in assignment (expression has type "B", variable has type "A") + reveal_type(x) # N: Revealed type is "m.A" + +[file m.pyi] +class A: + def __enter__(self) -> A: ... + def __exit__(self, x, y, z) -> None: ... +class B: + def __enter__(self) -> B: ... + def __exit__(self, x, y, z) -> None: ... + +[case testWithStmtScopeAndFuncDef2] +from m import A, B + +def f() -> None: + pass # function before with is unsupported + +with A() as x: + reveal_type(x) # N: Revealed type is "m.A" + +with B() as x: \ + # E: Incompatible types in assignment (expression has type "B", variable has type "A") + reveal_type(x) # N: Revealed type is "m.A" + +[file m.pyi] +class A: + def __enter__(self) -> A: ... + def __exit__(self, x, y, z) -> None: ... +class B: + def __enter__(self) -> B: ... + def __exit__(self, x, y, z) -> None: ... + +[case testWithStmtScopeAndFuncDef3] +from m import A, B + +with A() as x: + reveal_type(x) # N: Revealed type is "m.A" + +with B() as x: \ + # E: Incompatible types in assignment (expression has type "B", variable has type "A") + reveal_type(x) # N: Revealed type is "m.A" + +def f() -> None: + pass # function after with is unsupported + +[file m.pyi] +class A: + def __enter__(self) -> A: ... + def __exit__(self, x, y, z) -> None: ... +class B: + def __enter__(self) -> B: ... + def __exit__(self, x, y, z) -> None: ... + +[case testWithStmtScopeAndFuncDef4] +from m import A, B + +with A() as x: + def f() -> None: + pass # Function within with is unsupported + + reveal_type(x) # N: Revealed type is "m.A" + +with B() as x: \ + # E: Incompatible types in assignment (expression has type "B", variable has type "A") + reveal_type(x) # N: Revealed type is "m.A" + +[file m.pyi] +class A: + def __enter__(self) -> A: ... + def __exit__(self, x, y, z) -> None: ... +class B: + def __enter__(self) -> B: ... + def __exit__(self, x, y, z) -> None: ... + +[case testWithStmtScopeAndImport1] +from m import A, B, x + +with A() as x: \ + # E: Incompatible types in assignment (expression has type "A", variable has type "B") + reveal_type(x) # N: Revealed type is "m.B" + +with B() as x: + reveal_type(x) # N: Revealed type is "m.B" + +[file m.pyi] +x: B + +class A: + def __enter__(self) -> A: ... + def __exit__(self, x, y, z) -> None: ... +class B: + def __enter__(self) -> B: ... + def __exit__(self, x, y, z) -> None: ... + +[case testWithStmtScopeAndImport2] +from m import A, B +import m as x + +with A() as x: \ + # E: Incompatible types in assignment (expression has type "A", variable has type Module) + pass + +with B() as x: \ + # E: Incompatible types in assignment (expression has type "B", variable has type Module) + pass + +[file m.pyi] +class A: + def __enter__(self) -> A: ... + def __exit__(self, x, y, z) -> None: ... +class B: + def __enter__(self) -> B: ... + def __exit__(self, x, y, z) -> None: ... +[builtins fixtures/module.pyi] + +[case testWithStmtScopeAndImportStar] +from m import A, B +from m import * + +with A() as x: + pass + +with B() as x: \ + # E: Incompatible types in assignment (expression has type "B", variable has type "A") + pass + +[file m.pyi] +class A: + def __enter__(self) -> A: ... + def __exit__(self, x, y, z) -> None: ... +class B: + def __enter__(self) -> B: ... + def __exit__(self, x, y, z) -> None: ... + +[case testWithStmtScopeNestedWith1] +from m import A, B + +with A() as x: + with B() as x: \ + # E: Incompatible types in assignment (expression has type "B", variable has type "A") + reveal_type(x) # N: Revealed type is "m.A" + +with B() as x: + with A() as x: \ + # E: Incompatible types in assignment (expression has type "A", variable has type "B") + reveal_type(x) # N: Revealed type is "m.B" + +[file m.pyi] +class A: + def __enter__(self) -> A: ... + def __exit__(self, x, y, z) -> None: ... +class B: + def __enter__(self) -> B: ... + def __exit__(self, x, y, z) -> None: ... + +[case testWithStmtScopeNestedWith2] +from m import A, B + +with A() as x: + with A() as y: + reveal_type(y) # N: Revealed type is "m.A" + with B() as y: + reveal_type(y) # N: Revealed type is "m.B" + +[file m.pyi] +class A: + def __enter__(self) -> A: ... + def __exit__(self, x, y, z) -> None: ... +class B: + def __enter__(self) -> B: ... + def __exit__(self, x, y, z) -> None: ... + +[case testWithStmtScopeInnerAndOuterScopes] +from m import A, B + +x = A() # Outer scope should have no impact + +with A() as x: + pass + +def f() -> None: + with A() as x: + reveal_type(x) # N: Revealed type is "m.A" + with B() as x: + reveal_type(x) # N: Revealed type is "m.B" + +y = x + +with A() as x: + pass + +[file m.pyi] +class A: + def __enter__(self) -> A: ... + def __exit__(self, x, y, z) -> None: ... +class B: + def __enter__(self) -> B: ... + def __exit__(self, x, y, z) -> None: ... + +[case testWithStmtScopeMultipleContextManagers] +from m import A, B + +with A() as x, B() as y: + reveal_type(x) # N: Revealed type is "m.A" + reveal_type(y) # N: Revealed type is "m.B" +with B() as x, A() as y: + reveal_type(x) # N: Revealed type is "m.B" + reveal_type(y) # N: Revealed type is "m.A" + +[file m.pyi] +class A: + def __enter__(self) -> A: ... + def __exit__(self, x, y, z) -> None: ... +class B: + def __enter__(self) -> B: ... + def __exit__(self, x, y, z) -> None: ... + +[case testWithStmtScopeMultipleAssignment] +from m import A, B + +with A() as (x, y): + reveal_type(x) # N: Revealed type is "m.A" + reveal_type(y) # N: Revealed type is "builtins.int" +with B() as [x, y]: + reveal_type(x) # N: Revealed type is "m.B" + reveal_type(y) # N: Revealed type is "builtins.str" + +[file m.pyi] +from typing import Tuple + +class A: + def __enter__(self) -> Tuple[A, int]: ... + def __exit__(self, x, y, z) -> None: ... +class B: + def __enter__(self) -> Tuple[B, str]: ... + def __exit__(self, x, y, z) -> None: ... +[builtins fixtures/tuple.pyi] + +[case testWithStmtScopeComplexAssignments] +from m import A, B, f + +with A() as x: + pass +with B() as x: \ + # E: Incompatible types in assignment (expression has type "B", variable has type "A") + pass +with B() as f(x).x: + pass + +with A() as y: + pass +with B() as y: \ + # E: Incompatible types in assignment (expression has type "B", variable has type "A") + pass +with B() as f(y)[0]: + pass + +[file m.pyi] +def f(x): ... + +class A: + def __enter__(self) -> A: ... + def __exit__(self, x, y, z) -> None: ... +class B: + def __enter__(self) -> B: ... + def __exit__(self, x, y, z) -> None: ... + +[case testWithStmtScopeAndClass] +from m import A, B + +with A() as x: + pass + +class C: + with A() as y: + pass + with B() as y: + pass + +with B() as x: \ + # E: Incompatible types in assignment (expression has type "B", variable has type "A") + pass + +[file m.pyi] +class A: + def __enter__(self) -> A: ... + def __exit__(self, x, y, z) -> None: ... +class B: + def __enter__(self) -> B: ... + def __exit__(self, x, y, z) -> None: ... + +[case testWithStmtScopeInnerScopeReference] +from m import A, B + +with A() as x: + def f() -> A: + return x + f() + +with B() as x: \ + # E: Incompatible types in assignment (expression has type "B", variable has type "A") + pass +[file m.pyi] +class A: + def __enter__(self) -> A: ... + def __exit__(self, x, y, z) -> None: ... +class B: + def __enter__(self) -> B: ... + def __exit__(self, x, y, z) -> None: ... + +[case testWithStmtScopeAndLambda] +from m import A, B + +# This is technically not correct, since the lambda can outlive the with +# statement, but this behavior seems more intuitive. + +with A() as x: + lambda: reveal_type(x) # N: Revealed type is "m.A" + +with B() as x: + pass +[file m.pyi] +class A: + def __enter__(self) -> A: ... + def __exit__(self, x, y, z) -> None: ... +class B: + def __enter__(self) -> B: ... + def __exit__(self, x, y, z) -> None: ... + -- Chained assignment -- ------------------ diff --git a/test-data/unit/semanal-statements.test b/test-data/unit/semanal-statements.test index 97dd5f048834..fdc5ca2bbbdd 100644 --- a/test-data/unit/semanal-statements.test +++ b/test-data/unit/semanal-statements.test @@ -1058,3 +1058,59 @@ MypyFile:1( AssignmentStmt:6( NameExpr(x [__main__.x]) StrExpr())) + +[case testSimpleWithRenaming] +with 0 as y: + z = y +with 1 as y: + y = 1 +[out] +MypyFile:1( + WithStmt:1( + Expr( + IntExpr(0)) + Target( + NameExpr(y'* [__main__.y'])) + Block:1( + AssignmentStmt:2( + NameExpr(z* [__main__.z]) + NameExpr(y' [__main__.y'])))) + WithStmt:3( + Expr( + IntExpr(1)) + Target( + NameExpr(y* [__main__.y])) + Block:3( + AssignmentStmt:4( + NameExpr(y [__main__.y]) + IntExpr(1))))) + +[case testSimpleWithRenamingFailure] +with 0 as y: + z = y +zz = y +with 1 as y: + y = 1 +[out] +MypyFile:1( + WithStmt:1( + Expr( + IntExpr(0)) + Target( + NameExpr(y* [__main__.y])) + Block:1( + AssignmentStmt:2( + NameExpr(z* [__main__.z]) + NameExpr(y [__main__.y])))) + AssignmentStmt:3( + NameExpr(zz* [__main__.zz]) + NameExpr(y [__main__.y])) + WithStmt:4( + Expr( + IntExpr(1)) + Target( + NameExpr(y [__main__.y])) + Block:4( + AssignmentStmt:5( + NameExpr(y [__main__.y]) + IntExpr(1)))))