Skip to content

Commit fb74025

Browse files
authoredOct 9, 2024··
Preserve source positions for assertion rewriting
Closes #12818
1 parent 294e113 commit fb74025

File tree

4 files changed

+223
-11
lines changed

4 files changed

+223
-11
lines changed
 

‎AUTHORS

+1
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ Feng Ma
160160
Florian Bruhin
161161
Florian Dahlitz
162162
Floris Bruynooghe
163+
Frank Hoffmann
163164
Fraser Stark
164165
Gabriel Landau
165166
Gabriel Reis

‎changelog/12818.bugfix.rst

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Assertion rewriting now preserves the source ranges of the original instructions, making it play well with tools that deal with the ``AST``, like `executing <https://github.com/alexmojaki/executing>`__.

‎src/_pytest/assertion/rewrite.py

+14-7
Original file line numberDiff line numberDiff line change
@@ -792,7 +792,7 @@ def assign(self, expr: ast.expr) -> ast.Name:
792792
"""Give *expr* a name."""
793793
name = self.variable()
794794
self.statements.append(ast.Assign([ast.Name(name, ast.Store())], expr))
795-
return ast.Name(name, ast.Load())
795+
return ast.copy_location(ast.Name(name, ast.Load()), expr)
796796

797797
def display(self, expr: ast.expr) -> ast.expr:
798798
"""Call saferepr on the expression."""
@@ -975,7 +975,10 @@ def visit_Assert(self, assert_: ast.Assert) -> list[ast.stmt]:
975975
# Fix locations (line numbers/column offsets).
976976
for stmt in self.statements:
977977
for node in traverse_node(stmt):
978-
ast.copy_location(node, assert_)
978+
if getattr(node, "lineno", None) is None:
979+
# apply the assertion location to all generated ast nodes without source location
980+
# and preserve the location of existing nodes or generated nodes with an correct location.
981+
ast.copy_location(node, assert_)
979982
return self.statements
980983

981984
def visit_NamedExpr(self, name: ast.NamedExpr) -> tuple[ast.NamedExpr, str]:
@@ -1052,15 +1055,17 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> tuple[ast.Name, str]:
10521055
def visit_UnaryOp(self, unary: ast.UnaryOp) -> tuple[ast.Name, str]:
10531056
pattern = UNARY_MAP[unary.op.__class__]
10541057
operand_res, operand_expl = self.visit(unary.operand)
1055-
res = self.assign(ast.UnaryOp(unary.op, operand_res))
1058+
res = self.assign(ast.copy_location(ast.UnaryOp(unary.op, operand_res), unary))
10561059
return res, pattern % (operand_expl,)
10571060

10581061
def visit_BinOp(self, binop: ast.BinOp) -> tuple[ast.Name, str]:
10591062
symbol = BINOP_MAP[binop.op.__class__]
10601063
left_expr, left_expl = self.visit(binop.left)
10611064
right_expr, right_expl = self.visit(binop.right)
10621065
explanation = f"({left_expl} {symbol} {right_expl})"
1063-
res = self.assign(ast.BinOp(left_expr, binop.op, right_expr))
1066+
res = self.assign(
1067+
ast.copy_location(ast.BinOp(left_expr, binop.op, right_expr), binop)
1068+
)
10641069
return res, explanation
10651070

10661071
def visit_Call(self, call: ast.Call) -> tuple[ast.Name, str]:
@@ -1089,7 +1094,7 @@ def visit_Call(self, call: ast.Call) -> tuple[ast.Name, str]:
10891094
arg_expls.append("**" + expl)
10901095

10911096
expl = "{}({})".format(func_expl, ", ".join(arg_expls))
1092-
new_call = ast.Call(new_func, new_args, new_kwargs)
1097+
new_call = ast.copy_location(ast.Call(new_func, new_args, new_kwargs), call)
10931098
res = self.assign(new_call)
10941099
res_expl = self.explanation_param(self.display(res))
10951100
outer_expl = f"{res_expl}\n{{{res_expl} = {expl}\n}}"
@@ -1105,7 +1110,9 @@ def visit_Attribute(self, attr: ast.Attribute) -> tuple[ast.Name, str]:
11051110
if not isinstance(attr.ctx, ast.Load):
11061111
return self.generic_visit(attr)
11071112
value, value_expl = self.visit(attr.value)
1108-
res = self.assign(ast.Attribute(value, attr.attr, ast.Load()))
1113+
res = self.assign(
1114+
ast.copy_location(ast.Attribute(value, attr.attr, ast.Load()), attr)
1115+
)
11091116
res_expl = self.explanation_param(self.display(res))
11101117
pat = "%s\n{%s = %s.%s\n}"
11111118
expl = pat % (res_expl, res_expl, value_expl, attr.attr)
@@ -1146,7 +1153,7 @@ def visit_Compare(self, comp: ast.Compare) -> tuple[ast.expr, str]:
11461153
syms.append(ast.Constant(sym))
11471154
expl = f"{left_expl} {sym} {next_expl}"
11481155
expls.append(ast.Constant(expl))
1149-
res_expr = ast.Compare(left_res, [op], [next_res])
1156+
res_expr = ast.copy_location(ast.Compare(left_res, [op], [next_res]), comp)
11501157
self.statements.append(ast.Assign([store_names[i]], res_expr))
11511158
left_res, left_expl = next_res, next_expl
11521159
# Use pytest.assertion.util._reprcompare if that's available.

‎testing/test_assertrewrite.py

+207-4
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
from __future__ import annotations
33

