Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 07b1c70

Browse files
authored
Initial implementation of MSC3981: recursive relations API (#15315)
Adds an optional keyword argument to the /relations API which will recurse a limited number of event relationships. This will cause the API to return not just the events related to the parent event, but also events related to those related to the parent event, etc. This is disabled by default behind an experimental configuration flag and is currently implemented using prefixed parameters.
1 parent 3b853b1 commit 07b1c70

File tree

6 files changed

+186
-18
lines changed

6 files changed

+186
-18
lines changed

changelog.d/15315.feature

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Experimental support to recursively provide relations per [MSC3981](https://github.com/matrix-org/matrix-spec-proposals/pull/3981).

synapse/config/experimental.py

+5
Original file line numberDiff line numberDiff line change
@@ -192,5 +192,10 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
192192
# MSC2659: Application service ping endpoint
193193
self.msc2659_enabled = experimental.get("msc2659_enabled", False)
194194

195+
# MSC3981: Recurse relations
196+
self.msc3981_recurse_relations = experimental.get(
197+
"msc3981_recurse_relations", False
198+
)
199+
195200
# MSC3970: Scope transaction IDs to devices
196201
self.msc3970_enabled = experimental.get("msc3970_enabled", False)

synapse/handlers/relations.py

+3
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ async def get_relations(
8585
event_id: str,
8686
room_id: str,
8787
pagin_config: PaginationConfig,
88+
recurse: bool,
8889
include_original_event: bool,
8990
relation_type: Optional[str] = None,
9091
event_type: Optional[str] = None,
@@ -98,6 +99,7 @@ async def get_relations(
9899
event_id: Fetch events that relate to this event ID.
99100
room_id: The room the event belongs to.
100101
pagin_config: The pagination config rules to apply, if any.
102+
recurse: Whether to recursively find relations.
101103
include_original_event: Whether to include the parent event.
102104
relation_type: Only fetch events with this relation type, if given.
103105
event_type: Only fetch events with this event type, if given.
@@ -132,6 +134,7 @@ async def get_relations(
132134
direction=pagin_config.direction,
133135
from_token=pagin_config.from_token,
134136
to_token=pagin_config.to_token,
137+
recurse=recurse,
135138
)
136139

137140
events = await self._main_store.get_events_as_list(

synapse/rest/client/relations.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from synapse.api.constants import Direction
2020
from synapse.handlers.relations import ThreadsListInclude
2121
from synapse.http.server import HttpServer
22-
from synapse.http.servlet import RestServlet, parse_integer, parse_string
22+
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
2323
from synapse.http.site import SynapseRequest
2424
from synapse.rest.client._base import client_patterns
2525
from synapse.storage.databases.main.relations import ThreadsNextBatch
@@ -49,6 +49,7 @@ def __init__(self, hs: "HomeServer"):
4949
self.auth = hs.get_auth()
5050
self._store = hs.get_datastores().main
5151
self._relations_handler = hs.get_relations_handler()
52+
self._support_recurse = hs.config.experimental.msc3981_recurse_relations
5253

5354
async def on_GET(
5455
self,
@@ -63,6 +64,12 @@ async def on_GET(
6364
pagination_config = await PaginationConfig.from_request(
6465
self._store, request, default_limit=5, default_dir=Direction.BACKWARDS
6566
)
67+
if self._support_recurse:
68+
recurse = parse_boolean(
69+
request, "org.matrix.msc3981.recurse", default=False
70+
)
71+
else:
72+
recurse = False
6673

6774
# The unstable version of this API returns an extra field for client
6875
# compatibility, see https://github.com/matrix-org/synapse/issues/12930.
@@ -75,6 +82,7 @@ async def on_GET(
7582
event_id=parent_id,
7683
room_id=room_id,
7784
pagin_config=pagination_config,
85+
recurse=recurse,
7886
include_original_event=include_original_event,
7987
relation_type=relation_type,
8088
event_type=event_type,

synapse/storage/databases/main/relations.py

+48-17
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ async def get_relations_for_event(
172172
direction: Direction = Direction.BACKWARDS,
173173
from_token: Optional[StreamToken] = None,
174174
to_token: Optional[StreamToken] = None,
175+
recurse: bool = False,
175176
) -> Tuple[Sequence[_RelatedEvent], Optional[StreamToken]]:
176177
"""Get a list of relations for an event, ordered by topological ordering.
177178
@@ -186,6 +187,7 @@ async def get_relations_for_event(
186187
oldest first (forwards).
187188
from_token: Fetch rows from the given token, or from the start if None.
188189
to_token: Fetch rows up to the given token, or up to the end if None.
190+
recurse: Whether to recursively find relations.
189191
190192
Returns:
191193
A tuple of:
@@ -200,8 +202,8 @@ async def get_relations_for_event(
200202
# Ensure bad limits aren't being passed in.
201203
assert limit >= 0
202204

203-
where_clause = ["relates_to_id = ?", "room_id = ?"]
204-
where_args: List[Union[str, int]] = [event.event_id, room_id]
205+
where_clause = ["room_id = ?"]
206+
where_args: List[Union[str, int]] = [room_id]
205207
is_redacted = event.internal_metadata.is_redacted()
206208

207209
if relation_type is not None:
@@ -229,23 +231,52 @@ async def get_relations_for_event(
229231
if pagination_clause:
230232
where_clause.append(pagination_clause)
231233

232-
sql = """
233-
SELECT event_id, relation_type, sender, topological_ordering, stream_ordering
234-
FROM event_relations
235-
INNER JOIN events USING (event_id)
236-
WHERE %s
237-
ORDER BY topological_ordering %s, stream_ordering %s
238-
LIMIT ?
239-
""" % (
240-
" AND ".join(where_clause),
241-
order,
242-
order,
243-
)
234+
# If a recursive query is requested then the filters are applied after
235+
# recursively following relationships from the requested event to children
236+
# up to 3-relations deep.
237+
#
238+
# If no recursion is needed then the event_relations table is queried
239+
# for direct children of the requested event.
240+
if recurse:
241+
sql = """
242+
WITH RECURSIVE related_events AS (
243+
SELECT event_id, relation_type, relates_to_id, 0 AS depth
244+
FROM event_relations
245+
WHERE relates_to_id = ?
246+
UNION SELECT e.event_id, e.relation_type, e.relates_to_id, depth + 1
247+
FROM event_relations e
248+
INNER JOIN related_events r ON r.event_id = e.relates_to_id
249+
WHERE depth <= 3
250+
)
251+
SELECT event_id, relation_type, sender, topological_ordering, stream_ordering
252+
FROM related_events
253+
INNER JOIN events USING (event_id)
254+
WHERE %s
255+
ORDER BY topological_ordering %s, stream_ordering %s
256+
LIMIT ?;
257+
""" % (
258+
" AND ".join(where_clause),
259+
order,
260+
order,
261+
)
262+
else:
263+
sql = """
264+
SELECT event_id, relation_type, sender, topological_ordering, stream_ordering
265+
FROM event_relations
266+
INNER JOIN events USING (event_id)
267+
WHERE relates_to_id = ? AND %s
268+
ORDER BY topological_ordering %s, stream_ordering %s
269+
LIMIT ?
270+
""" % (
271+
" AND ".join(where_clause),
272+
order,
273+
order,
274+
)
244275

245276
def _get_recent_references_for_event_txn(
246277
txn: LoggingTransaction,
247278
) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]:
248-
txn.execute(sql, where_args + [limit + 1])
279+
txn.execute(sql, [event.event_id] + where_args + [limit + 1])
249280

250281
events = []
251282
topo_orderings: List[int] = []
@@ -965,7 +996,7 @@ async def get_thread_id(self, event_id: str) -> str:
965996
# relation.
966997
sql = """
967998
WITH RECURSIVE related_events AS (
968-
SELECT event_id, relates_to_id, relation_type, 0 depth
999+
SELECT event_id, relates_to_id, relation_type, 0 AS depth
9691000
FROM event_relations
9701001
WHERE event_id = ?
9711002
UNION SELECT e.event_id, e.relates_to_id, e.relation_type, depth + 1
@@ -1025,7 +1056,7 @@ async def get_thread_id_for_receipts(self, event_id: str) -> str:
10251056
sql = """
10261057
SELECT relates_to_id FROM event_relations WHERE relates_to_id = COALESCE((
10271058
WITH RECURSIVE related_events AS (
1028-
SELECT event_id, relates_to_id, relation_type, 0 depth
1059+
SELECT event_id, relates_to_id, relation_type, 0 AS depth
10291060
FROM event_relations
10301061
WHERE event_id = ?
10311062
UNION SELECT e.event_id, e.relates_to_id, e.relation_type, depth + 1

tests/rest/client/test_relations.py

+120
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from tests.server import FakeChannel
3131
from tests.test_utils import make_awaitable
3232
from tests.test_utils.event_injection import inject_event
33+
from tests.unittest import override_config
3334

3435

3536
class BaseRelationsTestCase(unittest.HomeserverTestCase):
@@ -949,6 +950,125 @@ def test_pagination_from_sync_and_messages(self) -> None:
949950
)
950951

951952

953+
class RecursiveRelationTestCase(BaseRelationsTestCase):
954+
@override_config({"experimental_features": {"msc3981_recurse_relations": True}})
955+
def test_recursive_relations(self) -> None:
956+
"""Generate a complex, multi-level relationship tree and query it."""
957+
# Create a thread with a few messages in it.
958+
channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
959+
thread_1 = channel.json_body["event_id"]
960+
961+
channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
962+
thread_2 = channel.json_body["event_id"]
963+
964+
# Add annotations.
965+
channel = self._send_relation(
966+
RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_2
967+
)
968+
annotation_1 = channel.json_body["event_id"]
969+
970+
channel = self._send_relation(
971+
RelationTypes.ANNOTATION, "m.reaction", "b", parent_id=thread_1
972+
)
973+
annotation_2 = channel.json_body["event_id"]
974+
975+
# Add a reference to part of the thread, then edit the reference and annotate it.
976+
channel = self._send_relation(
977+
RelationTypes.REFERENCE, "m.room.test", parent_id=thread_2
978+
)
979+
reference_1 = channel.json_body["event_id"]
980+
981+
channel = self._send_relation(
982+
RelationTypes.ANNOTATION, "m.reaction", "c", parent_id=reference_1
983+
)
984+
annotation_3 = channel.json_body["event_id"]
985+
986+
channel = self._send_relation(
987+
RelationTypes.REPLACE,
988+
"m.room.test",
989+
parent_id=reference_1,
990+
)
991+
edit = channel.json_body["event_id"]
992+
993+
# Also more events off the root.
994+
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "d")
995+
annotation_4 = channel.json_body["event_id"]
996+
997+
channel = self.make_request(
998+
"GET",
999+
f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}"
1000+
"?dir=f&limit=20&org.matrix.msc3981.recurse=true",
1001+
access_token=self.user_token,
1002+
)
1003+
self.assertEqual(200, channel.code, channel.json_body)
1004+
1005+
# The above events should be returned in creation order.
1006+
event_ids = [ev["event_id"] for ev in channel.json_body["chunk"]]
1007+
self.assertEqual(
1008+
event_ids,
1009+
[
1010+
thread_1,
1011+
thread_2,
1012+
annotation_1,
1013+
annotation_2,
1014+
reference_1,
1015+
annotation_3,
1016+
edit,
1017+
annotation_4,
1018+
],
1019+
)
1020+
1021+
@override_config({"experimental_features": {"msc3981_recurse_relations": True}})
1022+
def test_recursive_relations_with_filter(self) -> None:
1023+
"""The event_type and rel_type still apply."""
1024+
# Create a thread with a few messages in it.
1025+
channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
1026+
thread_1 = channel.json_body["event_id"]
1027+
1028+
# Add annotations.
1029+
channel = self._send_relation(
1030+
RelationTypes.ANNOTATION, "m.reaction", "b", parent_id=thread_1
1031+
)
1032+
annotation_1 = channel.json_body["event_id"]
1033+
1034+
# Add a reference to part of the thread, then edit the reference and annotate it.
1035+
channel = self._send_relation(
1036+
RelationTypes.REFERENCE, "m.room.test", parent_id=thread_1
1037+
)
1038+
reference_1 = channel.json_body["event_id"]
1039+
1040+
channel = self._send_relation(
1041+
RelationTypes.ANNOTATION, "org.matrix.reaction", "c", parent_id=reference_1
1042+
)
1043+
annotation_2 = channel.json_body["event_id"]
1044+
1045+
# Fetch only annotations, but recursively.
1046+
channel = self.make_request(
1047+
"GET",
1048+
f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}/{RelationTypes.ANNOTATION}"
1049+
"?dir=f&limit=20&org.matrix.msc3981.recurse=true",
1050+
access_token=self.user_token,
1051+
)
1052+
self.assertEqual(200, channel.code, channel.json_body)
1053+
1054+
# The above events should be returned in creation order.
1055+
event_ids = [ev["event_id"] for ev in channel.json_body["chunk"]]
1056+
self.assertEqual(event_ids, [annotation_1, annotation_2])
1057+
1058+
# Fetch only m.reactions, but recursively.
1059+
channel = self.make_request(
1060+
"GET",
1061+
f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}/{RelationTypes.ANNOTATION}/m.reaction"
1062+
"?dir=f&limit=20&org.matrix.msc3981.recurse=true",
1063+
access_token=self.user_token,
1064+
)
1065+
self.assertEqual(200, channel.code, channel.json_body)
1066+
1067+
# The above events should be returned in creation order.
1068+
event_ids = [ev["event_id"] for ev in channel.json_body["chunk"]]
1069+
self.assertEqual(event_ids, [annotation_1])
1070+
1071+
9521072
class BundledAggregationsTestCase(BaseRelationsTestCase):
9531073
"""
9541074
See RelationsTestCase.test_edit for a similar test for edits.

0 commit comments

Comments
 (0)