Skip to content

Commit 7f93337

Browse files
authored
Merge branch 'master' into cheaper-no-other-refs
2 parents 2349b9c + 31ce0a5 commit 7f93337

File tree

6 files changed

+242
-6
lines changed

6 files changed

+242
-6
lines changed

docs/api.rst

+1
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ Testing and debugging
200200
---------------------
201201

202202
.. autoclass:: anyio.TaskInfo
203+
.. autoclass:: anyio.pytest_plugin.FreePortFactory
203204
.. autofunction:: anyio.get_current_task
204205
.. autofunction:: anyio.get_running_tasks
205206
.. autofunction:: anyio.wait_all_tasks_blocked

docs/conf.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
exclude_patterns = ["_build"]
2929
pygments_style = "sphinx"
3030
autodoc_default_options = {"members": True, "show-inheritance": True}
31-
autodoc_mock_imports = ["_typeshed"]
31+
autodoc_mock_imports = ["_typeshed", "pytest", "_pytest"]
3232
todo_include_todos = False
3333

3434
html_theme = "sphinx_rtd_theme"

docs/testing.rst

+77
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,83 @@ scoped::
140140
yield
141141
await server.shutdown()
142142

143+
Built-in utility fixtures
144+
-------------------------
145+
146+
Some useful pytest fixtures are provided to make testing network services easier:
147+
148+
* ``free_tcp_port_factory``: session scoped fixture returning a callable
149+
(:class:`~pytest_plugin.FreePortFactory`) that generates unused TCP port numbers
150+
* ``free_udp_port_factory``: session scoped fixture returning a callable
151+
(:class:`~pytest_plugin.FreePortFactory`) that generates unused UDP port numbers
152+
* ``free_tcp_port``: function level fixture that invokes the ``free_tcp_port_factory``
153+
fixture to generate a free TCP port number
154+
* ``free_udp_port``: function level fixture that invokes the ``free_udp_port_factory``
155+
fixture to generate a free UDP port number
156+
157+
The use of these fixtures, in place of hard-coded ports numbers, will avoid errors due
158+
to a port already being allocated. In particular, they are a must for running multiple
159+
instances of the same test suite concurrently, either via ``pytest-xdist`` or ``tox`` or
160+
similar tools which can run the test suite in multiple interpreters in parallel.
161+
162+
For example, you could set up a network listener in an ephemeral port and then connect
163+
to it::
164+
165+
from anyio import connect_tcp, create_task_group, create_tcp_listener
166+
from anyio.abc import SocketStream
167+
168+
169+
async def test_echo(free_tcp_port: int) -> None:
170+
async def handle(client_stream: SocketStream) -> None:
171+
async with client_stream:
172+
payload = await client_stream.receive()
173+
await client_stream.send(payload[::-1])
174+
175+
async with (
176+
await create_tcp_listener(local_port=free_tcp_port) as listener,
177+
create_task_group() as tg
178+
):
179+
tg.start_soon(listener.serve, handle)
180+
181+
async with await connect_tcp("127.0.0.1", free_tcp_port) as stream:
182+
await stream.send(b"hello")
183+
assert await stream.receive() == b"olleh"
184+
185+
tg.cancel_scope.cancel()
186+
187+
.. warning:::: It is possible in rare cases, particularly in local development, that
188+
another process could bind to the port returned by one of these fixtures before your
189+
code can do the same, leading to an :exc:`OSError` with the ``EADDRINUSE`` code. It
190+
is advisable to just rerun the test if this happens.
191+
192+
This is mostly useful with APIs that don't natively offer any way to bind to ephemeral
193+
ports (and retrieve those ports after binding). If you're working with AnyIO's own APIs,
194+
however, you could make use of this native capability::
195+
196+
from anyio import connect_tcp, create_task_group, create_tcp_listener
197+
from anyio.abc import SocketAttribute, SocketStream
198+
199+
async def test_echo() -> None:
200+
async def handle(client_stream: SocketStream) -> None:
201+
async with client_stream:
202+
payload = await client_stream.receive()
203+
await client_stream.send(payload[::-1])
204+
205+
async with (
206+
await create_tcp_listener(local_host="127.0.0.1") as listener,
207+
create_task_group() as tg
208+
):
209+
tg.start_soon(listener.serve, handle)
210+
port = listener.extra(SocketAttribute.local_port)
211+
212+
async with await connect_tcp("127.0.0.1", port) as stream:
213+
await stream.send(b"hello")
214+
assert await stream.receive() == b"olleh"
215+
216+
tg.cancel_scope.cancel()
217+
218+
.. versionadded:: 4.9.0
219+
143220
Technical details
144221
-----------------
145222

