4
4
import io
5
5
import json
6
6
from http .cookies import SimpleCookie
7
- from typing import Any , List
7
+ from typing import Any , Awaitable , Callable , List
8
8
from unittest import mock
9
9
from uuid import uuid4
10
10
16
16
import aiohttp
17
17
from aiohttp import client , hdrs , web
18
18
from aiohttp .client import ClientSession
19
+ from aiohttp .client_proto import ResponseHandler
19
20
from aiohttp .client_reqrep import ClientRequest
20
- from aiohttp .connector import BaseConnector , TCPConnector
21
+ from aiohttp .connector import BaseConnector , Connection , TCPConnector , UnixConnector
21
22
from aiohttp .helpers import DEBUG
22
23
from aiohttp .test_utils import make_mocked_coro
24
+ from aiohttp .tracing import Trace
23
25
24
26
25
27
@pytest .fixture
@@ -487,15 +489,17 @@ async def test_ws_connect_allowed_protocols(
487
489
hdrs .CONNECTION : "upgrade" ,
488
490
hdrs .SEC_WEBSOCKET_ACCEPT : ws_key ,
489
491
}
490
- resp .url = URL (f"{ protocol } ://example.com " )
492
+ resp .url = URL (f"{ protocol } ://example" )
491
493
resp .cookies = SimpleCookie ()
492
494
resp .start = mock .AsyncMock ()
493
495
494
496
req = mock .create_autospec (aiohttp .ClientRequest , spec_set = True )
495
497
req_factory = mock .Mock (return_value = req )
496
498
req .send = mock .AsyncMock (return_value = resp )
499
+ # BaseConnector allows all high level protocols by default
500
+ connector = BaseConnector ()
497
501
498
- session = await create_session (request_class = req_factory )
502
+ session = await create_session (connector = connector , request_class = req_factory )
499
503
500
504
connections = []
501
505
original_connect = session ._connector .connect
@@ -515,7 +519,68 @@ async def create_connection(req, traces, timeout):
515
519
"aiohttp.client.os"
516
520
) as m_os :
517
521
m_os .urandom .return_value = key_data
518
- await session .ws_connect (f"{ protocol } ://example.com" )
522
+ await session .ws_connect (f"{ protocol } ://example" )
523
+
524
+ # normally called during garbage collection. triggers an exception
525
+ # if the connection wasn't already closed
526
+ for c in connections :
527
+ c .close ()
528
+ c .__del__ ()
529
+
530
+ await session .close ()
531
+
532
+
533
+ @pytest .mark .parametrize ("protocol" , ["http" , "https" , "ws" , "wss" , "unix" ])
534
+ async def test_ws_connect_unix_socket_allowed_protocols (
535
+ create_session : Callable [..., Awaitable [ClientSession ]],
536
+ create_mocked_conn : Callable [[], ResponseHandler ],
537
+ protocol : str ,
538
+ ws_key : bytes ,
539
+ key_data : bytes ,
540
+ ) -> None :
541
+ resp = mock .create_autospec (aiohttp .ClientResponse )
542
+ resp .status = 101
543
+ resp .headers = {
544
+ hdrs .UPGRADE : "websocket" ,
545
+ hdrs .CONNECTION : "upgrade" ,
546
+ hdrs .SEC_WEBSOCKET_ACCEPT : ws_key ,
547
+ }
548
+ resp .url = URL (f"{ protocol } ://example" )
549
+ resp .cookies = SimpleCookie ()
550
+ resp .start = mock .AsyncMock ()
551
+
552
+ req = mock .create_autospec (aiohttp .ClientRequest , spec_set = True )
553
+ req_factory = mock .Mock (return_value = req )
554
+ req .send = mock .AsyncMock (return_value = resp )
555
+ # UnixConnector allows all high level protocols by default and unix sockets
556
+ session = await create_session (
557
+ connector = UnixConnector (path = "" ), request_class = req_factory
558
+ )
559
+
560
+ connections = []
561
+ assert session ._connector is not None
562
+ original_connect = session ._connector .connect
563
+
564
+ async def connect (
565
+ req : ClientRequest , traces : List [Trace ], timeout : aiohttp .ClientTimeout
566
+ ) -> Connection :
567
+ conn = await original_connect (req , traces , timeout )
568
+ connections .append (conn )
569
+ return conn
570
+
571
+ async def create_connection (
572
+ req : object , traces : object , timeout : object
573
+ ) -> ResponseHandler :
574
+ return create_mocked_conn ()
575
+
576
+ connector = session ._connector
577
+ with mock .patch .object (connector , "connect" , connect ), mock .patch .object (
578
+ connector , "_create_connection" , create_connection
579
+ ), mock .patch .object (connector , "_release" ), mock .patch (
580
+ "aiohttp.client.os"
581
+ ) as m_os :
582
+ m_os .urandom .return_value = key_data
583
+ await session .ws_connect (f"{ protocol } ://example" )
519
584
520
585
# normally called during garbage collection. triggers an exception
521
586
# if the connection wasn't already closed
0 commit comments