Skip to content

Commit 6c4278c

Browse files
committed
💅 Propagate error causes via asyncio protocols
This is supposed to unify setting exceptions on the future objects, allowing to also attach their causes whenever available. It's also supposed to help with tracking down what's happening with #4581.
1 parent 4b2eebd commit 6c4278c

11 files changed

+184
-65
lines changed

aiohttp/_http_parser.pyx

+7-5
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ from multidict import CIMultiDict as _CIMultiDict, CIMultiDictProxy as _CIMultiD
1919
from yarl import URL as _URL
2020

2121
from aiohttp import hdrs
22-
from aiohttp.helpers import DEBUG
22+
from aiohttp.helpers import DEBUG, set_exception
2323

2424
from .http_exceptions import (
2525
BadHttpMessage,
@@ -763,11 +763,13 @@ cdef int cb_on_body(cparser.llhttp_t* parser,
763763
cdef bytes body = at[:length]
764764
try:
765765
pyparser._payload.feed_data(body, length)
766-
except BaseException as exc:
766+
except BaseException as underlying_exc:
767+
reraised_exc = underlying_exc
767768
if pyparser._payload_exception is not None:
768-
pyparser._payload.set_exception(pyparser._payload_exception(str(exc)))
769-
else:
770-
pyparser._payload.set_exception(exc)
769+
reraised_exc = pyparser._payload_exception(str(underlying_exc))
770+
771+
set_exception(pyparser._payload, reraised_exc, underlying_exc)
772+
771773
pyparser._payload_error = 1
772774
return -1
773775
else:

aiohttp/base_protocol.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
from typing import Optional, cast
33

4+
from .helpers import set_exception
45
from .tcp_helpers import tcp_nodelay
56

67

@@ -76,7 +77,11 @@ def connection_lost(self, exc: Optional[BaseException]) -> None:
7677
if exc is None:
7778
waiter.set_result(None)
7879
else:
79-
waiter.set_exception(exc)
80+
set_exception(
81+
waiter,
82+
ConnectionError("Connection lost"),
83+
exc,
84+
)
8085

8186
async def _drain_helper(self) -> None:
8287
if not self.connected:

aiohttp/client_proto.py

+60-19
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@
44

55
from .base_protocol import BaseProtocol
66
from .client_exceptions import (
7+
ClientConnectionError,
78
ClientOSError,
89
ClientPayloadError,
910
ServerDisconnectedError,
1011
SocketTimeoutError,
1112
)
1213
from .helpers import (
1314
BaseTimerContext,
15+
ErrorableMixin,
1416
set_exception,
1517
set_result,
1618
status_code_must_be_empty_body,
@@ -80,41 +82,71 @@ def is_connected(self) -> bool:
8082
def connection_lost(self, exc: Optional[BaseException]) -> None:
8183
self._drop_timeout()
8284

83-
if exc is not None:
84-
set_exception(self.closed, exc)
85-
else:
85+
original_connection_error = exc
86+
reraised_exc = original_connection_error
87+
88+
connection_closed_cleanly = original_connection_error is None
89+
"""Whether the connection got a clean EOF."""
90+
91+
if connection_closed_cleanly:
8692
set_result(self.closed, None)
93+
else:
94+
set_exception(
95+
self.closed,
96+
ClientConnectionError(
97+
f"Connection lost: {original_connection_error !s}",
98+
),
99+
original_connection_error,
100+
)
87101

88102
if self._payload_parser is not None:
89-
with suppress(Exception):
103+
with suppress(Exception): # FIXME: log this somehow?
90104
self._payload_parser.feed_eof()
91105

92106
uncompleted = None
93107
if self._parser is not None:
94108
try:
95109
uncompleted = self._parser.feed_eof()
96-
except Exception as e:
110+
except Exception as underlying_exc:
97111
if self._payload is not None:
98-
exc = ClientPayloadError("Response payload is not completed")
99-
exc.__cause__ = e
100-
self._payload.set_exception(exc)
112+
client_payload_exc_msg = (
113+
f"Response payload is not completed: {underlying_exc !r}"
114+
)
115+
if not connection_closed_cleanly:
116+
client_payload_exc_msg = (
117+
f"{client_payload_exc_msg !s}. "
118+
f"{original_connection_error !r}"
119+
)
120+
set_exception(
121+
self._payload,
122+
ClientPayloadError(client_payload_exc_msg),
123+
underlying_exc,
124+
)
101125

102126
if not self.is_eof():
103-
if isinstance(exc, OSError):
104-
exc = ClientOSError(*exc.args)
105-
if exc is None:
106-
exc = ServerDisconnectedError(uncompleted)
127+
if isinstance(original_connection_error, OSError):
128+
reraised_exc = ClientOSError(*original_connection_error.args)
129+
if connection_closed_cleanly:
130+
reraised_exc = ServerDisconnectedError(uncompleted)
107131
# assigns self._should_close to True as side effect,
108132
# we do it anyway below
109-
self.set_exception(exc)
133+
set_exc_kwargs = (
134+
{}
135+
if connection_closed_cleanly
136+
else {
137+
"exc_cause": original_connection_error,
138+
}
139+
)
140+
assert reraised_exc is not None
141+
self.set_exception(reraised_exc, **set_exc_kwargs)
110142

111143
self._should_close = True
112144
self._parser = None
113145
self._payload = None
114146
self._payload_parser = None
115147
self._reading_paused = False
116148

117-
super().connection_lost(exc)
149+
super().connection_lost(reraised_exc)
118150

119151
def eof_received(self) -> None:
120152
# should call parser.feed_eof() most likely
@@ -128,10 +160,14 @@ def resume_reading(self) -> None:
128160
super().resume_reading()
129161
self._reschedule_timeout()
130162

131-
def set_exception(self, exc: BaseException) -> None:
163+
def set_exception(
164+
self,
165+
exc: BaseException,
166+
exc_cause: BaseException = ErrorableMixin._EXC_SENTINEL,
167+
) -> None:
132168
self._should_close = True
133169
self._drop_timeout()
134-
super().set_exception(exc)
170+
super().set_exception(exc, exc_cause)
135171

136172
def set_parser(self, parser: Any, payload: Any) -> None:
137173
# TODO: actual types are:
@@ -208,7 +244,7 @@ def _on_read_timeout(self) -> None:
208244
exc = SocketTimeoutError("Timeout on reading data from socket")
209245
self.set_exception(exc)
210246
if self._payload is not None:
211-
self._payload.set_exception(exc)
247+
set_exception(self._payload, exc)
212248

213249
def data_received(self, data: bytes) -> None:
214250
self._reschedule_timeout()
@@ -234,14 +270,19 @@ def data_received(self, data: bytes) -> None:
234270
# parse http messages
235271
try:
236272
messages, upgraded, tail = self._parser.feed_data(data)
237-
except BaseException as exc:
273+
except BaseException as underlying_exc:
238274
if self.transport is not None:
239275
# connection.release() could be called BEFORE
240276
# data_received(), the transport is already
241277
# closed in this case
242278
self.transport.close()
243279
# should_close is True after the call
244-
self.set_exception(exc)
280+
self.set_exception(
281+
ClientPayloadError(
282+
f"Unable to parse response payload: {underlying_exc !r}"
283+
),
284+
underlying_exc,
285+
)
245286
return
246287

247288
self._upgraded = upgraded

aiohttp/client_reqrep.py

+26-12
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
noop,
5454
parse_mimetype,
5555
reify,
56+
set_exception,
5657
set_result,
5758
)
5859
from .http import (
@@ -566,20 +567,33 @@ async def write_bytes(
566567

567568
for chunk in self.body:
568569
await writer.write(chunk) # type: ignore[arg-type]
569-
except OSError as exc:
570-
if exc.errno is None and isinstance(exc, asyncio.TimeoutError):
571-
protocol.set_exception(exc)
572-
else:
573-
new_exc = ClientOSError(
574-
exc.errno, "Can not write request body for %s" % self.url
570+
except OSError as underlying_exc:
571+
reraised_exc = underlying_exc
572+
573+
exc_is_not_timeout = underlying_exc.errno is not None or not isinstance(
574+
underlying_exc, asyncio.TimeoutError
575+
)
576+
if exc_is_not_timeout:
577+
reraised_exc = ClientOSError(
578+
underlying_exc.errno,
579+
f"Can not write request body for {self.url !s}",
575580
)
576-
new_exc.__context__ = exc
577-
new_exc.__cause__ = exc
578-
protocol.set_exception(new_exc)
581+
582+
# FIXME: Is setting `__context__` harmful?
583+
# FIXME: Especially, given that `__cause__` is also set...
584+
reraised_exc.__context__ = underlying_exc
585+
586+
set_exception(protocol, reraised_exc, underlying_exc)
579587
except asyncio.CancelledError:
580588
await writer.write_eof()
581-
except Exception as exc:
582-
protocol.set_exception(exc)
589+
except Exception as underlying_exc:
590+
set_exception(
591+
protocol,
592+
ClientConnectionError(
593+
f"Failed to send bytes into the underlying connection {conn !s}",
594+
),
595+
underlying_exc,
596+
)
583597
else:
584598
await writer.write_eof()
585599
protocol.start_timeout()
@@ -1019,7 +1033,7 @@ def _notify_content(self) -> None:
10191033
content = self.content
10201034
# content can be None here, but the types are cheated elsewhere.
10211035
if content and content.exception() is None: # type: ignore[truthy-bool]
1022-
content.set_exception(ClientConnectionError("Connection closed"))
1036+
set_exception(content, ClientConnectionError("Connection closed"))
10231037
self._released = True
10241038

10251039
async def wait_for_close(self) -> None:

aiohttp/helpers.py

+32-3
Original file line numberDiff line numberDiff line change
@@ -797,9 +797,38 @@ def set_result(fut: "asyncio.Future[_T]", result: _T) -> None:
797797
fut.set_result(result)
798798

799799

800-
def set_exception(fut: "asyncio.Future[_T]", exc: BaseException) -> None:
801-
if not fut.done():
802-
fut.set_exception(exc)
800+
class ErrorableMixin:
801+
_EXC_SENTINEL = BaseException()
802+
803+
def set_exception(
804+
self,
805+
exc: BaseException,
806+
exc_cause: BaseException = _EXC_SENTINEL,
807+
):
808+
raise NotImplementedError
809+
810+
811+
def set_exception(
812+
fut: "asyncio.Future[_T] | ErrorableMixin",
813+
exc: BaseException,
814+
exc_cause: BaseException = ErrorableMixin._EXC_SENTINEL,
815+
) -> None:
816+
"""Set future exception.
817+
818+
If the future is marked as complete, this function is a no-op.
819+
820+
:param exc_cause: An exception that is a direct cause of ``exc``.
821+
Only set if provided.
822+
"""
823+
if asyncio.isfuture(fut) and fut.done():
824+
return
825+
826+
exc_is_sentinel = exc_cause is ErrorableMixin._EXC_SENTINEL
827+
exc_causes_itself = exc is exc_cause
828+
if not exc_is_sentinel and not exc_causes_itself:
829+
exc.__cause__ = exc_cause
830+
831+
fut.set_exception(exc)
803832

804833

805834
@functools.total_ordering

aiohttp/http_parser.py

+18-9
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@
3131
DEBUG,
3232
NO_EXTENSIONS,
3333
BaseTimerContext,
34+
ErrorableMixin,
3435
method_must_be_empty_body,
36+
set_exception,
3537
status_code_must_be_empty_body,
3638
)
3739
from .http_exceptions import (
@@ -439,13 +441,16 @@ def get_content_length() -> Optional[int]:
439441
assert self._payload_parser is not None
440442
try:
441443
eof, data = self._payload_parser.feed_data(data[start_pos:], SEP)
442-
except BaseException as exc:
444+
except BaseException as underlying_exc:
445+
reraised_exc = underlying_exc
443446
if self.payload_exception is not None:
444-
self._payload_parser.payload.set_exception(
445-
self.payload_exception(str(exc))
446-
)
447-
else:
448-
self._payload_parser.payload.set_exception(exc)
447+
reraised_exc = self.payload_exception(str(underlying_exc))
448+
449+
set_exception(
450+
self._payload_parser.payload,
451+
reraised_exc,
452+
underlying_exc,
453+
)
449454

450455
eof = True
451456
data = b""
@@ -826,7 +831,7 @@ def feed_data(
826831
exc = TransferEncodingError(
827832
chunk[:pos].decode("ascii", "surrogateescape")
828833
)
829-
self.payload.set_exception(exc)
834+
set_exception(self.payload, exc)
830835
raise exc
831836
size = int(bytes(size_b), 16)
832837

@@ -929,8 +934,12 @@ def __init__(self, out: StreamReader, encoding: Optional[str]) -> None:
929934
else:
930935
self.decompressor = ZLibDecompressor(encoding=encoding)
931936

932-
def set_exception(self, exc: BaseException) -> None:
933-
self.out.set_exception(exc)
937+
def set_exception(
938+
self,
939+
exc: BaseException,
940+
exc_cause: BaseException = ErrorableMixin._EXC_SENTINEL,
941+
) -> None:
942+
set_exception(self.out, exc, exc_cause)
934943

935944
def feed_data(self, chunk: bytes, size: int) -> None:
936945
if not size:

aiohttp/http_websocket.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
from .base_protocol import BaseProtocol
2727
from .compression_utils import ZLibCompressor, ZLibDecompressor
28-
from .helpers import NO_EXTENSIONS
28+
from .helpers import NO_EXTENSIONS, set_exception
2929
from .streams import DataQueue
3030

3131
__all__ = (
@@ -305,7 +305,7 @@ def feed_data(self, data: bytes) -> Tuple[bool, bytes]:
305305
return self._feed_data(data)
306306
except Exception as exc:
307307
self._exc = exc
308-
self.queue.set_exception(exc)
308+
set_exception(self.queue, exc)
309309
return True, b""
310310

311311
def _feed_data(self, data: bytes) -> Tuple[bool, bytes]:

0 commit comments

Comments
 (0)