docs/versionhistory.rst

+10
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,16 @@ This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.
55

66
**UNRELEASED**
77

8+
- Added 4 new fixtures for the AnyIO ``pytest`` plugin:
9+
10+
* ``free_tcp_port_factory``: session scoped fixture returning a callable that
11+
generates unused TCP port numbers
12+
* ``free_udp_port_factory``: session scoped fixture returning a callable that
13+
generates unused UDP port numbers
14+
* ``free_tcp_port``: function scoped fixture that invokes the
15+
``free_tcp_port_factory`` fixture to generate a free TCP port number
16+
* ``free_udp_port``: function scoped fixture that invokes the
17+
``free_udp_port_factory`` fixture to generate a free UDP port number
818
- Added ``stdin`` argument to ``anyio.run_process()`` akin to what
919
``anyio.open_process()``, ``asyncio.create_subprocess_…()``, ``trio.run_process()``,
1020
and ``subprocess.run()`` already accept (PR by @jmehnle)

src/anyio/pytest_plugin.py

+83-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from __future__ import annotations
22

3+
import socket
34
import sys
4-
from collections.abc import Generator, Iterator
5+
from collections.abc import Callable, Generator, Iterator
56
from contextlib import ExitStack, contextmanager
67
from inspect import isasyncgenfunction, iscoroutinefunction, ismethod
78
from typing import Any, cast
@@ -188,3 +189,84 @@ def anyio_backend_options(anyio_backend: Any) -> dict[str, Any]:
188189
return {}
189190
else:
190191
return anyio_backend[1]
192+
193+
194+
class FreePortFactory:
195+
"""
196+
Manages port generation based on specified socket kind, ensuring no duplicate
197+
ports are generated.
198+
199+
This class provides functionality for generating available free ports on the
200+
system. It is initialized with a specific socket kind and can generate ports
201+
for given address families while avoiding reuse of previously generated ports.
202+
203+
Users should not instantiate this class directly, but use the
204+
``free_tcp_port_factory`` and ``free_udp_port_factory`` fixtures instead. For simple
205+
uses cases, ``free_tcp_port`` and ``free_udp_port`` can be used instead.
206+
"""
207+
208+
def __init__(self, kind: socket.SocketKind) -> None:
209+
self._kind = kind
210+
self._generated = set[int]()
211+
212+
@property
213+
def kind(self) -> socket.SocketKind:
214+
"""
215+
The type of socket connection (e.g., :data:`~socket.SOCK_STREAM` or
216+
:data:`~socket.SOCK_DGRAM`) used to bind for checking port availability
217+
218+
"""
219+
return self._kind
220+
221+
def __call__(self, family: socket.AddressFamily | None = None) -> int:
222+
"""
223+
Return an unbound port for the given address family.
224+
225+
:param family: if omitted, both IPv4 and IPv6 addresses will be tried
226+
:return: a port number
227+
228+
"""
229+
if family is not None:
230+
families = [family]
231+
else:
232+
families = [socket.AF_INET]
233+
if socket.has_ipv6:
234+
families.append(socket.AF_INET6)
235+
236+
while True:
237+
port = 0
238+
with ExitStack() as stack:
239+
for family in families:
240+
sock = stack.enter_context(socket.socket(family, self._kind))
241+
addr = "::1" if family == socket.AF_INET6 else "127.0.0.1"
242+
try:
243+
sock.bind((addr, port))
244+
except OSError:
245+
break
246+
247+
if not port:
248+
port = sock.getsockname()[1]
249+
else:
250+
if port not in self._generated:
251+
self._generated.add(port)
252+
return port
253+
254+
255+
@pytest.fixture(scope="session")
256+
def free_tcp_port_factory() -> FreePortFactory:
257+
return FreePortFactory(socket.SOCK_STREAM)
258+
259+
260+
@pytest.fixture(scope="session")
261+
def free_udp_port_factory() -> FreePortFactory:
262+
return FreePortFactory(socket.SOCK_DGRAM)
263+
264+
265+
@pytest.fixture
266+
def free_tcp_port(free_tcp_port_factory: Callable[[], int]) -> int:
267+
return free_tcp_port_factory()
268+
269+
270+
@pytest.fixture
271+
def free_udp_port(free_udp_port_factory: Callable[[], int]) -> int:
272+
return free_udp_port_factory()

