Skip to content

Commit e9fdf0a

Browse files
authored
Refactor web error handling (#5270)
1 parent 9f659ca commit e9fdf0a

6 files changed

+87
-917
lines changed

aiohttp/client_proto.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ def data_received(self, data: bytes) -> None:
238238
self._payload = payload
239239

240240
if self._skip_payload or message.code in (204, 304):
241-
self.feed_data((message, EMPTY_PAYLOAD), 0) # type: ignore
241+
self.feed_data((message, EMPTY_PAYLOAD), 0)
242242
else:
243243
self.feed_data((message, payload), 0)
244244
if payload is not None:

aiohttp/http_parser.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ def feed_data(
390390
if not payload_parser.done:
391391
self._payload_parser = payload_parser
392392
else:
393-
payload = EMPTY_PAYLOAD # type: ignore
393+
payload = EMPTY_PAYLOAD
394394

395395
messages.append((msg, payload))
396396
else:

aiohttp/streams.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import warnings
44
from typing import Awaitable, Callable, Generic, List, Optional, Tuple, TypeVar
55

6+
from typing_extensions import Final
7+
68
from .base_protocol import BaseProtocol
79
from .helpers import BaseTimerContext, set_exception, set_result
810
from .log import internal_logger
@@ -490,7 +492,10 @@ def _read_nowait(self, n: int) -> bytes:
490492
return b"".join(chunks) if chunks else b""
491493

492494

493-
class EmptyStreamReader(AsyncStreamReaderMixin):
495+
class EmptyStreamReader(StreamReader): # lgtm [py/missing-call-to-init]
496+
def __init__(self) -> None:
497+
pass
498+
494499
def exception(self) -> Optional[BaseException]:
495500
return None
496501

@@ -535,11 +540,11 @@ async def readchunk(self) -> Tuple[bytes, bool]:
535540
async def readexactly(self, n: int) -> bytes:
536541
raise asyncio.IncompleteReadError(b"", n)
537542

538-
def read_nowait(self) -> bytes:
543+
def read_nowait(self, n: int = -1) -> bytes:
539544
return b""
540545

541546

542-
EMPTY_PAYLOAD = EmptyStreamReader()
547+
EMPTY_PAYLOAD: Final[StreamReader] = EmptyStreamReader()
543548

544549

545550
class DataQueue(Generic[_T]):

aiohttp/web_protocol.py

+73-87
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@
1313
Callable,
1414
Deque,
1515
Optional,
16+
Sequence,
1617
Tuple,
1718
Type,
1819
Union,
1920
cast,
2021
)
2122

23+
import attr
2224
import yarl
2325

2426
from .abc import AbstractAccessLogger, AbstractAsyncAccessLogger, AbstractStreamWriter
@@ -62,7 +64,6 @@
6264
Type[AbstractAccessLogger],
6365
]
6466

