|
15 | 15 | import logging
|
16 | 16 | from typing import (
|
17 | 17 | TYPE_CHECKING,
|
18 |
| - Callable, |
19 | 18 | Collection,
|
20 | 19 | Dict,
|
21 | 20 | FrozenSet,
|
@@ -690,117 +689,109 @@ def _get_rooms_for_user_with_stream_ordering_txn(
|
690 | 689 | for room_id, instance, stream_id in txn
|
691 | 690 | )
|
692 | 691 |
|
693 |
| - @cachedList( |
694 |
| - cached_method_name="get_rooms_for_user_with_stream_ordering", |
695 |
| - list_name="user_ids", |
696 |
| - ) |
697 |
| - async def get_rooms_for_users_with_stream_ordering( |
| 692 | + async def get_users_server_still_shares_room_with( |
698 | 693 | self, user_ids: Collection[str]
|
699 |
| - ) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]: |
700 |
| - """A batched version of `get_rooms_for_user_with_stream_ordering`. |
701 |
| -
|
702 |
| - Returns: |
703 |
| - Map from user_id to set of rooms that is currently in. |
| 694 | + ) -> Set[str]: |
| 695 | + """Given a list of users return the set that the server still share a |
| 696 | + room with. |
704 | 697 | """
|
| 698 | + |
| 699 | + if not user_ids: |
| 700 | + return set() |
| 701 | + |
705 | 702 | return await self.db_pool.runInteraction(
|
706 |
| - "get_rooms_for_users_with_stream_ordering", |
707 |
| - self._get_rooms_for_users_with_stream_ordering_txn, |
| 703 | + "get_users_server_still_shares_room_with", |
| 704 | + self.get_users_server_still_shares_room_with_txn, |
708 | 705 | user_ids,
|
709 | 706 | )
|
710 | 707 |
|
711 |
| - def _get_rooms_for_users_with_stream_ordering_txn( |
712 |
| - self, txn: LoggingTransaction, user_ids: Collection[str] |
713 |
| - ) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]: |
| 708 | + def get_users_server_still_shares_room_with_txn( |
| 709 | + self, |
| 710 | + txn: LoggingTransaction, |
| 711 | + user_ids: Collection[str], |
| 712 | + ) -> Set[str]: |
| 713 | + if not user_ids: |
| 714 | + return set() |
| 715 | + |
| 716 | + sql = """ |
| 717 | + SELECT state_key FROM current_state_events |
| 718 | + WHERE |
| 719 | + type = 'm.room.member' |
| 720 | + AND membership = 'join' |
| 721 | + AND %s |
| 722 | + GROUP BY state_key |
| 723 | + """ |
714 | 724 |
|
715 | 725 | clause, args = make_in_list_sql_clause(
|
716 |
| - self.database_engine, |
717 |
| - "c.state_key", |
718 |
| - user_ids, |
| 726 | + self.database_engine, "state_key", user_ids |
719 | 727 | )
|
720 | 728 |
|
721 |
| - if self._current_state_events_membership_up_to_date: |
722 |
| - sql = f""" |
723 |
| - SELECT c.state_key, room_id, e.instance_name, e.stream_ordering |
724 |
| - FROM current_state_events AS c |
725 |
| - INNER JOIN events AS e USING (room_id, event_id) |
726 |
| - WHERE |
727 |
| - c.type = 'm.room.member' |
728 |
| - AND c.membership = ? |
729 |
| - AND {clause} |
730 |
| - """ |
731 |
| - else: |
732 |
| - sql = f""" |
733 |
| - SELECT c.state_key, room_id, e.instance_name, e.stream_ordering |
734 |
| - FROM current_state_events AS c |
735 |
| - INNER JOIN room_memberships AS m USING (room_id, event_id) |
736 |
| - INNER JOIN events AS e USING (room_id, event_id) |
737 |
| - WHERE |
738 |
| - c.type = 'm.room.member' |
739 |
| - AND m.membership = ? |
740 |
| - AND {clause} |
741 |
| - """ |
| 729 | + txn.execute(sql % (clause,), args) |
742 | 730 |
|
743 |
| - txn.execute(sql, [Membership.JOIN] + args) |
| 731 | + return {row[0] for row in txn} |
744 | 732 |
|
745 |
| - result: Dict[str, Set[GetRoomsForUserWithStreamOrdering]] = { |
746 |
| - user_id: set() for user_id in user_ids |
747 |
| - } |
748 |
| - for user_id, room_id, instance, stream_id in txn: |
749 |
| - result[user_id].add( |
750 |
| - GetRoomsForUserWithStreamOrdering( |
751 |
| - room_id, PersistedEventPosition(instance, stream_id) |
752 |
| - ) |
753 |
| - ) |
754 |
| - |
755 |
| - return {user_id: frozenset(v) for user_id, v in result.items()} |
| 733 | + @cached(max_entries=500000, iterable=True) |
| 734 | + async def get_rooms_for_user(self, user_id: str) -> FrozenSet[str]: |
| 735 | + """Returns a set of room_ids the user is currently joined to. |
756 | 736 |
|
757 |
| - async def get_users_server_still_shares_room_with( |
758 |
| - self, user_ids: Collection[str] |
759 |
| - ) -> Set[str]: |
760 |
| - """Given a list of users return the set that the server still share a |
761 |
| - room with. |
| 737 | + If a remote user only returns rooms this server is currently |
| 738 | + participating in. |
762 | 739 | """
|
| 740 | + rooms = self.get_rooms_for_user_with_stream_ordering.cache.get_immediate( |
| 741 | + (user_id,), |
| 742 | + None, |
| 743 | + update_metrics=False, |
| 744 | + ) |
| 745 | + if rooms: |
| 746 | + return frozenset(r.room_id for r in rooms) |
763 | 747 |
|
764 |
| - if not user_ids: |
765 |
| - return set() |
766 |
| - |
767 |
| - def _get_users_server_still_shares_room_with_txn( |
768 |
| - txn: LoggingTransaction, |
769 |
| - ) -> Set[str]: |
770 |
| - sql = """ |
771 |
| - SELECT state_key FROM current_state_events |
772 |
| - WHERE |
773 |
| - type = 'm.room.member' |
774 |
| - AND membership = 'join' |
775 |
| - AND %s |
776 |
| - GROUP BY state_key |
777 |
| - """ |
| 748 | + room_ids = await self.db_pool.simple_select_onecol( |
| 749 | + table="current_state_events", |
| 750 | + keyvalues={ |
| 751 | + "type": EventTypes.Member, |
| 752 | + "membership": Membership.JOIN, |
| 753 | + "state_key": user_id, |
| 754 | + }, |
| 755 | + retcol="room_id", |
| 756 | + desc="get_rooms_for_user", |
| 757 | + ) |
778 | 758 |
|
779 |
| - clause, args = make_in_list_sql_clause( |
780 |
| - self.database_engine, "state_key", user_ids |
781 |
| - ) |
| 759 | + return frozenset(room_ids) |
782 | 760 |
|
783 |
| - txn.execute(sql % (clause,), args) |
| 761 | + @cachedList( |
| 762 | + cached_method_name="get_rooms_for_user", |
| 763 | + list_name="user_ids", |
| 764 | + ) |
| 765 | + async def get_rooms_for_users( |
| 766 | + self, user_ids: Collection[str] |
| 767 | + ) -> Dict[str, FrozenSet[str]]: |
| 768 | + """A batched version of `get_rooms_for_user`. |
784 | 769 |
|
785 |
| - return {row[0] for row in txn} |
| 770 | + Returns: |
| 771 | + Map from user_id to set of rooms that is currently in. |
| 772 | + """ |
786 | 773 |
|
787 |
| - return await self.db_pool.runInteraction( |
788 |
| - "get_users_server_still_shares_room_with", |
789 |
| - _get_users_server_still_shares_room_with_txn, |
| 774 | + rows = await self.db_pool.simple_select_many_batch( |
| 775 | + table="current_state_events", |
| 776 | + column="state_key", |
| 777 | + iterable=user_ids, |
| 778 | + retcols=( |
| 779 | + "state_key", |
| 780 | + "room_id", |
| 781 | + ), |
| 782 | + keyvalues={ |
| 783 | + "type": EventTypes.Member, |
| 784 | + "membership": Membership.JOIN, |
| 785 | + }, |
| 786 | + desc="get_rooms_for_users", |
790 | 787 | )
|
791 | 788 |
|
792 |
| - async def get_rooms_for_user( |
793 |
| - self, user_id: str, on_invalidate: Optional[Callable[[], None]] = None |
794 |
| - ) -> FrozenSet[str]: |
795 |
| - """Returns a set of room_ids the user is currently joined to. |
| 789 | + user_rooms: Dict[str, Set[str]] = {user_id: set() for user_id in user_ids} |
796 | 790 |
|
797 |
| - If a remote user only returns rooms this server is currently |
798 |
| - participating in. |
799 |
| - """ |
800 |
| - rooms = await self.get_rooms_for_user_with_stream_ordering( |
801 |
| - user_id, on_invalidate=on_invalidate |
802 |
| - ) |
803 |
| - return frozenset(r.room_id for r in rooms) |
| 791 | + for row in rows: |
| 792 | + user_rooms[row["state_key"]].add(row["room_id"]) |
| 793 | + |
| 794 | + return {key: frozenset(rooms) for key, rooms in user_rooms.items()} |
804 | 795 |
|
805 | 796 | @cached(max_entries=10000)
|
806 | 797 | async def does_pair_of_users_share_a_room(
|
|
0 commit comments