|
15 | 15 | import logging
|
16 | 16 | from typing import (
|
17 | 17 | TYPE_CHECKING,
|
18 |
| - Callable, |
19 | 18 | Collection,
|
20 | 19 | Dict,
|
21 | 20 | FrozenSet,
|
|
52 | 51 | from synapse.util.async_helpers import Linearizer
|
53 | 52 | from synapse.util.caches import intern_string
|
54 | 53 | from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
|
55 |
| -from synapse.util.cancellation import cancellable |
56 | 54 | from synapse.util.iterutils import batch_iter
|
57 | 55 | from synapse.util.metrics import Measure
|
58 | 56 |
|
@@ -619,106 +617,109 @@ def _get_rooms_for_user_with_stream_ordering_txn(
|
619 | 617 | for room_id, instance, stream_id in txn
|
620 | 618 | )
|
621 | 619 |
|
622 |
| - @cachedList( |
623 |
| - cached_method_name="get_rooms_for_user_with_stream_ordering", |
624 |
| - list_name="user_ids", |
625 |
| - ) |
626 |
| - async def get_rooms_for_users_with_stream_ordering( |
| 620 | + async def get_users_server_still_shares_room_with( |
627 | 621 | self, user_ids: Collection[str]
|
628 |
| - ) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]: |
629 |
| - """A batched version of `get_rooms_for_user_with_stream_ordering`. |
630 |
| -
|
631 |
| - Returns: |
632 |
| - Map from user_id to set of rooms that is currently in. |
| 622 | + ) -> Set[str]: |
| 623 | + """Given a list of users return the set that the server still share a |
| 624 | + room with. |
633 | 625 | """
|
634 |
| - return await self.db_pool.runInteraction( |
635 |
| - "get_rooms_for_users_with_stream_ordering", |
636 |
| - self._get_rooms_for_users_with_stream_ordering_txn, |
637 |
| - user_ids, |
638 |
| - ) |
639 | 626 |
|
640 |
| - def _get_rooms_for_users_with_stream_ordering_txn( |
641 |
| - self, txn: LoggingTransaction, user_ids: Collection[str] |
642 |
| - ) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]: |
| 627 | + if not user_ids: |
| 628 | + return set() |
643 | 629 |
|
644 |
| - clause, args = make_in_list_sql_clause( |
645 |
| - self.database_engine, |
646 |
| - "c.state_key", |
| 630 | + return await self.db_pool.runInteraction( |
| 631 | + "get_users_server_still_shares_room_with", |
| 632 | + self.get_users_server_still_shares_room_with_txn, |
647 | 633 | user_ids,
|
648 | 634 | )
|
649 | 635 |
|
650 |
| - sql = f""" |
651 |
| - SELECT c.state_key, room_id, e.instance_name, e.stream_ordering |
652 |
| - FROM current_state_events AS c |
653 |
| - INNER JOIN events AS e USING (room_id, event_id) |
| 636 | + def get_users_server_still_shares_room_with_txn( |
| 637 | + self, |
| 638 | + txn: LoggingTransaction, |
| 639 | + user_ids: Collection[str], |
| 640 | + ) -> Set[str]: |
| 641 | + if not user_ids: |
| 642 | + return set() |
| 643 | + |
| 644 | + sql = """ |
| 645 | + SELECT state_key FROM current_state_events |
654 | 646 | WHERE
|
655 |
| - c.type = 'm.room.member' |
656 |
| - AND c.membership = ? |
657 |
| - AND {clause} |
| 647 | + type = 'm.room.member' |
| 648 | + AND membership = 'join' |
| 649 | + AND %s |
| 650 | + GROUP BY state_key |
658 | 651 | """
|
659 | 652 |
|
660 |
| - txn.execute(sql, [Membership.JOIN] + args) |
| 653 | + clause, args = make_in_list_sql_clause( |
| 654 | + self.database_engine, "state_key", user_ids |
| 655 | + ) |
661 | 656 |
|
662 |
| - result: Dict[str, Set[GetRoomsForUserWithStreamOrdering]] = { |
663 |
| - user_id: set() for user_id in user_ids |
664 |
| - } |
665 |
| - for user_id, room_id, instance, stream_id in txn: |
666 |
| - result[user_id].add( |
667 |
| - GetRoomsForUserWithStreamOrdering( |
668 |
| - room_id, PersistedEventPosition(instance, stream_id) |
669 |
| - ) |
670 |
| - ) |
| 657 | + txn.execute(sql % (clause,), args) |
671 | 658 |
|
672 |
| - return {user_id: frozenset(v) for user_id, v in result.items()} |
| 659 | + return {row[0] for row in txn} |
673 | 660 |
|
674 |
| - async def get_users_server_still_shares_room_with( |
675 |
| - self, user_ids: Collection[str] |
676 |
| - ) -> Set[str]: |
677 |
| - """Given a list of users return the set that the server still share a |
678 |
| - room with. |
| 661 | + @cached(max_entries=500000, iterable=True) |
| 662 | + async def get_rooms_for_user(self, user_id: str) -> FrozenSet[str]: |
| 663 | + """Returns a set of room_ids the user is currently joined to. |
| 664 | +
|
| 665 | + If a remote user only returns rooms this server is currently |
| 666 | + participating in. |
679 | 667 | """
|
| 668 | + rooms = self.get_rooms_for_user_with_stream_ordering.cache.get_immediate( |
| 669 | + (user_id,), |
| 670 | + None, |
| 671 | + update_metrics=False, |
| 672 | + ) |
| 673 | + if rooms: |
| 674 | + return frozenset(r.room_id for r in rooms) |
680 | 675 |
|
681 |
| - if not user_ids: |
682 |
| - return set() |
| 676 | + room_ids = await self.db_pool.simple_select_onecol( |
| 677 | + table="current_state_events", |
| 678 | + keyvalues={ |
| 679 | + "type": EventTypes.Member, |
| 680 | + "membership": Membership.JOIN, |
| 681 | + "state_key": user_id, |
| 682 | + }, |
| 683 | + retcol="room_id", |
| 684 | + desc="get_rooms_for_user", |
| 685 | + ) |
683 | 686 |
|
684 |
| - def _get_users_server_still_shares_room_with_txn( |
685 |
| - txn: LoggingTransaction, |
686 |
| - ) -> Set[str]: |
687 |
| - sql = """ |
688 |
| - SELECT state_key FROM current_state_events |
689 |
| - WHERE |
690 |
| - type = 'm.room.member' |
691 |
| - AND membership = 'join' |
692 |
| - AND %s |
693 |
| - GROUP BY state_key |
694 |
| - """ |
| 687 | + return frozenset(room_ids) |
695 | 688 |
|
696 |
| - clause, args = make_in_list_sql_clause( |
697 |
| - self.database_engine, "state_key", user_ids |
698 |
| - ) |
699 |
| - |
700 |
| - txn.execute(sql % (clause,), args) |
| 689 | + @cachedList( |
| 690 | + cached_method_name="get_rooms_for_user", |
| 691 | + list_name="user_ids", |
| 692 | + ) |
| 693 | + async def get_rooms_for_users( |
| 694 | + self, user_ids: Collection[str] |
| 695 | + ) -> Dict[str, FrozenSet[str]]: |
| 696 | + """A batched version of `get_rooms_for_user`. |
701 | 697 |
|
702 |
| - return {row[0] for row in txn} |
| 698 | + Returns: |
| 699 | + Map from user_id to set of rooms that is currently in. |
| 700 | + """ |
703 | 701 |
|
704 |
| - return await self.db_pool.runInteraction( |
705 |
| - "get_users_server_still_shares_room_with", |
706 |
| - _get_users_server_still_shares_room_with_txn, |
| 702 | + rows = await self.db_pool.simple_select_many_batch( |
| 703 | + table="current_state_events", |
| 704 | + column="state_key", |
| 705 | + iterable=user_ids, |
| 706 | + retcols=( |
| 707 | + "state_key", |
| 708 | + "room_id", |
| 709 | + ), |
| 710 | + keyvalues={ |
| 711 | + "type": EventTypes.Member, |
| 712 | + "membership": Membership.JOIN, |
| 713 | + }, |
| 714 | + desc="get_rooms_for_users", |
707 | 715 | )
|
708 | 716 |
|
709 |
| - @cancellable |
710 |
| - async def get_rooms_for_user( |
711 |
| - self, user_id: str, on_invalidate: Optional[Callable[[], None]] = None |
712 |
| - ) -> FrozenSet[str]: |
713 |
| - """Returns a set of room_ids the user is currently joined to. |
| 717 | + user_rooms: Dict[str, Set[str]] = {user_id: set() for user_id in user_ids} |
714 | 718 |
|
715 |
| - If a remote user only returns rooms this server is currently |
716 |
| - participating in. |
717 |
| - """ |
718 |
| - rooms = await self.get_rooms_for_user_with_stream_ordering( |
719 |
| - user_id, on_invalidate=on_invalidate |
720 |
| - ) |
721 |
| - return frozenset(r.room_id for r in rooms) |
| 719 | + for row in rows: |
| 720 | + user_rooms[row["state_key"]].add(row["room_id"]) |
| 721 | + |
| 722 | + return {key: frozenset(rooms) for key, rooms in user_rooms.items()} |
722 | 723 |
|
723 | 724 | @cached(max_entries=10000)
|
724 | 725 | async def does_pair_of_users_share_a_room(
|
|
0 commit comments