Skip to content

Commit 72f41aa

Browse files
[PR #8632/b2691f2 backport][3.10] Fix connecting to npipe://, tcp://, and unix:// urls (#8637)
Co-authored-by: Sam Bull <[email protected]>
1 parent bf83dbe commit 72f41aa

File tree

5 files changed

+137
-14
lines changed

5 files changed

+137
-14
lines changed

CHANGES/8632.bugfix.rst

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fixed connecting to ``npipe://``, ``tcp://``, and ``unix://`` urls -- by :user:`bdraco`.

aiohttp/client.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
)
7676
from .client_ws import ClientWebSocketResponse as ClientWebSocketResponse
7777
from .connector import (
78+
HTTP_AND_EMPTY_SCHEMA_SET,
7879
BaseConnector as BaseConnector,
7980
NamedPipeConnector as NamedPipeConnector,
8081
TCPConnector as TCPConnector,
@@ -209,9 +210,6 @@ class ClientTimeout:
209210

210211
# https://www.rfc-editor.org/rfc/rfc9110#section-9.2.2
211212
IDEMPOTENT_METHODS = frozenset({"GET", "HEAD", "OPTIONS", "TRACE", "PUT", "DELETE"})
212-
HTTP_SCHEMA_SET = frozenset({"http", "https", ""})
213-
WS_SCHEMA_SET = frozenset({"ws", "wss"})
214-
ALLOWED_PROTOCOL_SCHEMA_SET = HTTP_SCHEMA_SET | WS_SCHEMA_SET
215213

216214
_RetType = TypeVar("_RetType")
217215
_CharsetResolver = Callable[[ClientResponse, bytes], str]
@@ -517,7 +515,8 @@ async def _request(
517515
except ValueError as e:
518516
raise InvalidUrlClientError(str_or_url) from e
519517

520-
if url.scheme not in ALLOWED_PROTOCOL_SCHEMA_SET:
518+
assert self._connector is not None
519+
if url.scheme not in self._connector.allowed_protocol_schema_set:
521520
raise NonHttpUrlClientError(url)
522521

523522
skip_headers = set(self._skip_auto_headers)
@@ -655,7 +654,6 @@ async def _request(
655654
real_timeout.connect,
656655
ceil_threshold=real_timeout.ceil_threshold,
657656
):
658-
assert self._connector is not None
659657
conn = await self._connector.connect(
660658
req, traces=traces, timeout=real_timeout
661659
)
@@ -752,7 +750,7 @@ async def _request(
752750
) from e
753751

754752
scheme = parsed_redirect_url.scheme
755-
if scheme not in HTTP_SCHEMA_SET:
753+
if scheme not in HTTP_AND_EMPTY_SCHEMA_SET:
756754
resp.close()
757755
raise NonHttpUrlRedirectClientError(r_url)
758756
elif not scheme:

aiohttp/connector.py

+16
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,14 @@
6363
SSLContext = object # type: ignore[misc,assignment]
6464

6565

66+
EMPTY_SCHEMA_SET = frozenset({""})
67+
HTTP_SCHEMA_SET = frozenset({"http", "https"})
68+
WS_SCHEMA_SET = frozenset({"ws", "wss"})
69+
70+
HTTP_AND_EMPTY_SCHEMA_SET = HTTP_SCHEMA_SET | EMPTY_SCHEMA_SET
71+
HIGH_LEVEL_SCHEMA_SET = HTTP_AND_EMPTY_SCHEMA_SET | WS_SCHEMA_SET
72+
73+
6674
__all__ = ("BaseConnector", "TCPConnector", "UnixConnector", "NamedPipeConnector")
6775

6876

@@ -211,6 +219,8 @@ class BaseConnector:
211219
# abort transport after 2 seconds (cleanup broken connections)
212220
_cleanup_closed_period = 2.0
213221

222+
allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET
223+
214224
def __init__(
215225
self,
216226
*,
@@ -760,6 +770,8 @@ class TCPConnector(BaseConnector):
760770
loop - Optional event loop.
761771
"""
762772

773+
allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"tcp"})
774+
763775
def __init__(
764776
self,
765777
*,
@@ -1458,6 +1470,8 @@ class UnixConnector(BaseConnector):
14581470
loop - Optional event loop.
14591471
"""
14601472

1473+
allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"unix"})
1474+
14611475
def __init__(
14621476
self,
14631477
path: str,
@@ -1514,6 +1528,8 @@ class NamedPipeConnector(BaseConnector):
15141528
loop - Optional event loop.
15151529
"""
15161530

1531+
allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"npipe"})
1532+
15171533
def __init__(
15181534
self,
15191535
path: str,

tests/test_client_session.py

+70-5
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import io
55
import json
66
from http.cookies import SimpleCookie
7-
from typing import Any, List
7+
from typing import Any, Awaitable, Callable, List
88
from unittest import mock
99
from uuid import uuid4
1010

@@ -16,10 +16,12 @@
1616
import aiohttp
1717
from aiohttp import client, hdrs, web
1818
from aiohttp.client import ClientSession
19+
from aiohttp.client_proto import ResponseHandler
1920
from aiohttp.client_reqrep import ClientRequest
20-
from aiohttp.connector import BaseConnector, TCPConnector
21+
from aiohttp.connector import BaseConnector, Connection, TCPConnector, UnixConnector
2122
from aiohttp.helpers import DEBUG
2223
from aiohttp.test_utils import make_mocked_coro
24+
from aiohttp.tracing import Trace
2325

2426

2527
@pytest.fixture
@@ -487,15 +489,17 @@ async def test_ws_connect_allowed_protocols(
487489
hdrs.CONNECTION: "upgrade",
488490
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
489491
}
490-
resp.url = URL(f"{protocol}://example.com")
492+
resp.url = URL(f"{protocol}://example")
491493
resp.cookies = SimpleCookie()
492494
resp.start = mock.AsyncMock()
493495

494496
req = mock.create_autospec(aiohttp.ClientRequest, spec_set=True)
495497
req_factory = mock.Mock(return_value=req)
496498
req.send = mock.AsyncMock(return_value=resp)
499+
# BaseConnector allows all high level protocols by default
500+
connector = BaseConnector()
497501

498-
session = await create_session(request_class=req_factory)
502+
session = await create_session(connector=connector, request_class=req_factory)
499503

500504
connections = []
501505
original_connect = session._connector.connect
@@ -515,7 +519,68 @@ async def create_connection(req, traces, timeout):
515519
"aiohttp.client.os"
516520
) as m_os:
517521
m_os.urandom.return_value = key_data
518-
await session.ws_connect(f"{protocol}://example.com")
522+
await session.ws_connect(f"{protocol}://example")
523+
524+
# normally called during garbage collection. triggers an exception
525+
# if the connection wasn't already closed
526+
for c in connections:
527+
c.close()
528+
c.__del__()
529+
530+
await session.close()
531+
532+
533+
@pytest.mark.parametrize("protocol", ["http", "https", "ws", "wss", "unix"])
534+
async def test_ws_connect_unix_socket_allowed_protocols(
535+
create_session: Callable[..., Awaitable[ClientSession]],
536+
create_mocked_conn: Callable[[], ResponseHandler],
537+
protocol: str,
538+
ws_key: bytes,
539+
key_data: bytes,
540+
) -> None:
541+
resp = mock.create_autospec(aiohttp.ClientResponse)
542+
resp.status = 101
543+
resp.headers = {
544+
hdrs.UPGRADE: "websocket",
545+
hdrs.CONNECTION: "upgrade",
546+
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
547+
}
548+
resp.url = URL(f"{protocol}://example")
549+
resp.cookies = SimpleCookie()
550+
resp.start = mock.AsyncMock()
551+
552+
req = mock.create_autospec(aiohttp.ClientRequest, spec_set=True)
553+
req_factory = mock.Mock(return_value=req)
554+
req.send = mock.AsyncMock(return_value=resp)
555+
# UnixConnector allows all high level protocols by default and unix sockets
556+
session = await create_session(
557+
connector=UnixConnector(path=""), request_class=req_factory
558+
)
559+
560+
connections = []
561+
assert session._connector is not None
562+
original_connect = session._connector.connect
563+
564+
async def connect(
565+
req: ClientRequest, traces: List[Trace], timeout: aiohttp.ClientTimeout
566+
) -> Connection:
567+
conn = await original_connect(req, traces, timeout)
568+
connections.append(conn)
569+
return conn
570+
571+
async def create_connection(
572+
req: object, traces: object, timeout: object
573+
) -> ResponseHandler:
574+
return create_mocked_conn()
575+
576+
connector = session._connector
577+
with mock.patch.object(connector, "connect", connect), mock.patch.object(
578+
connector, "_create_connection", create_connection
579+
), mock.patch.object(connector, "_release"), mock.patch(
580+
"aiohttp.client.os"
581+
) as m_os:
582+
m_os.urandom.return_value = key_data
583+
await session.ws_connect(f"{protocol}://example")
519584

520585
# normally called during garbage collection. triggers an exception
521586
# if the connection wasn't already closed

tests/test_connector.py

+46-3
Original file line numberDiff line numberDiff line change
@@ -1481,7 +1481,19 @@ async def test_tcp_connector_ctor() -> None:
14811481
assert conn.family == 0
14821482

14831483

1484-
async def test_tcp_connector_ctor_fingerprint_valid(loop) -> None:
1484+
async def test_tcp_connector_allowed_protocols(loop: asyncio.AbstractEventLoop) -> None:
1485+
conn = aiohttp.TCPConnector()
1486+
assert conn.allowed_protocol_schema_set == {"", "tcp", "http", "https", "ws", "wss"}
1487+
1488+
1489+
async def test_invalid_ssl_param() -> None:
1490+
with pytest.raises(TypeError):
1491+
aiohttp.TCPConnector(ssl=object()) # type: ignore[arg-type]
1492+
1493+
1494+
async def test_tcp_connector_ctor_fingerprint_valid(
1495+
loop: asyncio.AbstractEventLoop,
1496+
) -> None:
14851497
valid = aiohttp.Fingerprint(hashlib.sha256(b"foo").digest())
14861498
conn = aiohttp.TCPConnector(ssl=valid, loop=loop)
14871499
assert conn._ssl is valid
@@ -1639,8 +1651,23 @@ async def test_ctor_with_default_loop(loop) -> None:
16391651
assert loop is conn._loop
16401652

16411653

1642-
async def test_connect_with_limit(loop, key) -> None:
1643-
proto = mock.Mock()
1654+
async def test_base_connector_allows_high_level_protocols(
1655+
loop: asyncio.AbstractEventLoop,
1656+
) -> None:
1657+
conn = aiohttp.BaseConnector()
1658+
assert conn.allowed_protocol_schema_set == {
1659+
"",
1660+
"http",
1661+
"https",
1662+
"ws",
1663+
"wss",
1664+
}
1665+
1666+
1667+
async def test_connect_with_limit(
1668+
loop: asyncio.AbstractEventLoop, key: ConnectionKey
1669+
) -> None:
1670+
proto = create_mocked_conn(loop)
16441671
proto.is_connected.return_value = True
16451672

16461673
req = ClientRequest(
@@ -2412,6 +2439,14 @@ async def handler(request):
24122439

24132440
connector = aiohttp.UnixConnector(unix_sockname)
24142441
assert unix_sockname == connector.path
2442+
assert connector.allowed_protocol_schema_set == {
2443+
"",
2444+
"http",
2445+
"https",
2446+
"ws",
2447+
"wss",
2448+
"unix",
2449+
}
24152450

24162451
session = client.ClientSession(connector=connector)
24172452
r = await session.get(url)
@@ -2437,6 +2472,14 @@ async def handler(request):
24372472

24382473
connector = aiohttp.NamedPipeConnector(pipe_name)
24392474
assert pipe_name == connector.path
2475+
assert connector.allowed_protocol_schema_set == {
2476+
"",
2477+
"http",
2478+
"https",
2479+
"ws",
2480+
"wss",
2481+
"npipe",
2482+
}
24402483

24412484
session = client.ClientSession(connector=connector)
24422485
r = await session.get(url)

0 commit comments

Comments
 (0)