Skip to content

Commit a62ca9c

Browse files
committed
Add seq choose and tests
1 parent 2fd1efd commit a62ca9c

File tree

4 files changed

+79
-23
lines changed

4 files changed

+79
-23
lines changed

expression/collections/seq.py

+53-9
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,23 @@ def __init__(self, iterable: Iterable[TSource] = []) -> None:
5050
def filter(self, predicate: Callable[[TSource], bool]) -> "Seq[TSource]":
5151
return Seq(filter(predicate)(self))
5252

53+
def choose(self, chooser: Callable[[TSource], Option[TResult]]) -> "Seq[TResult]":
54+
"""Choose items from the sequence.
55+
56+
Applies the given function to each element of the list. Returns
57+
the list comprised of the results x for each element where the
58+
function returns `Some(x)`.
59+
60+
Args:
61+
chooser: The function to generate options from the elements.
62+
63+
Returns:
64+
The list comprising the values selected from the chooser
65+
function.
66+
"""
67+
68+
return Seq(pipe(self, choose(chooser)))
69+
5370
def collect(self, mapping: Callable[[TSource], "Seq[TResult]"]) -> "Seq[TResult]":
5471
return Seq(collect(mapping)(self))
5572

@@ -91,16 +108,16 @@ def map(self, mapper: Callable[[TSource], TResult]) -> "Seq[TResult]":
91108
return Seq(map(mapper)(self))
92109

93110
@overload
94-
def match(self) -> "Case":
111+
def match(self) -> "Case[Iterable[TSource]]":
95112
...
96113

97114
@overload
98115
def match(self, pattern: Any) -> Iterable[Iterable[TSource]]:
99116
...
100117

101118
def match(self, pattern: Any) -> Any:
102-
m = Case(self)
103-
return m.case(pattern) if pattern else m
119+
case: Case[Iterable[TSource]] = Case(self)
120+
return case(pattern) if pattern else case
104121

105122
@overload
106123
def pipe(self, __fn1: Callable[["Seq[TSource]"], TResult]) -> TResult:
@@ -193,6 +210,30 @@ def __iter__(self) -> Iterator[TSource]:
193210
return builtins.iter(self._value)
194211

195212

213+
def choose(chooser: Callable[[TSource], Option[TResult]]) -> Callable[[Iterable[TSource]], Iterable[TResult]]:
214+
"""Choose items from the sequence.
215+
216+
Applies the given function to each element of the list. Returns
217+
the list comprised of the results x for each element where the
218+
function returns `Some(x)`.
219+
220+
Args:
221+
chooser: The function to generate options from the elements.
222+
223+
Returns:
224+
The list comprising the values selected from the chooser
225+
function.
226+
"""
227+
228+
def _choose(source: Iterable[TSource]) -> Iterable[TResult]:
229+
def mapper(x: TSource) -> Iterable[TResult]:
230+
return chooser(x).to_seq()
231+
232+
return pipe(source, collect(mapper))
233+
234+
return _choose
235+
236+
196237
def collect(mapping: Callable[[TSource], Iterable[TResult]]) -> Callable[[Iterable[TSource]], Iterable[TResult]]:
197238
def _collect(source: Iterable[TSource]) -> Iterable[TResult]:
198239
return (x for xs in source for x in mapping(xs))
@@ -437,19 +478,21 @@ def _min_by(source: Iterable[TSource]) -> TSupportsLessThan:
437478
return _min_by
438479

439480

440-
def of(value: Iterable[TSource]) -> Seq[TSource]:
481+
def of(*args: TSource) -> Seq[TSource]:
441482
"""Create sequence from iterable.
442483
443484
Enables fluent dot chaining on the created sequence object.
444485
"""
445-
return Seq(value)
486+
return Seq(args)
487+
446488

489+
def of_iterable(source: Iterable[TSource]) -> Seq[TSource]:
490+
"""Alias to `Seq.of`."""
491+
return Seq(source)
447492

448-
of_list = of
449-
"""Alias to `Seq.of`."""
450493

451-
of_iterable = of
452-
"""Alias to `Seq.of`."""
494+
of_list = of_iterable
495+
"""Alias to `seq.of_iterable`."""
453496

454497

455498
@overload
@@ -570,6 +613,7 @@ def _zip(source2: Iterable[TResult]) -> Iterable[Tuple[TSource, TResult]]:
570613

