13
13
Callable ,
14
14
Deque ,
15
15
Optional ,
16
+ Sequence ,
16
17
Tuple ,
17
18
Type ,
18
19
Union ,
19
20
cast ,
20
21
)
21
22
23
+ import attr
22
24
import yarl
23
25
24
26
from .abc import AbstractAccessLogger , AbstractAsyncAccessLogger , AbstractStreamWriter
62
64
Type [AbstractAccessLogger ],
63
65
]
64
66
65
-
66
67
ERROR = RawRequestMessage (
67
68
"UNKNOWN" , "/" , HttpVersion10 , {}, {}, True , False , False , False , yarl .URL ("/" )
68
69
)
@@ -95,6 +96,16 @@ async def log(
95
96
self .access_logger .log (request , response , self ._loop .time () - request_start )
96
97
97
98
99
+ @attr .s (auto_attribs = True , frozen = True , slots = True )
100
+ class _ErrInfo :
101
+ status : int
102
+ exc : BaseException
103
+ message : str
104
+
105
+
106
+ _MsgType = Tuple [Union [RawRequestMessage , _ErrInfo ], StreamReader ]
107
+
108
+
98
109
class RequestHandler (BaseProtocol ):
99
110
"""HTTP protocol implementation.
100
111
@@ -106,30 +117,26 @@ class RequestHandler(BaseProtocol):
106
117
status line, bad headers or incomplete payload. If any error occurs,
107
118
connection gets closed.
108
119
109
- :param keepalive_timeout: number of seconds before closing
110
- keep-alive connection
111
- :type keepalive_timeout: int or None
120
+ keepalive_timeout -- number of seconds before closing
121
+ keep-alive connection
112
122
113
- :param bool tcp_keepalive: TCP keep-alive is on, default is on
123
+ tcp_keepalive -- TCP keep-alive is on, default is on
114
124
115
- :param logger: custom logger object
116
- :type logger: aiohttp.log.server_logger
125
+ logger -- custom logger object
117
126
118
- :param access_log_class: custom class for access_logger
119
- :type access_log_class: aiohttp.abc.AbstractAccessLogger
127
+ access_log_class -- custom class for access_logger
120
128
121
- :param access_log: custom logging object
122
- :type access_log: aiohttp.log.server_logger
129
+ access_log -- custom logging object
123
130
124
- :param str access_log_format: access log format string
131
+ access_log_format -- access log format string
125
132
126
- :param loop: Optional event loop
133
+ loop -- Optional event loop
127
134
128
- :param int max_line_size: Optional maximum header line size
135
+ max_line_size -- Optional maximum header line size
129
136
130
- :param int max_field_size: Optional maximum header field size
137
+ max_field_size -- Optional maximum header field size
131
138
132
- :param int max_headers: Optional maximum header size
139
+ max_headers -- Optional maximum header size
133
140
134
141
"""
135
142
@@ -149,7 +156,6 @@ class RequestHandler(BaseProtocol):
149
156
"_messages" ,
150
157
"_message_tail" ,
151
158
"_waiter" ,
152
- "_error_handler" ,
153
159
"_task_handler" ,
154
160
"_upgrade" ,
155
161
"_payload_parser" ,
@@ -180,19 +186,14 @@ def __init__(
180
186
lingering_time : float = 10.0 ,
181
187
read_bufsize : int = 2 ** 16 ,
182
188
):
183
-
184
189
super ().__init__ (loop )
185
190
186
191
self ._request_count = 0
187
192
self ._keepalive = False
188
193
self ._current_request = None # type: Optional[BaseRequest]
189
194
self ._manager = manager # type: Optional[Server]
190
- self ._request_handler = (
191
- manager .request_handler
192
- ) # type: Optional[_RequestHandler]
193
- self ._request_factory = (
194
- manager .request_factory
195
- ) # type: Optional[_RequestFactory]
195
+ self ._request_handler : Optional [_RequestHandler ] = manager .request_handler
196
+ self ._request_factory : Optional [_RequestFactory ] = manager .request_factory
196
197
197
198
self ._tcp_keepalive = tcp_keepalive
198
199
# placeholder to be replaced on keepalive timeout setup
@@ -201,11 +202,10 @@ def __init__(
201
202
self ._keepalive_timeout = keepalive_timeout
202
203
self ._lingering_time = float (lingering_time )
203
204
204
- self ._messages : Deque [Tuple [ RawRequestMessage , StreamReader ] ] = deque ()
205
+ self ._messages : Deque [_MsgType ] = deque ()
205
206
self ._message_tail = b""
206
207
207
208
self ._waiter = None # type: Optional[asyncio.Future[None]]
208
- self ._error_handler = None # type: Optional[asyncio.Task[None]]
209
209
self ._task_handler = None # type: Optional[asyncio.Task[None]]
210
210
211
211
self ._upgrade = False
@@ -264,9 +264,6 @@ async def shutdown(self, timeout: Optional[float] = 15.0) -> None:
264
264
# wait for handlers
265
265
with suppress (asyncio .CancelledError , asyncio .TimeoutError ):
266
266
async with ceil_timeout (timeout ):
267
- if self ._error_handler is not None and not self ._error_handler .done ():
268
- await self ._error_handler
269
-
270
267
if self ._current_request is not None :
271
268
self ._current_request ._cancel (asyncio .CancelledError ())
272
269
@@ -313,8 +310,6 @@ def connection_lost(self, exc: Optional[BaseException]) -> None:
313
310
exc = ConnectionResetError ("Connection lost" )
314
311
self ._current_request ._cancel (exc )
315
312
316
- if self ._error_handler is not None :
317
- self ._error_handler .cancel ()
318
313
if self ._task_handler is not None :
319
314
self ._task_handler .cancel ()
320
315
if self ._waiter is not None :
@@ -343,40 +338,30 @@ def data_received(self, data: bytes) -> None:
343
338
if self ._force_close or self ._close :
344
339
return
345
340
# parse http messages
341
+ messages : Sequence [_MsgType ]
346
342
if self ._payload_parser is None and not self ._upgrade :
347
343
assert self ._request_parser is not None
348
344
try :
349
345
messages , upgraded , tail = self ._request_parser .feed_data (data )
350
346
except HttpProcessingError as exc :
351
- # something happened during parsing
352
- self ._error_handler = self ._loop .create_task (
353
- self .handle_parse_error (
354
- StreamWriter (self , self ._loop ), 400 , exc , exc .message
355
- )
356
- )
357
- self .close ()
358
- except Exception as exc :
359
- # 500: internal error
360
- self ._error_handler = self ._loop .create_task (
361
- self .handle_parse_error (StreamWriter (self , self ._loop ), 500 , exc )
362
- )
363
- self .close ()
364
- else :
365
- if messages :
366
- # sometimes the parser returns no messages
367
- for (msg , payload ) in messages :
368
- self ._request_count += 1
369
- self ._messages .append ((msg , payload ))
370
-
371
- waiter = self ._waiter
372
- if waiter is not None :
373
- if not waiter .done ():
374
- # don't set result twice
375
- waiter .set_result (None )
376
-
377
- self ._upgrade = upgraded
378
- if upgraded and tail :
379
- self ._message_tail = tail
347
+ messages = [
348
+ (_ErrInfo (status = 400 , exc = exc , message = exc .message ), EMPTY_PAYLOAD )
349
+ ]
350
+ upgraded = False
351
+ tail = b""
352
+
353
+ for msg , payload in messages or ():
354
+ self ._request_count += 1
355
+ self ._messages .append ((msg , payload ))
356
+
357
+ waiter = self ._waiter
358
+ if messages and waiter is not None and not waiter .done ():
359
+ # don't set result twice
360
+ waiter .set_result (None )
361
+
362
+ self ._upgrade = upgraded
363
+ if upgraded and tail :
364
+ self ._message_tail = tail
380
365
381
366
# no parser, just store
382
367
elif self ._payload_parser is None and self ._upgrade and data :
@@ -449,12 +434,13 @@ async def _handle_request(
449
434
self ,
450
435
request : BaseRequest ,
451
436
start_time : float ,
437
+ request_handler : Callable [[BaseRequest ], Awaitable [StreamResponse ]],
452
438
) -> Tuple [StreamResponse , bool ]:
453
439
assert self ._request_handler is not None
454
440
try :
455
441
try :
456
442
self ._current_request = request
457
- resp = await self . _request_handler (request )
443
+ resp = await request_handler (request )
458
444
finally :
459
445
self ._current_request = None
460
446
except HTTPException as exc :
@@ -513,10 +499,19 @@ async def start(self) -> None:
513
499
514
500
manager .requests_count += 1
515
501
writer = StreamWriter (self , loop )
502
+ if isinstance (message , _ErrInfo ):
503
+ # make request_factory work
504
+ request_handler = self ._make_error_handler (message )
505
+ message = ERROR
506
+ else :
507
+ request_handler = self ._request_handler
508
+
516
509
request = self ._request_factory (message , payload , self , writer , handler )
517
510
try :
518
511
# a new task is used for copy context vars (#3406)
519
- task = self ._loop .create_task (self ._handle_request (request , start ))
512
+ task = self ._loop .create_task (
513
+ self ._handle_request (request , start , request_handler )
514
+ )
520
515
try :
521
516
resp , reset = await task
522
517
except (asyncio .CancelledError , ConnectionError ):
@@ -586,7 +581,7 @@ async def start(self) -> None:
586
581
# remove handler, close transport if no handlers left
587
582
if not self ._force_close :
588
583
self ._task_handler = None
589
- if self .transport is not None and self . _error_handler is None :
584
+ if self .transport is not None :
590
585
self .transport .close ()
591
586
592
587
async def finish_response (
@@ -639,6 +634,13 @@ def handle_error(
639
634
information. It always closes current connection."""
640
635
self .log_exception ("Error handling request" , exc_info = exc )
641
636
637
+ # some data already got sent, connection is broken
638
+ if request .writer .output_size > 0 :
639
+ raise ConnectionError (
640
+ "Response is sent already, cannot send another response "
641
+ "with the error message"
642
+ )
643
+
642
644
ct = "text/plain"
643
645
if status == HTTPStatus .INTERNAL_SERVER_ERROR :
644
646
title = "{0.value} {0.phrase}" .format (HTTPStatus .INTERNAL_SERVER_ERROR )
@@ -667,30 +669,14 @@ def handle_error(
667
669
resp = Response (status = status , text = message , content_type = ct )
668
670
resp .force_close ()
669
671
670
- # some data already got sent, connection is broken
671
- if request .writer .output_size > 0 or self .transport is None :
672
- self .force_close ()
673
-
674
672
return resp
675
673
676
- async def handle_parse_error (
677
- self ,
678
- writer : AbstractStreamWriter ,
679
- status : int ,
680
- exc : Optional [BaseException ] = None ,
681
- message : Optional [str ] = None ,
682
- ) -> None :
683
- task = asyncio .current_task ()
684
- assert task is not None
685
- request = BaseRequest (
686
- ERROR , EMPTY_PAYLOAD , self , writer , task , self ._loop # type: ignore
687
- )
688
-
689
- resp = self .handle_error (request , status , exc , message )
690
- await resp .prepare (request )
691
- await resp .write_eof ()
692
-
693
- if self .transport is not None :
694
- self .transport .close ()
674
+ def _make_error_handler (
675
+ self , err_info : _ErrInfo
676
+ ) -> Callable [[BaseRequest ], Awaitable [StreamResponse ]]:
677
+ async def handler (request : BaseRequest ) -> StreamResponse :
678
+ return self .handle_error (
679
+ request , err_info .status , err_info .exc , err_info .message
680
+ )
695
681
696
- self . _error_handler = None
682
+ return handler
0 commit comments