Skip to content

Commit 6e366e1

Browse files
committed
Handle type guards properly in Receiver.filter()
Now the `Receiver` type returned by `Receiver.filter()` will have the narrowed type when a `TypeGuard` is used. Signed-off-by: Leandro Lucarella <[email protected]>
1 parent 486d1be commit 6e366e1

File tree

3 files changed

+87
-1
lines changed

3 files changed

+87
-1
lines changed

RELEASE_NOTES.md

+2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
print('Received from recv2:', selected.message)
2121
```
2222

23+
* `Receiver.filter()` can now properly handle `TypeGuard`s. The resulting receiver will now have the narrowed type when a `TypeGuard` is used.
24+
2325
## Bug Fixes
2426

2527
<!-- Here goes notable bug fixes that are worth a special mention or explanation -->

src/frequenz/channels/_receiver.py

+59-1
Original file line numberDiff line numberDiff line change
@@ -155,14 +155,17 @@
155155

156156
from abc import ABC, abstractmethod
157157
from collections.abc import Callable
158-
from typing import TYPE_CHECKING, Any, Generic, Self, TypeGuard
158+
from typing import TYPE_CHECKING, Any, Generic, Self, TypeGuard, TypeVar, overload
159159

160160
from ._exceptions import Error
161161
from ._generic import MappedMessageT_co, ReceiverMessageT_co
162162

163163
if TYPE_CHECKING:
164164
from ._select import Selected
165165

166+
FilteredMessageT_co = TypeVar("FilteredMessageT_co", covariant=True)
167+
"""Type variable for the filtered message type."""
168+
166169

167170
class Receiver(ABC, Generic[ReceiverMessageT_co]):
168171
"""An endpoint to receive messages."""
@@ -267,11 +270,66 @@ def map(
267270
"""
268271
return _Mapper(receiver=self, mapping_function=mapping_function)
269272

273+
@overload
274+
def filter(
275+
self,
276+
filter_function: Callable[
277+
[ReceiverMessageT_co], TypeGuard[FilteredMessageT_co]
278+
],
279+
/,
280+
) -> Receiver[FilteredMessageT_co]:
281+
"""Apply a type guard on the messages on a receiver.
282+
283+
Tip:
284+
The returned receiver type won't have all the methods of the original
285+
receiver. If you need to access methods of the original receiver that are
286+
not part of the `Receiver` interface you should save a reference to the
287+
original receiver and use that instead.
288+
289+
Args:
290+
filter_function: The function to be applied on incoming messages to
291+
determine if they should be received.
292+
293+
Returns:
294+
A new receiver that only receives messages that pass the filter.
295+
"""
296+
... # pylint: disable=unnecessary-ellipsis
297+
298+
@overload
270299
def filter(
271300
self, filter_function: Callable[[ReceiverMessageT_co], bool], /
272301
) -> Receiver[ReceiverMessageT_co]:
273302
"""Apply a filter function on the messages on a receiver.
274303
304+
Tip:
305+
The returned receiver type won't have all the methods of the original
306+
receiver. If you need to access methods of the original receiver that are
307+
not part of the `Receiver` interface you should save a reference to the
308+
original receiver and use that instead.
309+
310+
Args:
311+
filter_function: The function to be applied on incoming messages to
312+
determine if they should be received.
313+
314+
Returns:
315+
A new receiver that only receives messages that pass the filter.
316+
"""
317+
... # pylint: disable=unnecessary-ellipsis
318+
319+
def filter(
320+
self,
321+
filter_function: (
322+
Callable[[ReceiverMessageT_co], bool]
323+
| Callable[[ReceiverMessageT_co], TypeGuard[FilteredMessageT_co]]
324+
),
325+
/,
326+
) -> Receiver[ReceiverMessageT_co] | Receiver[FilteredMessageT_co]:
327+
"""Apply a filter function on the messages on a receiver.
328+
329+
Note:
330+
You can pass a [type guard][typing.TypeGuard] as the filter function to
331+
narrow the type of the messages that pass the filter.
332+
275333
Tip:
276334
The returned receiver type won't have all the methods of the original
277335
receiver. If you need to access methods of the original receiver that are

tests/test_broadcast.py

+26
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import asyncio
88
from dataclasses import dataclass
9+
from typing import TypeGuard, assert_never
910

1011
import pytest
1112

@@ -248,6 +249,31 @@ async def test_broadcast_filter() -> None:
248249
assert (await receiver.receive()) == 15
249250

250251

252+
async def test_broadcast_filter_type_guard() -> None:
253+
"""Ensure filter type guard works."""
254+
chan = Broadcast[int | str](name="input-chan")
255+
sender = chan.new_sender()
256+
257+
def _is_int(num: int | str) -> TypeGuard[int]:
258+
return isinstance(num, int)
259+
260+
# filter out objects that are not integers.
261+
receiver = chan.new_receiver().filter(_is_int)
262+
263+
await sender.send("hello")
264+
await sender.send(8)
265+
266+
message = await receiver.receive()
267+
assert message == 8
268+
is_int = False
269+
match message:
270+
case int():
271+
is_int = True
272+
case unexpected:
273+
assert_never(unexpected)
274+
assert is_int
275+
276+
251277
async def test_broadcast_receiver_drop() -> None:
252278
"""Ensure deleted receivers get cleaned up."""
253279
chan = Broadcast[int](name="input-chan")

0 commit comments

Comments
 (0)