Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit df4b1e9

Browse files
authored
Pass room_id to get_auth_chain_difference (#8879)
This is so that we can choose which algorithm to use based on the room ID.
1 parent b774c55 commit df4b1e9

File tree

6 files changed

+33
-17
lines changed

6 files changed

+33
-17
lines changed

changelog.d/8879.misc

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Pass `room_id` to `get_auth_chain_difference`.

synapse/state/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -783,7 +783,7 @@ def get_events(
783783
)
784784

785785
def get_auth_chain_difference(
786-
self, state_sets: List[Set[str]]
786+
self, room_id: str, state_sets: List[Set[str]]
787787
) -> Awaitable[Set[str]]:
788788
"""Given sets of state events figure out the auth chain difference (as
789789
per state res v2 algorithm).
@@ -796,4 +796,4 @@ def get_auth_chain_difference(
796796
An awaitable that resolves to a set of event IDs.
797797
"""
798798

799-
return self.store.get_auth_chain_difference(state_sets)
799+
return self.store.get_auth_chain_difference(room_id, state_sets)

synapse/state/v2.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,9 @@ async def resolve_events_with_store(
9797

9898
# Also fetch all auth events that appear in only some of the state sets'
9999
# auth chains.
100-
auth_diff = await _get_auth_chain_difference(state_sets, event_map, state_res_store)
100+
auth_diff = await _get_auth_chain_difference(
101+
room_id, state_sets, event_map, state_res_store
102+
)
101103

102104
full_conflicted_set = set(
103105
itertools.chain(
@@ -236,6 +238,7 @@ async def _get_power_level_for_sender(
236238

237239

238240
async def _get_auth_chain_difference(
241+
room_id: str,
239242
state_sets: Sequence[StateMap[str]],
240243
event_map: Dict[str, EventBase],
241244
state_res_store: "synapse.state.StateResolutionStore",
@@ -332,7 +335,9 @@ async def _get_auth_chain_difference(
332335
difference_from_event_map = ()
333336
state_sets_ids = [set(state_set.values()) for state_set in state_sets]
334337

335-
difference = await state_res_store.get_auth_chain_difference(state_sets_ids)
338+
difference = await state_res_store.get_auth_chain_difference(
339+
room_id, state_sets_ids
340+
)
336341
difference.update(difference_from_event_map)
337342

338343
return difference

synapse/storage/databases/main/event_federation.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,9 @@ def _get_auth_chain_ids_txn(
137137

138138
return list(results)
139139

140-
async def get_auth_chain_difference(self, state_sets: List[Set[str]]) -> Set[str]:
140+
async def get_auth_chain_difference(
141+
self, room_id: str, state_sets: List[Set[str]]
142+
) -> Set[str]:
141143
"""Given sets of state events figure out the auth chain difference (as
142144
per state res v2 algorithm).
143145

tests/state/test_v2.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -623,7 +623,9 @@ def test_simple(self):
623623

624624
store = TestStateResolutionStore(persisted_events)
625625

626-
diff_d = _get_auth_chain_difference(state_sets, unpersited_events, store)
626+
diff_d = _get_auth_chain_difference(
627+
ROOM_ID, state_sets, unpersited_events, store
628+
)
627629
difference = self.successResultOf(defer.ensureDeferred(diff_d))
628630

629631
self.assertEqual(difference, {c.event_id})
@@ -662,7 +664,9 @@ def test_multiple_unpersisted_chain(self):
662664

663665
store = TestStateResolutionStore(persisted_events)
664666

665-
diff_d = _get_auth_chain_difference(state_sets, unpersited_events, store)
667+
diff_d = _get_auth_chain_difference(
668+
ROOM_ID, state_sets, unpersited_events, store
669+
)
666670
difference = self.successResultOf(defer.ensureDeferred(diff_d))
667671

668672
self.assertEqual(difference, {d.event_id, c.event_id})
@@ -707,7 +711,9 @@ def test_unpersisted_events_different_sets(self):
707711

708712
store = TestStateResolutionStore(persisted_events)
709713

710-
diff_d = _get_auth_chain_difference(state_sets, unpersited_events, store)
714+
diff_d = _get_auth_chain_difference(
715+
ROOM_ID, state_sets, unpersited_events, store
716+
)
711717
difference = self.successResultOf(defer.ensureDeferred(diff_d))
712718

713719
self.assertEqual(difference, {d.event_id, e.event_id})
@@ -773,7 +779,7 @@ def _get_auth_chain(self, event_ids: List[str]) -> List[str]:
773779

774780
return list(result)
775781

776-
def get_auth_chain_difference(self, auth_sets):
782+
def get_auth_chain_difference(self, room_id, auth_sets):
777783
chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets]
778784

779785
common = set(chains[0]).intersection(*chains[1:])

tests/storage/test_event_federation.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -202,39 +202,41 @@ def insert_event(txn, event_id, stream_ordering):
202202
# Now actually test that various combinations give the right result:
203203

204204
difference = self.get_success(
205-
self.store.get_auth_chain_difference([{"a"}, {"b"}])
205+
self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}])
206206
)
207207
self.assertSetEqual(difference, {"a", "b"})
208208

209209
difference = self.get_success(
210-
self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}])
210+
self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}])
211211
)
212212
self.assertSetEqual(difference, {"a", "b", "c", "e", "f"})
213213

214214
difference = self.get_success(
215-
self.store.get_auth_chain_difference([{"a", "c"}, {"b"}])
215+
self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b"}])
216216
)
217217
self.assertSetEqual(difference, {"a", "b", "c"})
218218

219219
difference = self.get_success(
220-
self.store.get_auth_chain_difference([{"a", "c"}, {"b", "c"}])
220+
self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b", "c"}])
221221
)
222222
self.assertSetEqual(difference, {"a", "b"})
223223

224224
difference = self.get_success(
225-
self.store.get_auth_chain_difference([{"a"}, {"b"}, {"d"}])
225+
self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"d"}])
226226
)
227227
self.assertSetEqual(difference, {"a", "b", "d", "e"})
228228

229229
difference = self.get_success(
230-
self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}, {"d"}])
230+
self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}, {"d"}])
231231
)
232232
self.assertSetEqual(difference, {"a", "b", "c", "d", "e", "f"})
233233

234234
difference = self.get_success(
235-
self.store.get_auth_chain_difference([{"a"}, {"b"}, {"e"}])
235+
self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"e"}])
236236
)
237237
self.assertSetEqual(difference, {"a", "b"})
238238

239-
difference = self.get_success(self.store.get_auth_chain_difference([{"a"}]))
239+
difference = self.get_success(
240+
self.store.get_auth_chain_difference(room_id, [{"a"}])
241+
)
240242
self.assertSetEqual(difference, set())

0 commit comments

Comments
 (0)