65-
6667
ERROR = RawRequestMessage(
6768
"UNKNOWN", "/", HttpVersion10, {}, {}, True, False, False, False, yarl.URL("/")
6869
)
@@ -95,6 +96,16 @@ async def log(
9596
self.access_logger.log(request, response, self._loop.time() - request_start)
9697

9798

99+
@attr.s(auto_attribs=True, frozen=True, slots=True)
100+
class _ErrInfo:
101+
status: int
102+
exc: BaseException
103+
message: str
104+
105+
106+
_MsgType = Tuple[Union[RawRequestMessage, _ErrInfo], StreamReader]
107+
108+
98109
class RequestHandler(BaseProtocol):
99110
"""HTTP protocol implementation.
100111
@@ -106,30 +117,26 @@ class RequestHandler(BaseProtocol):
106117
status line, bad headers or incomplete payload. If any error occurs,
107118
connection gets closed.
108119
109-
:param keepalive_timeout: number of seconds before closing
110-
keep-alive connection
111-
:type keepalive_timeout: int or None
120+
keepalive_timeout -- number of seconds before closing
121+
keep-alive connection
112122
113-
:param bool tcp_keepalive: TCP keep-alive is on, default is on
123+
tcp_keepalive -- TCP keep-alive is on, default is on
114124
115-
:param logger: custom logger object
116-
:type logger: aiohttp.log.server_logger
125+
logger -- custom logger object
117126
118-
:param access_log_class: custom class for access_logger
119-
:type access_log_class: aiohttp.abc.AbstractAccessLogger
127+
access_log_class -- custom class for access_logger
120128
121-
:param access_log: custom logging object
122-
:type access_log: aiohttp.log.server_logger
129+
access_log -- custom logging object
123130
124-
:param str access_log_format: access log format string
131+
access_log_format -- access log format string
125132
126-
:param loop: Optional event loop
133+
loop -- Optional event loop
127134
128-
:param int max_line_size: Optional maximum header line size
135+
max_line_size -- Optional maximum header line size
129136
130-
:param int max_field_size: Optional maximum header field size
137+
max_field_size -- Optional maximum header field size
131138
132-
:param int max_headers: Optional maximum header size
139+
max_headers -- Optional maximum header size
133140
134141
"""
135142

@@ -149,7 +156,6 @@ class RequestHandler(BaseProtocol):
149156
"_messages",
150157
"_message_tail",
151158
"_waiter",
152-
"_error_handler",
153159
"_task_handler",
154160
"_upgrade",
155161
"_payload_parser",
@@ -180,19 +186,14 @@ def __init__(
180186
lingering_time: float = 10.0,
181187
read_bufsize: int = 2 ** 16,
182188
):
183-
184189
super().__init__(loop)
185190

186191
self._request_count = 0
187192
self._keepalive = False
188193
self._current_request = None # type: Optional[BaseRequest]
189194
self._manager = manager # type: Optional[Server]
190-
self._request_handler = (
191-
manager.request_handler
192-
) # type: Optional[_RequestHandler]
193-
self._request_factory = (
194-
manager.request_factory
195-
) # type: Optional[_RequestFactory]
195+
self._request_handler: Optional[_RequestHandler] = manager.request_handler
196+
self._request_factory: Optional[_RequestFactory] = manager.request_factory
196197

197198
self._tcp_keepalive = tcp_keepalive
198199
# placeholder to be replaced on keepalive timeout setup
@@ -201,11 +202,10 @@ def __init__(
201202
self._keepalive_timeout = keepalive_timeout
202203
self._lingering_time = float(lingering_time)
203204

204-
self._messages: Deque[Tuple[RawRequestMessage, StreamReader]] = deque()
205+
self._messages: Deque[_MsgType] = deque()
205206
self._message_tail = b""
206207

207208
self._waiter = None # type: Optional[asyncio.Future[None]]
208-
self._error_handler = None # type: Optional[asyncio.Task[None]]
209209
self._task_handler = None # type: Optional[asyncio.Task[None]]
210210

211211
self._upgrade = False
@@ -264,9 +264,6 @@ async def shutdown(self, timeout: Optional[float] = 15.0) -> None:
264264
# wait for handlers
265265
with suppress(asyncio.CancelledError, asyncio.TimeoutError):
266266
async with ceil_timeout(timeout):
267-
if self._error_handler is not None and not self._error_handler.done():
268-
await self._error_handler
269-
270267
if self._current_request is not None:
271268
self._current_request._cancel(asyncio.CancelledError())
272269

@@ -313,8 +310,6 @@ def connection_lost(self, exc: Optional[BaseException]) -> None:
313310
exc = ConnectionResetError("Connection lost")
314311
self._current_request._cancel(exc)
315312

316-
if self._error_handler is not None:
317-
self._error_handler.cancel()
318313
if self._task_handler is not None:
319314
self._task_handler.cancel()
320315
if self._waiter is not None:
@@ -343,40 +338,30 @@ def data_received(self, data: bytes) -> None:
343338
if self._force_close or self._close:
344339
return
345340
# parse http messages
341+
messages: Sequence[_MsgType]
346342
if self._payload_parser is None and not self._upgrade:
347343
assert self._request_parser is not None
348344
try:
349345
messages, upgraded, tail = self._request_parser.feed_data(data)
350346
except HttpProcessingError as exc:
351-
# something happened during parsing
352-
self._error_handler = self._loop.create_task(
353-
self.handle_parse_error(
354-
StreamWriter(self, self._loop), 400, exc, exc.message
355-
)
356-
)
357-
self.close()
358-
except Exception as exc:
359-
# 500: internal error
360-
self._error_handler = self._loop.create_task(
361-
self.handle_parse_error(StreamWriter(self, self._loop), 500, exc)
362-
)
363-
self.close()
364-
else:
365-
if messages:
366-
# sometimes the parser returns no messages
367-
for (msg, payload) in messages:
368-
self._request_count += 1
369-
self._messages.append((msg, payload))
370-
371-
waiter = self._waiter
372-
if waiter is not None:
373-
if not waiter.done():
374-
# don't set result twice
375-
waiter.set_result(None)
376-
377-
self._upgrade = upgraded
378-
if upgraded and tail:
379-
self._message_tail = tail
347+
messages = [
348+
(_ErrInfo(status=400, exc=exc, message=exc.message), EMPTY_PAYLOAD)
349+
]
350+
upgraded = False
351+
tail = b""
352+
353+
for msg, payload in messages or ():
354+
self._request_count += 1
355+
self._messages.append((msg, payload))
356+
357+
waiter = self._waiter
358+
if messages and waiter is not None and not waiter.done():
359+
# don't set result twice
360+
waiter.set_result(None)
361+
362+
self._upgrade = upgraded
363+
if upgraded and tail:
364+
self._message_tail = tail
380365

381366
# no parser, just store
382367
elif self._payload_parser is None and self._upgrade and data:
@@ -449,12 +434,13 @@ async def _handle_request(
449434
self,
450435
request: BaseRequest,
451436
start_time: float,
437+
request_handler: Callable[[BaseRequest], Awaitable[StreamResponse]],
452438
) -> Tuple[StreamResponse, bool]:
453439
assert self._request_handler is not None
454440
try:
455441
try:
456442
self._current_request = request
457-
resp = await self._request_handler(request)
443+
resp = await request_handler(request)
458444
finally:
459445
self._current_request = None
460446
except HTTPException as exc:
@@ -513,10 +499,19 @@ async def start(self) -> None:
513499

514500
manager.requests_count += 1
515501
writer = StreamWriter(self, loop)
502+
if isinstance(message, _ErrInfo):
503+
# make request_factory work
504+
request_handler = self._make_error_handler(message)
505+
message = ERROR
506+
else:
507+
request_handler = self._request_handler
508+
516509
request = self._request_factory(message, payload, self, writer, handler)
517510
try:
518511
# a new task is used for copy context vars (#3406)
519-
task = self._loop.create_task(self._handle_request(request, start))
512+
task = self._loop.create_task(
513+
self._handle_request(request, start, request_handler)
514+
)
520515
try:
521516
resp, reset = await task
522517
except (asyncio.CancelledError, ConnectionError):
@@ -586,7 +581,7 @@ async def start(self) -> None:
586581
# remove handler, close transport if no handlers left
587582
if not self._force_close:
588583
self._task_handler = None
589-
if self.transport is not None and self._error_handler is None:
584+
if self.transport is not None:
590585
self.transport.close()
591586

592587
async def finish_response(
@@ -639,6 +634,13 @@ def handle_error(
639634
information. It always closes current connection."""
640635
self.log_exception("Error handling request", exc_info=exc)
641636

637+
# some data already got sent, connection is broken
638+
if request.writer.output_size > 0:
639+
raise ConnectionError(
640+
"Response is sent already, cannot send another response "
641+
"with the error message"
642+
)
643+
642644
ct = "text/plain"
643645
if status == HTTPStatus.INTERNAL_SERVER_ERROR:
644646
title = "{0.value} {0.phrase}".format(HTTPStatus.INTERNAL_SERVER_ERROR)
@@ -667,30 +669,14 @@ def handle_error(
667669
resp = Response(status=status, text=message, content_type=ct)
668670
resp.force_close()
669671

670-
# some data already got sent, connection is broken
671-
if request.writer.output_size > 0 or self.transport is None:
672-
self.force_close()
673-
674672
return resp
675673

676-
async def handle_parse_error(
677-
self,
678-
writer: AbstractStreamWriter,
679-
status: int,
680-
exc: Optional[BaseException] = None,
681-
message: Optional[str] = None,
682-
) -> None:
683-
task = asyncio.current_task()
684-
assert task is not None
685-
request = BaseRequest(
686-
ERROR, EMPTY_PAYLOAD, self, writer, task, self._loop # type: ignore
687-
)
688-
689-
resp = self.handle_error(request, status, exc, message)
690-
await resp.prepare(request)
691-
await resp.write_eof()
692-
693-
if self.transport is not None:
694-
self.transport.close()
674+
def _make_error_handler(
675+
self, err_info: _ErrInfo
676+
) -> Callable[[BaseRequest], Awaitable[StreamResponse]]:
677+
async def handler(request: BaseRequest) -> StreamResponse:
678+
return self.handle_error(
679+
request, err_info.status, err_info.exc, err_info.message
680+
)
695681

696-
self._error_handler = None
682+
return handler

tests/test_streams.py

+4
Original file line numberDiff line numberDiff line change
@@ -1472,3 +1472,7 @@ async def test_stream_reader_iter_chunks_chunked_encoding(protocol: Any) -> None
14721472
async for data, end_of_chunk in stream.iter_chunks():
14731473
assert (data, end_of_chunk) == (next(it), True)
14741474
pytest.raises(StopIteration, next, it)
1475+
1476+
1477+
def test_isinstance_check() -> None:
1478+
assert isinstance(streams.EMPTY_PAYLOAD, streams.StreamReader)

0 commit comments

Comments
 (0)