From 154c54dcdd1c9372c82d90b286c47dec1749d369 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Tue, 1 Oct 2019 12:43:28 +0300 Subject: [PATCH 1/3] Don't cancel web handler on disconnection (#4080) --- CHANGES/4080.feature | 1 + aiohttp/http_websocket.py | 18 +++-- aiohttp/web_protocol.py | 19 ++++- aiohttp/web_request.py | 3 + aiohttp/web_ws.py | 4 + docs/web_advanced.rst | 103 +++---------------------- tests/test_client_ws.py | 10 +-- tests/test_web_protocol.py | 6 +- tests/test_web_websocket.py | 96 ++++------------------- tests/test_web_websocket_functional.py | 26 +------ tests/test_websocket_writer.py | 4 +- 11 files changed, 73 insertions(+), 217 deletions(-) create mode 100644 CHANGES/4080.feature diff --git a/CHANGES/4080.feature b/CHANGES/4080.feature new file mode 100644 index 00000000000..4032817a418 --- /dev/null +++ b/CHANGES/4080.feature @@ -0,0 +1 @@ +Don't cancel web handler on peer disconnection, raise `OSError` on reading/writing instead. diff --git a/aiohttp/http_websocket.py b/aiohttp/http_websocket.py index e331d691aac..12211d5a6a4 100644 --- a/aiohttp/http_websocket.py +++ b/aiohttp/http_websocket.py @@ -13,7 +13,6 @@ from .base_protocol import BaseProtocol from .helpers import NO_EXTENSIONS -from .log import ws_logger from .streams import DataQueue __all__ = ('WS_CLOSED_MESSAGE', 'WS_CLOSING_MESSAGE', 'WS_KEY', @@ -567,8 +566,8 @@ def __init__(self, protocol: BaseProtocol, transport: asyncio.Transport, *, async def _send_frame(self, message: bytes, opcode: int, compress: Optional[int]=None) -> None: """Send a frame over the websocket with message as its payload.""" - if self._closing: - ws_logger.warning('websocket connection is closing.') + if self._closing and not (opcode & WSMsgType.CLOSE): + raise ConnectionResetError('Cannot write to closing transport') rsv = 0 @@ -610,14 +609,14 @@ async def _send_frame(self, message: bytes, opcode: int, mask = mask.to_bytes(4, 'big') message = bytearray(message) _websocket_mask(mask, message) - self.transport.write(header + mask + message) + self._write(header + mask + message) self._output_size += len(header) + len(mask) + len(message) else: if len(message) > MSG_SIZE: - self.transport.write(header) - self.transport.write(message) + self._write(header) + self._write(message) else: - self.transport.write(header + message) + self._write(header + message) self._output_size += len(header) + len(message) @@ -625,6 +624,11 @@ async def _send_frame(self, message: bytes, opcode: int, self._output_size = 0 await self.protocol._drain_helper() + def _write(self, data: bytes) -> None: + if self.transport is None or self.transport.is_closing(): + raise ConnectionResetError('Cannot write to closing transport') + self.transport.write(data) + async def pong(self, message: bytes=b'') -> None: """Send pong message.""" if isinstance(message, str): diff --git a/aiohttp/web_protocol.py b/aiohttp/web_protocol.py index 8796739644f..d670fa24d37 100644 --- a/aiohttp/web_protocol.py +++ b/aiohttp/web_protocol.py @@ -115,7 +115,8 @@ class RequestHandler(BaseProtocol): '_waiter', '_error_handler', '_task_handler', '_upgrade', '_payload_parser', '_request_parser', '_reading_paused', 'logger', 'debug', 'access_log', - 'access_logger', '_close', '_force_close') + 'access_logger', '_close', '_force_close', + '_current_request') def __init__(self, manager: 'Server', *, loop: asyncio.AbstractEventLoop, @@ -135,6 +136,7 @@ def __init__(self, manager: 'Server', *, self._request_count = 0 self._keepalive = False + self._current_request = None # type: Optional[BaseRequest] self._manager = manager # type: Optional[Server] self._request_handler = manager.request_handler # type: Optional[_RequestHandler] # noqa self._request_factory = manager.request_factory # type: Optional[_RequestFactory] # noqa @@ -202,6 +204,9 @@ async def shutdown(self, timeout: Optional[float]=15.0) -> None: not self._error_handler.done()): await self._error_handler + if self._current_request is not None: + self._current_request._cancel(asyncio.CancelledError()) + if (self._task_handler is not None and not self._task_handler.done()): await self._task_handler @@ -241,8 +246,10 @@ def connection_lost(self, exc: Optional[BaseException]) -> None: if self._keepalive_handle is not None: self._keepalive_handle.cancel() - if self._task_handler is not None: - self._task_handler.cancel() + if self._current_request is not None: + if exc is None: + exc = ConnectionResetError("Connetion lost") + self._current_request._cancel(exc) if self._error_handler is not None: self._error_handler.cancel() @@ -378,7 +385,11 @@ async def _handle_request(self, ) -> Tuple[StreamResponse, bool]: assert self._request_handler is not None try: - resp = await self._request_handler(request) + try: + self._current_request = request + resp = await self._request_handler(request) + finally: + self._current_request = None except HTTPException as exc: resp = Response(status=exc.status, reason=exc.reason, diff --git a/aiohttp/web_request.py b/aiohttp/web_request.py index 3b6c7d0c235..7378612267f 100644 --- a/aiohttp/web_request.py +++ b/aiohttp/web_request.py @@ -696,6 +696,9 @@ def __bool__(self) -> bool: async def _prepare_hook(self, response: StreamResponse) -> None: return + def _cancel(self, exc: BaseException) -> None: + self._payload.set_exception(exc) + class Request(BaseRequest): diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py index aad245c7b54..2631725f488 100644 --- a/aiohttp/web_ws.py +++ b/aiohttp/web_ws.py @@ -455,3 +455,7 @@ async def __anext__(self) -> WSMessage: WSMsgType.CLOSED): raise StopAsyncIteration # NOQA return msg + + def _cancel(self, exc: BaseException) -> None: + if self._reader is not None: + self._reader.set_exception(exc) diff --git a/docs/web_advanced.rst b/docs/web_advanced.rst index d36ef7ff68b..f1096258caa 100644 --- a/docs/web_advanced.rst +++ b/docs/web_advanced.rst @@ -20,103 +20,22 @@ But in case of custom regular expressions for *percent encoded*: if you pass Unicode patterns they don't match to *requoted* path. +Peer disconnection +------------------ -Web Handler Cancellation ------------------------- - -.. warning:: - - :term:`web-handler` execution could be canceled on every ``await`` - if client drops connection without reading entire response's BODY. - - The behavior is very different from classic WSGI frameworks like - Flask and Django. - -Sometimes it is a desirable behavior: on processing ``GET`` request the -code might fetch data from database or other web resource, the -fetching is potentially slow. - -Canceling this fetch is very good: the peer dropped connection -already, there is no reason to waste time and resources (memory etc) by -getting data from DB without any chance to send it back to peer. - -But sometimes the cancellation is bad: on ``POST`` request very often -is needed to save data to DB regardless to peer closing. - -Cancellation prevention could be implemented in several ways: - -* Applying :func:`asyncio.shield` to coroutine that saves data into DB. -* Spawning a new task for DB saving -* Using aiojobs_ or other third party library. - -:func:`asyncio.shield` works pretty good. The only disadvantage is you -need to split web handler into exactly two async functions: one -for handler itself and other for protected code. - -For example the following snippet is not safe:: - - async def handler(request): - await asyncio.shield(write_to_redis(request)) - await asyncio.shield(write_to_postgres(request)) - return web.Response(text='OK') - -Cancellation might be occurred just after saving data in REDIS, -``write_to_postgres`` will be not called. - -Spawning a new task is much worse: there is no place to ``await`` -spawned tasks:: - - async def handler(request): - request.loop.create_task(write_to_redis(request)) - return web.Response(text='OK') - -In this case errors from ``write_to_redis`` are not awaited, it leads -to many asyncio log messages *Future exception was never retrieved* -and *Task was destroyed but it is pending!*. - -Moreover on :ref:`aiohttp-web-graceful-shutdown` phase *aiohttp* don't -wait for these tasks, you have a great chance to loose very important -data. - -On other hand aiojobs_ provides an API for spawning new jobs and -awaiting their results etc. It stores all scheduled activity in -internal data structures and could terminate them gracefully:: - - from aiojobs.aiohttp import setup, spawn - - async def coro(timeout): - await asyncio.sleep(timeout) # do something in background - - async def handler(request): - await spawn(request, coro()) - return web.Response() - - app = web.Application() - setup(app) - app.router.add_get('/', handler) - -All not finished jobs will be terminated on -:attr:`Application.on_cleanup` signal. +When a client peer is gone a subsequent reading or writing raises :exc:`OSError` +or more specific exception like :exc:`ConnectionResetError`. -To prevent cancellation of the whole :term:`web-handler` use -``@atomic`` decorator:: +The reason for disconnection is vary; it can be a network issue or explicit +socket closing on the peer side without reading the whole server response. - from aiojobs.aiohttp import atomic +*aiohttp* handles disconnection properly but you can handle it explicitly, e.g.:: - @atomic async def handler(request): - await write_to_db() - return web.Response() - - app = web.Application() - setup(app) - app.router.add_post('/', handler) - -It prevents all ``handler`` async function from cancellation, -``write_to_db`` will be never interrupted. - -.. _aiojobs: http://aiojobs.readthedocs.io/en/latest/ - + try: + text = await request.text() + except OSError: + # disconnected Passing a coroutine into run_app and Gunicorn --------------------------------------------- diff --git a/tests/test_client_ws.py b/tests/test_client_ws.py index 8dc98e7876a..1678fb904c8 100644 --- a/tests/test_client_ws.py +++ b/tests/test_client_ws.py @@ -9,7 +9,6 @@ import aiohttp from aiohttp import client, hdrs from aiohttp.http import WS_KEY -from aiohttp.log import ws_logger from aiohttp.streams import EofStream from aiohttp.test_utils import make_mocked_coro @@ -363,7 +362,7 @@ async def test_close_exc2(loop, ws_key, key_data) -> None: await resp.close() -async def test_send_data_after_close(ws_key, key_data, loop, mocker) -> None: +async def test_send_data_after_close(ws_key, key_data, loop) -> None: resp = mock.Mock() resp.status = 101 resp.headers = { @@ -381,16 +380,13 @@ async def test_send_data_after_close(ws_key, key_data, loop, mocker) -> None: 'http://test.org') resp._writer._closing = True - mocker.spy(ws_logger, 'warning') - for meth, args in ((resp.ping, ()), (resp.pong, ()), (resp.send_str, ('s',)), (resp.send_bytes, (b'b',)), (resp.send_json, ({},))): - await meth(*args) - assert ws_logger.warning.called - ws_logger.warning.reset_mock() + with pytest.raises(ConnectionResetError): + await meth(*args) async def test_send_data_type_errors(ws_key, key_data, loop) -> None: diff --git a/tests/test_web_protocol.py b/tests/test_web_protocol.py index 9b23f17f27b..e7b1784a21d 100644 --- a/tests/test_web_protocol.py +++ b/tests/test_web_protocol.py @@ -819,7 +819,11 @@ async def test_two_data_received_without_waking_up_start_task(srv) -> None: async def test_client_disconnect(aiohttp_server) -> None: async def handler(request): - await request.content.read(10) + buf = b"" + with pytest.raises(ConnectionError): + while len(buf) < 10: + buf += await request.content.read(10) + # return with closed transport means premature client disconnection return web.Response() logger = mock.Mock() diff --git a/tests/test_web_websocket.py b/tests/test_web_websocket.py index cab47b01066..0ded90e268a 100644 --- a/tests/test_web_websocket.py +++ b/tests/test_web_websocket.py @@ -5,7 +5,6 @@ from multidict import CIMultiDict from aiohttp import WSMessage, WSMsgType, signals -from aiohttp.log import ws_logger from aiohttp.streams import EofStream from aiohttp.test_utils import make_mocked_coro, make_mocked_request from aiohttp.web import HTTPBadRequest, WebSocketResponse @@ -226,52 +225,48 @@ def test_closed_after_ctor() -> None: assert ws.close_code is None -async def test_send_str_closed(make_request, mocker) -> None: +async def test_send_str_closed(make_request) -> None: req = make_request('GET', '/') ws = WebSocketResponse() await ws.prepare(req) ws._reader.feed_data(WS_CLOSED_MESSAGE, 0) await ws.close() - mocker.spy(ws_logger, 'warning') - await ws.send_str('string') - assert ws_logger.warning.called + with pytest.raises(ConnectionError): + await ws.send_str('string') -async def test_send_bytes_closed(make_request, mocker) -> None: +async def test_send_bytes_closed(make_request) -> None: req = make_request('GET', '/') ws = WebSocketResponse() await ws.prepare(req) ws._reader.feed_data(WS_CLOSED_MESSAGE, 0) await ws.close() - mocker.spy(ws_logger, 'warning') - await ws.send_bytes(b'bytes') - assert ws_logger.warning.called + with pytest.raises(ConnectionError): + await ws.send_bytes(b'bytes') -async def test_send_json_closed(make_request, mocker) -> None: +async def test_send_json_closed(make_request) -> None: req = make_request('GET', '/') ws = WebSocketResponse() await ws.prepare(req) ws._reader.feed_data(WS_CLOSED_MESSAGE, 0) await ws.close() - mocker.spy(ws_logger, 'warning') - await ws.send_json({'type': 'json'}) - assert ws_logger.warning.called + with pytest.raises(ConnectionError): + await ws.send_json({'type': 'json'}) -async def test_ping_closed(make_request, mocker) -> None: +async def test_ping_closed(make_request) -> None: req = make_request('GET', '/') ws = WebSocketResponse() await ws.prepare(req) ws._reader.feed_data(WS_CLOSED_MESSAGE, 0) await ws.close() - mocker.spy(ws_logger, 'warning') - await ws.ping() - assert ws_logger.warning.called + with pytest.raises(ConnectionError): + await ws.ping() async def test_pong_closed(make_request, mocker) -> None: @@ -281,9 +276,8 @@ async def test_pong_closed(make_request, mocker) -> None: ws._reader.feed_data(WS_CLOSED_MESSAGE, 0) await ws.close() - mocker.spy(ws_logger, 'warning') - await ws.pong() - assert ws_logger.warning.called + with pytest.raises(ConnectionError): + await ws.pong() async def test_close_idempotent(make_request) -> None: @@ -354,40 +348,6 @@ async def test_receive_eofstream_in_reader(make_request, loop) -> None: assert ws.closed -async def test_receive_exc_in_reader(make_request, loop) -> None: - req = make_request('GET', '/') - ws = WebSocketResponse() - await ws.prepare(req) - - ws._reader = mock.Mock() - exc = ValueError() - res = loop.create_future() - res.set_exception(exc) - ws._reader.read = make_mocked_coro(res) - ws._payload_writer.drain = mock.Mock() - ws._payload_writer.drain.return_value = loop.create_future() - ws._payload_writer.drain.return_value.set_result(True) - - msg = await ws.receive() - assert msg.type == WSMsgType.ERROR - assert msg.data is exc - assert ws.exception() is exc - - -async def test_receive_cancelled(make_request, loop) -> None: - req = make_request('GET', '/') - ws = WebSocketResponse() - await ws.prepare(req) - - ws._reader = mock.Mock() - res = loop.create_future() - res.set_exception(asyncio.CancelledError()) - ws._reader.read = make_mocked_coro(res) - - with pytest.raises(asyncio.CancelledError): - await ws.receive() - - async def test_receive_timeouterror(make_request, loop) -> None: req = make_request('GET', '/') ws = WebSocketResponse() @@ -428,33 +388,7 @@ async def test_concurrent_receive(make_request) -> None: await ws.receive() -async def test_close_exc(make_request, loop, mocker) -> None: - req = make_request('GET', '/') - - ws = WebSocketResponse() - await ws.prepare(req) - - ws._reader = mock.Mock() - exc = ValueError() - ws._reader.read.return_value = loop.create_future() - ws._reader.read.return_value.set_exception(exc) - ws._payload_writer.drain = mock.Mock() - ws._payload_writer.drain.return_value = loop.create_future() - ws._payload_writer.drain.return_value.set_result(True) - - await ws.close() - assert ws.closed - assert ws.exception() is exc - - ws._closed = False - ws._reader.read.return_value = loop.create_future() - ws._reader.read.return_value.set_exception(asyncio.CancelledError()) - with pytest.raises(asyncio.CancelledError): - await ws.close() - assert ws.close_code == 1006 - - -async def test_close_exc2(make_request) -> None: +async def test_close_exc(make_request) -> None: req = make_request('GET', '/') ws = WebSocketResponse() diff --git a/tests/test_web_websocket_functional.py b/tests/test_web_websocket_functional.py index 59f5e11018f..df4c051e35e 100644 --- a/tests/test_web_websocket_functional.py +++ b/tests/test_web_websocket_functional.py @@ -276,19 +276,6 @@ async def handler(request): await asyncio.sleep(0.08) msg = await ws._reader.read() assert msg.type == WSMsgType.CLOSE - await ws.send_str('hang') - - # i am not sure what do we test here - # under uvloop this code raises RuntimeError - try: - await asyncio.sleep(0.08) - await ws.send_str('hang') - await asyncio.sleep(0.08) - await ws.send_str('hang') - await asyncio.sleep(0.08) - await ws.send_str('hang') - except RuntimeError: - pass await asyncio.sleep(0.08) assert (await aborted) @@ -668,19 +655,12 @@ async def handler(request): async def test_heartbeat_no_pong(loop, aiohttp_client, ceil) -> None: - cancelled = False async def handler(request): - nonlocal cancelled - ws = web.WebSocketResponse(heartbeat=0.05) await ws.prepare(request) - try: - await ws.receive() - except asyncio.CancelledError: - cancelled = True - + await ws.receive() return ws app = web.Application() @@ -690,9 +670,7 @@ async def handler(request): ws = await client.ws_connect('/', autoping=False) msg = await ws.receive() assert msg.type == aiohttp.WSMsgType.ping - await ws.receive() - - assert cancelled + await ws.close() async def test_server_ws_async_for(loop, aiohttp_server) -> None: diff --git a/tests/test_websocket_writer.py b/tests/test_websocket_writer.py index 2a25ab1cbff..0fde37aae4b 100644 --- a/tests/test_websocket_writer.py +++ b/tests/test_websocket_writer.py @@ -16,7 +16,9 @@ def protocol(): @pytest.fixture def transport(): - return mock.Mock() + ret = mock.Mock() + ret.is_closing.return_value = False + return ret @pytest.fixture From b6fe9b02c2c75184353013108630697c3dcb3b10 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Fri, 1 Nov 2019 19:54:21 +0200 Subject: [PATCH 2/3] Fix spelling --- aiohttp/web_protocol.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aiohttp/web_protocol.py b/aiohttp/web_protocol.py index d670fa24d37..3fb15fb8940 100644 --- a/aiohttp/web_protocol.py +++ b/aiohttp/web_protocol.py @@ -248,7 +248,7 @@ def connection_lost(self, exc: Optional[BaseException]) -> None: if self._current_request is not None: if exc is None: - exc = ConnectionResetError("Connetion lost") + exc = ConnectionResetError("Connection lost") self._current_request._cancel(exc) if self._error_handler is not None: From ea440e15623e6222e069bf9072b9749b6194c3d8 Mon Sep 17 00:00:00 2001 From: jack1142 <6032823+jack1142@users.noreply.github.com> Date: Mon, 25 May 2020 01:50:41 +0200 Subject: [PATCH 3/3] I wonder...