571614
__all__ = [
572615
"Seq",
616+
"choose",
573617
"concat",
574618
"collect",
575619
"empty",

tests/test_asyncseq.py

-4
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,5 @@
1-
import asyncio
2-
from typing import Callable, Generator, Iterable, List
3-
41
import pytest
52
from expression.collections.asyncseq import AsyncSeq
6-
from expression.core import pipe
73
from hypothesis import given
84
from hypothesis import strategies as st
95

tests/test_result.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22

33
import pytest
44
from expression import effect
5-
from expression.core import Error, Ok, Result, Try, match, result
6-
from expression.core.result_try import Failure, Success
5+
from expression.core import Error, Failure, Ok, Result, Success, Try, match, result
76
from expression.extra.result import pipeline, sequence
87
from hypothesis import given
98
from hypothesis import strategies as st

tests/test_seq.py

+25-8
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import functools
22
from itertools import accumulate
3-
from typing import Callable, Generator, Iterable, List, Tuple
3+
from typing import Callable, Generator, Iterable, List, Optional, Tuple
44

55
import pytest
66
from expression import effect
77
from expression.collections import Seq, seq
8-
from expression.core import Nothing, Option, Some, pipe
8+
from expression.core import Nothing, Option, Some, option, pipe
99
from hypothesis import given
1010
from hypothesis import strategies as st
1111

@@ -80,22 +80,22 @@ def test_seq_head_empty_source():
8080

8181
@given(st.lists(st.integers(), min_size=1))
8282
def test_seq_head_fluent(xs: List[int]):
83-
value = seq.of(xs).head()
83+
value = seq.of_iterable(xs).head()
8484

8585
assert value == xs[0]
8686

8787

8888
@given(st.lists(st.integers(), min_size=1), st.integers())
8989
def test_seq_fold_pipe(xs: List[int], s: int):
9090
folder: Callable[[int, int], int] = lambda s, v: s + v
91-
value = pipe(seq.of(xs), seq.fold(folder, s))
91+
value = pipe(seq.of_iterable(xs), seq.fold(folder, s))
9292

9393
assert value == sum(xs) + s
9494

9595

9696
@given(st.lists(st.integers(), min_size=1), st.integers())
9797
def test_seq_fold_fluent(xs: List[int], s: int):
98-
value = seq.of(xs).fold(lambda s, v: s + v, s)
98+
value = seq.of_iterable(xs).fold(lambda s, v: s + v, s)
9999

100100
assert value == sum(xs) + s
101101

@@ -115,15 +115,15 @@ def unfolder(state: int) -> Option[Tuple[int, int]]:
115115
@given(st.lists(st.integers(), min_size=1), st.integers())
116116
def test_seq_scan_pipe(xs: List[int], s: int):
117117
func: Callable[[int, int], int] = lambda s, v: s + v
118-
value = pipe(seq.of(xs), seq.scan(func, s))
118+
value = pipe(seq.of_iterable(xs), seq.scan(func, s))
119119

120120
assert list(value) == list(accumulate(xs, func, initial=s))
121121

122122

123123
@given(st.lists(st.integers(), min_size=1), st.integers())
124124
def test_seq_scan_fluent(xs: List[int], s: int):
125125
func: Callable[[int, int], int] = lambda s, v: s + v
126-
value = seq.of(xs).scan(func, s)
126+
value = seq.of_iterable(xs).scan(func, s)
127127

128128
assert list(value) == list(accumulate(xs, func, initial=s))
129129

@@ -158,14 +158,31 @@ def test_seq_collect(xs: List[int]):
158158

159159
@given(st.lists(st.integers()))
160160
def test_seq_pipeline(xs: List[int]):
161-
ys = seq.of(xs).pipe(
161+
ys = seq.of_iterable(xs).pipe(
162162
seq.map(lambda x: x * 10),
163163
seq.filter(lambda x: x > 100),
164164
seq.fold(lambda s, x: s + x, 0),
165165
)
166166
assert ys == functools.reduce(lambda s, x: s + x, filter(lambda x: x > 100, map(lambda x: x * 10, xs)), 0)
167167

168168

169+
def test_seq_choose_option():
170+
xs = seq.of(None, 42)
171+
172+
chooser = seq.choose(option.of_optional)
173+
ys = pipe(xs, chooser)
174+
175+
assert list(ys) == [42]
176+
177+
178+
def test_seq_choose_option_fluent():
179+
xs = seq.of(None, 42)
180+
181+
ys = xs.choose(option.of_optional)
182+
183+
assert list(ys) == [42]
184+
185+
169186
@given(st.lists(st.integers()))
170187
def test_seq_infinite(xs: List[int]):
171188
ys = pipe(xs, seq.zip(seq.infinite))

0 commit comments

Comments
 (0)