36
36
37
37
import attr
38
38
39
- from synapse .api .constants import EventTypes , HistoryVisibility , Membership
39
+ from synapse .api .constants import (
40
+ EventTypes ,
41
+ EventUnsignedContentFields ,
42
+ HistoryVisibility ,
43
+ Membership ,
44
+ )
40
45
from synapse .events import EventBase
41
46
from synapse .events .snapshot import EventContext
42
- from synapse .events .utils import prune_event
47
+ from synapse .events .utils import clone_event , prune_event
43
48
from synapse .logging .opentracing import trace
44
49
from synapse .storage .controllers import StorageControllers
45
50
from synapse .storage .databases .main import DataStore
@@ -77,6 +82,7 @@ async def filter_events_for_client(
77
82
is_peeking : bool = False ,
78
83
always_include_ids : FrozenSet [str ] = frozenset (),
79
84
filter_send_to_client : bool = True ,
85
+ msc4115_membership_on_events : bool = False ,
80
86
) -> List [EventBase ]:
81
87
"""
82
88
Check which events a user is allowed to see. If the user can see the event but its
@@ -95,9 +101,12 @@ async def filter_events_for_client(
95
101
filter_send_to_client: Whether we're checking an event that's going to be
96
102
sent to a client. This might not always be the case since this function can
97
103
also be called to check whether a user can see the state at a given point.
104
+ msc4115_membership_on_events: Whether to include the requesting user's
105
+ membership in the "unsigned" data, per MSC4115.
98
106
99
107
Returns:
100
- The filtered events.
108
+ The filtered events. If `msc4115_membership_on_events` is true, the `unsigned`
109
+ data is annotated with the membership state of `user_id` at each event.
101
110
"""
102
111
# Filter out events that have been soft failed so that we don't relay them
103
112
# to clients.
@@ -134,21 +143,54 @@ async def filter_events_for_client(
134
143
)
135
144
136
145
def allowed (event : EventBase ) -> Optional [EventBase ]:
137
- return _check_client_allowed_to_see_event (
146
+ state_after_event = event_id_to_state .get (event .event_id )
147
+ filtered = _check_client_allowed_to_see_event (
138
148
user_id = user_id ,
139
149
event = event ,
140
150
clock = storage .main .clock ,
141
151
filter_send_to_client = filter_send_to_client ,
142
152
sender_ignored = event .sender in ignore_list ,
143
153
always_include_ids = always_include_ids ,
144
154
retention_policy = retention_policies [room_id ],
145
- state = event_id_to_state . get ( event . event_id ) ,
155
+ state = state_after_event ,
146
156
is_peeking = is_peeking ,
147
157
sender_erased = erased_senders .get (event .sender , False ),
148
158
)
159
+ if filtered is None :
160
+ return None
161
+
162
+ if not msc4115_membership_on_events :
163
+ return filtered
164
+
165
+ # Annotate the event with the user's membership after the event.
166
+ #
167
+ # Normally we just look in `state_after_event`, but if the event is an outlier
168
+ # we won't have such a state. The only outliers that are returned here are the
169
+ # user's own membership event, so we can just inspect that.
170
+
171
+ user_membership_event : Optional [EventBase ]
172
+ if event .type == EventTypes .Member and event .state_key == user_id :
173
+ user_membership_event = event
174
+ elif state_after_event is not None :
175
+ user_membership_event = state_after_event .get ((EventTypes .Member , user_id ))
176
+ else :
177
+ # unreachable!
178
+ raise Exception ("Missing state for event that is not user's own membership" )
179
+
180
+ user_membership = (
181
+ user_membership_event .membership
182
+ if user_membership_event
183
+ else Membership .LEAVE
184
+ )
149
185
150
- # Check each event: gives an iterable of None or (a potentially modified)
151
- # EventBase.
186
+ # Copy the event before updating the unsigned data: this shouldn't be persisted
187
+ # to the cache!
188
+ cloned = clone_event (filtered )
189
+ cloned .unsigned [EventUnsignedContentFields .MSC4115_MEMBERSHIP ] = user_membership
190
+
191
+ return cloned
192
+
193
+ # Check each event: gives an iterable of None or (a modified) EventBase.
152
194
filtered_events = map (allowed , events )
153
195
154
196
# Turn it into a list and remove None entries before returning.
@@ -396,7 +438,13 @@ def _check_client_allowed_to_see_event(
396
438
397
439
@attr .s (frozen = True , slots = True , auto_attribs = True )
398
440
class _CheckMembershipReturn :
399
- "Return value of _check_membership"
441
+ """Return value of `_check_membership`.
442
+
443
+ Attributes:
444
+ allowed: Whether the user should be allowed to see the event.
445
+ joined: Whether the user was joined to the room at the event.
446
+ """
447
+
400
448
allowed : bool
401
449
joined : bool
402
450
@@ -408,12 +456,7 @@ def _check_membership(
408
456
state : StateMap [EventBase ],
409
457
is_peeking : bool ,
410
458
) -> _CheckMembershipReturn :
411
- """Check whether the user can see the event due to their membership
412
-
413
- Returns:
414
- True if they can, False if they can't, plus the membership of the user
415
- at the event.
416
- """
459
+ """Check whether the user can see the event due to their membership"""
417
460
# If the event is the user's own membership event, use the 'most joined'
418
461
# membership
419
462
membership = None
@@ -435,7 +478,7 @@ def _check_membership(
435
478
if membership == "leave" and (
436
479
prev_membership == "join" or prev_membership == "invite"
437
480
):
438
- return _CheckMembershipReturn (True , membership == Membership . JOIN )
481
+ return _CheckMembershipReturn (True , False )
439
482
440
483
new_priority = MEMBERSHIP_PRIORITY .index (membership )
441
484
old_priority = MEMBERSHIP_PRIORITY .index (prev_membership )
0 commit comments