tests/test_pytest_plugin.py

+70-4
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,22 @@
11
from __future__ import annotations
22

3+
import socket
4+
from collections.abc import Sequence
5+
36
import pytest
47
from _pytest.logging import LogCaptureFixture
58
from _pytest.pytester import Pytester
69

710
from anyio import get_all_backends
11+
from anyio.pytest_plugin import FreePortFactory
812

9-
pytestmark = pytest.mark.filterwarnings(
10-
"ignore:The TerminalReporter.writer attribute is deprecated"
11-
":pytest.PytestDeprecationWarning:"
12-
)
13+
pytestmark = [
14+
pytest.mark.filterwarnings(
15+
"ignore:The TerminalReporter.writer attribute is deprecated"
16+
":pytest.PytestDeprecationWarning:"
17+
),
18+
pytest.mark.anyio,
19+
]
1320

1421
pytest_args = "-v", "-p", "anyio", "-p", "no:asyncio", "-p", "no:trio"
1522

@@ -561,3 +568,62 @@ async def test_params(fixt):
561568

562569
result = testdir.runpytest(*pytest_args)
563570
result.assert_outcomes(passed=len(get_all_backends()) * 2)
571+
572+
573+
class TestFreePortFactory:
574+
@pytest.fixture(scope="class")
575+
def families(self) -> Sequence[tuple[socket.AddressFamily, str]]:
576+
from .test_sockets import has_ipv6
577+
578+
families: list[tuple[socket.AddressFamily, str]] = [
579+
(socket.AF_INET, "127.0.0.1")
580+
]
581+
if has_ipv6:
582+
families.append((socket.AF_INET6, "::1"))
583+
584+
return families
585+
586+
async def test_tcp_factory(
587+
self,
588+
families: Sequence[tuple[socket.AddressFamily, str]],
589+
free_tcp_port_factory: FreePortFactory,
590+
) -> None:
591+
generated_ports = {free_tcp_port_factory() for _ in range(5)}
592+
assert all(isinstance(port, int) for port in generated_ports)
593+
assert len(generated_ports) == 5
594+
for port in generated_ports:
595+
for family, addr in families:
596+
with socket.socket(family, socket.SOCK_STREAM) as sock:
597+
try:
598+
sock.bind((addr, port))
599+
except OSError:
600+
pass
601+
602+
async def test_udp_factory(
603+
self,
604+
families: Sequence[tuple[socket.AddressFamily, str]],
605+
free_udp_port_factory: FreePortFactory,
606+
) -> None:
607+
generated_ports = {free_udp_port_factory() for _ in range(5)}
608+
assert all(isinstance(port, int) for port in generated_ports)
609+
assert len(generated_ports) == 5
610+
for port in generated_ports:
611+
for family, addr in families:
612+
with socket.socket(family, socket.SOCK_DGRAM) as sock:
613+
sock.bind((addr, port))
614+
615+
async def test_free_tcp_port(
616+
self, families: Sequence[tuple[socket.AddressFamily, str]], free_tcp_port: int
617+
) -> None:
618+
assert isinstance(free_tcp_port, int)
619+
for family, addr in families:
620+
with socket.socket(family, socket.SOCK_STREAM) as sock:
621+
sock.bind((addr, free_tcp_port))
622+
623+
async def test_free_udp_port(
624+
self, families: Sequence[tuple[socket.AddressFamily, str]], free_udp_port: int
625+
) -> None:
626+
assert isinstance(free_udp_port, int)
627+
for family, addr in families:
628+
with socket.socket(family, socket.SOCK_DGRAM) as sock:
629+
sock.bind((addr, free_udp_port))

0 commit comments

Comments
 (0)