diff --git a/tests/check_new_syntax.py b/tests/check_new_syntax.py index aa9c8f73ff96..6bc5581a4124 100755 --- a/tests/check_new_syntax.py +++ b/tests/check_new_syntax.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 import ast +import re import sys from itertools import chain from pathlib import Path @@ -61,6 +62,91 @@ def visit_Subscript(self, node: ast.Subscript) -> None: self.generic_visit(node) + def check_instance_method_for_bad_typevars( + *, + method: ast.FunctionDef | ast.AsyncFunctionDef, + first_arg_annotation: ast.Name | ast.Subscript, + return_annotation: ast.Name, + ) -> None: + if not isinstance(first_arg_annotation, ast.Name): + return + + if first_arg_annotation.id != return_annotation.id: + return + + arg1_annotation_name = first_arg_annotation.id + + if not arg1_annotation_name.startswith("_"): + return + + method.decorator_list.clear() + new_syntax = re.sub(fr"(\W){arg1_annotation_name}(\W)", r"\1Self\2", ast.unparse(method)).replace("\n ", "") + errors.append(f"{path}:{method.lineno}: Use `_typeshed.Self` instead of `{arg1_annotation_name}`, e.g. `{new_syntax}`") + + def check_class_method_for_bad_typevars( + *, + method: ast.FunctionDef | ast.AsyncFunctionDef, + first_arg_annotation: ast.Name | ast.Subscript, + return_annotation: ast.Name, + ) -> None: + if not isinstance(first_arg_annotation, ast.Subscript): + return + + if not isinstance(first_arg_annotation.slice, ast.Name): + return + + if not isinstance(first_arg_annotation.value, ast.Name): + return + + if first_arg_annotation.value.id != "type": + return + + cls_typevar = first_arg_annotation.slice.id + + if cls_typevar == return_annotation.id and cls_typevar.startswith("_"): + method.decorator_list.clear() + new_syntax = re.sub(fr"(\W){cls_typevar}(\W)", r"\1Self\2", ast.unparse(method)).replace("\n ", "") + errors.append(f"{path}:{method.lineno}: Use `_typeshed.Self` instead of `{cls_typevar}`, e.g. `{new_syntax}`") + + class BadTypeVarFinder(ast.NodeVisitor): + def visit_FunctionDef(self, node: ast.FunctionDef) -> None: + return_annotation = node.returns + + if not node.args.args: + return + + first_arg_annotation = node.args.args[0].annotation + + if not isinstance(return_annotation, ast.Name): + return + + if not isinstance(first_arg_annotation, (ast.Name, ast.Subscript)): + return + + if node.name == "__new__": + check_class_method_for_bad_typevars( + method=node, first_arg_annotation=first_arg_annotation, return_annotation=return_annotation + ) + return + + for decorator in node.decorator_list: + if isinstance(decorator, ast.Name): + decorator_name = decorator.id + if decorator_name == "classmethod": + check_class_method_for_bad_typevars( + method=node, first_arg_annotation=first_arg_annotation, return_annotation=return_annotation + ) + return + elif decorator_name == "staticmethod": + return + + check_instance_method_for_bad_typevars( + method=node, first_arg_annotation=first_arg_annotation, return_annotation=return_annotation + ) + return + + visit_AsyncFunctionDef = visit_FunctionDef + class OldSyntaxFinder(ast.NodeVisitor): def visit_ImportFrom(self, node: ast.ImportFrom) -> None: if node.module == "collections.abc": @@ -133,6 +219,13 @@ def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None: UnionFinder().visit(node.returns) self.generic_visit(node) + # dateutil/relativedelta yields mypy errors about overlapping overloads if we use _typeshed.Self + if path != Path("stubs/python-dateutil/dateutil/relativedelta.pyi"): + + def visit_ClassDef(self, node: ast.ClassDef) -> None: + BadTypeVarFinder().visit(node) + self.generic_visit(node) + class ObjectClassdefFinder(ast.NodeVisitor): def visit_ClassDef(self, node: ast.ClassDef) -> None: if any(isinstance(base, ast.Name) and base.id == "object" for base in node.bases):