44
import ast
5+
import dis
56
import errno
67
from functools import partial
78
import glob
89
import importlib
10+
import inspect
911
import marshal
1012
import os
1113
from pathlib import Path
@@ -131,10 +133,211 @@ def test_location_is_set(self) -> None:
131133
continue
132134
for n in [node, *ast.iter_child_nodes(node)]:
133135
assert isinstance(n, (ast.stmt, ast.expr))
134-
assert n.lineno == 3
135-
assert n.col_offset == 0
136-
assert n.end_lineno == 6
137-
assert n.end_col_offset == 3
136+
for location in [
137+
(n.lineno, n.col_offset),
138+
(n.end_lineno, n.end_col_offset),
139+
]:
140+
assert (3, 0) <= location <= (6, 3)
141+
142+
def test_positions_are_preserved(self) -> None:
143+
"""Ensure AST positions are preserved during rewriting (#12818)."""
144+
145+
def preserved(code: str) -> None:
146+
s = textwrap.dedent(code)
147+
locations = []
148+
149+
def loc(msg: str | None = None) -> None:
150+
frame = inspect.currentframe()
151+
assert frame
152+
frame = frame.f_back
153+
assert frame
154+
frame = frame.f_back
155+
assert frame
156+
157+
offset = frame.f_lasti
158+
159+
instructions = {i.offset: i for i in dis.get_instructions(frame.f_code)}
160+
161+
# skip CACHE instructions
162+
while offset not in instructions and offset >= 0:
163+
offset -= 1
164+
165+
instruction = instructions[offset]
166+
if sys.version_info >= (3, 11):
167+
position = instruction.positions
168+
else:
169+
position = instruction.starts_line
170+
171+
locations.append((msg, instruction.opname, position))
172+
173+
globals = {"loc": loc}
174+
175+
m = rewrite(s)
176+
mod = compile(m, "<string>", "exec")
177+
exec(mod, globals, globals)
178+
transformed_locations = locations
179+
locations = []
180+
181+
mod = compile(s, "<string>", "exec")
182+
exec(mod, globals, globals)
183+
original_locations = locations
184+
185+
assert len(original_locations) > 0
186+
assert original_locations == transformed_locations
187+
188+
preserved("""
189+
def f():
190+
loc()
191+
return 8
192+
193+
assert f() in [8]
194+
assert (f()
195+
in
196+
[8])
197+
""")
198+
199+
preserved("""
200+
class T:
201+
def __init__(self):
202+
loc("init")
203+
def __getitem__(self,index):
204+
loc("getitem")
205+
return index
206+
207+
assert T()[5] == 5
208+
assert (T
209+
()
210+
[5]
211+
==
212+
5)
213+
""")
214+
215+
for name, op in [
216+
("pos", "+"),
217+
("neg", "-"),
218+
("invert", "~"),
219+
]:
220+
preserved(f"""
221+
class T:
222+
def __{name}__(self):
223+
loc("{name}")
224+
return "{name}"
225+
226+
assert {op}T() == "{name}"
227+
assert ({op}
228+
T
229+
()
230+
==
231+
"{name}")
232+
""")
233+
234+
for name, op in [
235+
("add", "+"),
236+
("sub", "-"),
237+
("mul", "*"),
238+
("truediv", "/"),
239+
("floordiv", "//"),
240+
("mod", "%"),
241+
("pow", "**"),
242+
("lshift", "<<"),
243+
("rshift", ">>"),
244+
("or", "|"),
245+
("xor", "^"),
246+
("and", "&"),
247+
("matmul", "@"),
248+
]:
249+
preserved(f"""
250+
class T:
251+
def __{name}__(self,other):
252+
loc("{name}")
253+
return other
254+
255+
def __r{name}__(self,other):
256+
loc("r{name}")
257+
return other
258+
259+
assert T() {op} 2 == 2
260+
assert 2 {op} T() == 2
261+
262+
assert (T
263+
()
264+
{op}
265+
2
266+
==
267+
2)
268+
269+
assert (2
270+
{op}
271+
T
272+
()
273+
==
274+
2)
275+
""")
276+
277+
for name, op in [
278+
("eq", "=="),
279+
("ne", "!="),
280+
("lt", "<"),
281+
("le", "<="),
282+
("gt", ">"),
283+
("ge", ">="),
284+
]:
285+
preserved(f"""
286+
class T:
287+
def __{name}__(self,other):
288+
loc()
289+
return True
290+
291+
assert T() {op} 5
292+
assert (T
293+
()
294+
{op}
295+
5)
296+
""")
297+
298+
for name, op in [
299+
("eq", "=="),
300+
("ne", "!="),
301+
("lt", ">"),
302+
("le", ">="),
303+
("gt", "<"),
304+
("ge", "<="),
305+
("contains", "in"),
306+
]:
307+
preserved(f"""
308+
class T:
309+
def __{name}__(self,other):
310+
loc()
311+
return True
312+
313+
assert 5 {op} T()
314+
assert (5
315+
{op}
316+
T
317+
())
318+
""")
319+
320+
preserved("""
321+
def func(value):
322+
loc("func")
323+
return value
324+
325+
class T:
326+
def __iter__(self):
327+
loc("iter")
328+
return iter([5])
329+
330+
assert func(*T()) == 5
331+
""")
332+
333+
preserved("""
334+
class T:
335+
def __getattr__(self,name):
336+
loc()
337+
return name
338+
339+
assert T().attr == "attr"
340+
""")
138341

139342
def test_dont_rewrite(self) -> None:
140343
s = """'PYTEST_DONT_REWRITE'\nassert 14"""

0 commit comments

Comments
 (0)
Please sign in to comment.