Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit eb609c6

Browse files
authored
Fix bug in StateFilter.return_expanded() and add some tests. (#12016)
1 parent 31a298f commit eb609c6

File tree

3 files changed

+117
-1
lines changed

3 files changed

+117
-1
lines changed

changelog.d/12016.misc

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix bug in `StateFilter.return_expanded()` and add some tests.

synapse/storage/state.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -204,13 +204,16 @@ def return_expanded(self) -> "StateFilter":
204204
if get_all_members:
205205
# We want to return everything.
206206
return StateFilter.all()
207-
else:
207+
elif EventTypes.Member in self.types:
208208
# We want to return all non-members, but only particular
209209
# memberships
210210
return StateFilter(
211211
types=frozendict({EventTypes.Member: self.types[EventTypes.Member]}),
212212
include_others=True,
213213
)
214+
else:
215+
# We want to return all non-members
216+
return _ALL_NON_MEMBER_STATE_FILTER
214217

215218
def make_sql_filter_clause(self) -> Tuple[str, List[str]]:
216219
"""Converts the filter to an SQL clause.
@@ -528,6 +531,9 @@ def approx_difference(self, other: "StateFilter") -> "StateFilter":
528531

529532

530533
_ALL_STATE_FILTER = StateFilter(types=frozendict(), include_others=True)
534+
_ALL_NON_MEMBER_STATE_FILTER = StateFilter(
535+
types=frozendict({EventTypes.Member: frozenset()}), include_others=True
536+
)
531537
_NONE_STATE_FILTER = StateFilter(types=frozendict(), include_others=False)
532538

533539

tests/storage/test_state.py

+109
Original file line numberDiff line numberDiff line change
@@ -992,3 +992,112 @@ def test_state_filter_difference_simple_cases(self):
992992
StateFilter.none(),
993993
StateFilter.all(),
994994
)
995+
996+
997+
class StateFilterTestCase(TestCase):
998+
def test_return_expanded(self):
999+
"""
1000+
Tests the behaviour of the return_expanded() function that expands
1001+
StateFilters to include more state types (for the sake of cache hit rate).
1002+
"""
1003+
1004+
self.assertEqual(StateFilter.all().return_expanded(), StateFilter.all())
1005+
1006+
self.assertEqual(StateFilter.none().return_expanded(), StateFilter.none())
1007+
1008+
# Concrete-only state filters stay the same
1009+
# (Case: mixed filter)
1010+
self.assertEqual(
1011+
StateFilter.freeze(
1012+
{
1013+
EventTypes.Member: {"@wombat:test", "@alicia:test"},
1014+
"some.other.state.type": {""},
1015+
},
1016+
include_others=False,
1017+
).return_expanded(),
1018+
StateFilter.freeze(
1019+
{
1020+
EventTypes.Member: {"@wombat:test", "@alicia:test"},
1021+
"some.other.state.type": {""},
1022+
},
1023+
include_others=False,
1024+
),
1025+
)
1026+
1027+
# Concrete-only state filters stay the same
1028+
# (Case: non-member-only filter)
1029+
self.assertEqual(
1030+
StateFilter.freeze(
1031+
{"some.other.state.type": {""}}, include_others=False
1032+
).return_expanded(),
1033+
StateFilter.freeze({"some.other.state.type": {""}}, include_others=False),
1034+
)
1035+
1036+
# Concrete-only state filters stay the same
1037+
# (Case: member-only filter)
1038+
self.assertEqual(
1039+
StateFilter.freeze(
1040+
{
1041+
EventTypes.Member: {"@wombat:test", "@alicia:test"},
1042+
},
1043+
include_others=False,
1044+
).return_expanded(),
1045+
StateFilter.freeze(
1046+
{
1047+
EventTypes.Member: {"@wombat:test", "@alicia:test"},
1048+
},
1049+
include_others=False,
1050+
),
1051+
)
1052+
1053+
# Wildcard member-only state filters stay the same
1054+
self.assertEqual(
1055+
StateFilter.freeze(
1056+
{EventTypes.Member: None},
1057+
include_others=False,
1058+
).return_expanded(),
1059+
StateFilter.freeze(
1060+
{EventTypes.Member: None},
1061+
include_others=False,
1062+
),
1063+
)
1064+
1065+
# If there is a wildcard in the non-member portion of the filter,
1066+
# it's expanded to include ALL non-member events.
1067+
# (Case: mixed filter)
1068+
self.assertEqual(
1069+
StateFilter.freeze(
1070+
{
1071+
EventTypes.Member: {"@wombat:test", "@alicia:test"},
1072+
"some.other.state.type": None,
1073+
},
1074+
include_others=False,
1075+
).return_expanded(),
1076+
StateFilter.freeze(
1077+
{EventTypes.Member: {"@wombat:test", "@alicia:test"}},
1078+
include_others=True,
1079+
),
1080+
)
1081+
1082+
# If there is a wildcard in the non-member portion of the filter,
1083+
# it's expanded to include ALL non-member events.
1084+
# (Case: non-member-only filter)
1085+
self.assertEqual(
1086+
StateFilter.freeze(
1087+
{
1088+
"some.other.state.type": None,
1089+
},
1090+
include_others=False,
1091+
).return_expanded(),
1092+
StateFilter.freeze({EventTypes.Member: set()}, include_others=True),
1093+
)
1094+
self.assertEqual(
1095+
StateFilter.freeze(
1096+
{
1097+
"some.other.state.type": None,
1098+
"yet.another.state.type": {"wombat"},
1099+
},
1100+
include_others=False,
1101+
).return_expanded(),
1102+
StateFilter.freeze({EventTypes.Member: set()}, include_others=True),
1103+
)

0 commit comments

Comments
 (0)