|
15 | 15 | import logging
|
16 | 16 | from typing import (
|
17 | 17 | TYPE_CHECKING,
|
18 |
| - Callable, |
19 | 18 | Collection,
|
20 | 19 | Dict,
|
21 | 20 | FrozenSet,
|
|
52 | 51 | from synapse.util.async_helpers import Linearizer
|
53 | 52 | from synapse.util.caches import intern_string
|
54 | 53 | from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
|
55 |
| -from synapse.util.cancellation import cancellable |
56 | 54 | from synapse.util.iterutils import batch_iter
|
57 | 55 | from synapse.util.metrics import Measure
|
58 | 56 |
|
@@ -670,19 +668,86 @@ def _get_users_server_still_shares_room_with_txn(
|
670 | 668 | _get_users_server_still_shares_room_with_txn,
|
671 | 669 | )
|
672 | 670 |
|
673 |
| - @cancellable |
| 671 | + @cached(max_entries=500000, iterable=True) |
674 | 672 | async def get_rooms_for_user(
|
675 |
| - self, user_id: str, on_invalidate: Optional[Callable[[], None]] = None |
| 673 | + self, user_id: str |
676 | 674 | ) -> FrozenSet[str]:
|
677 | 675 | """Returns a set of room_ids the user is currently joined to.
|
678 | 676 |
|
679 | 677 | If a remote user only returns rooms this server is currently
|
680 | 678 | participating in.
|
681 | 679 | """
|
682 |
| - rooms = await self.get_rooms_for_user_with_stream_ordering( |
683 |
| - user_id, on_invalidate=on_invalidate |
| 680 | + rooms = self.get_rooms_for_user_with_stream_ordering.cache.get_immediate(user_id) |
| 681 | + if rooms: |
| 682 | + return frozenset(r.room_id for r in rooms) |
| 683 | + |
| 684 | + return await self.db_pool.runInteraction( |
| 685 | + "get_rooms_for_user", |
| 686 | + self._get_rooms_for_user_txn, |
| 687 | + user_id, |
684 | 688 | )
|
685 |
| - return frozenset(r.room_id for r in rooms) |
| 689 | + |
| 690 | + def _get_rooms_for_user_txn( |
| 691 | + self, txn: LoggingTransaction, user_id: str |
| 692 | + ) -> FrozenSet[str]: |
| 693 | + sql = """ |
| 694 | + SELECT room_id |
| 695 | + FROM current_state_events AS c |
| 696 | + WHERE |
| 697 | + c.type = 'm.room.member' |
| 698 | + AND c.state_key = ? |
| 699 | + AND c.membership = ? |
| 700 | + """ |
| 701 | + |
| 702 | + txn.execute(sql, (user_id, Membership.JOIN)) |
| 703 | + return frozenset(txn) |
| 704 | + |
| 705 | + @cachedList( |
| 706 | + cached_method_name="get_rooms_for_user", |
| 707 | + list_name="user_ids", |
| 708 | + ) |
| 709 | + async def get_rooms_for_users( |
| 710 | + self, user_ids: Collection[str] |
| 711 | + ) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]: |
| 712 | + """A batched version of `get_rooms_for_user`. |
| 713 | +
|
| 714 | + Returns: |
| 715 | + Map from user_id to set of rooms that is currently in. |
| 716 | + """ |
| 717 | + return await self.db_pool.runInteraction( |
| 718 | + "get_rooms_for_users", |
| 719 | + self._get_rooms_for_users_txn, |
| 720 | + user_ids, |
| 721 | + ) |
| 722 | + |
| 723 | + def _get_rooms_for_users_txn( |
| 724 | + self, txn: LoggingTransaction, user_ids: Collection[str] |
| 725 | + ) -> Dict[str, FrozenSet[str]]: |
| 726 | + |
| 727 | + clause, args = make_in_list_sql_clause( |
| 728 | + self.database_engine, |
| 729 | + "c.state_key", |
| 730 | + user_ids, |
| 731 | + ) |
| 732 | + |
| 733 | + sql = f""" |
| 734 | + SELECT c.state_key, room_id |
| 735 | + FROM current_state_events AS c |
| 736 | + WHERE |
| 737 | + c.type = 'm.room.member' |
| 738 | + AND c.membership = ? |
| 739 | + AND {clause} |
| 740 | + """ |
| 741 | + |
| 742 | + txn.execute(sql, [Membership.JOIN] + args) |
| 743 | + |
| 744 | + result: Dict[str, Set[str]] = { |
| 745 | + user_id: set() for user_id in user_ids |
| 746 | + } |
| 747 | + for user_id, room_id in txn: |
| 748 | + result[user_id].add(room_id) |
| 749 | + |
| 750 | + return {user_id: frozenset(v) for user_id, v in result.items()} |
686 | 751 |
|
687 | 752 | @cached(max_entries=10000)
|
688 | 753 | async def does_pair_of_users_share_a_room(
|
|
0 commit comments