From 6c585094ea62c1e7a7a00d5652c882772f9e4cb9 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 9 Mar 2022 08:14:40 -0500 Subject: [PATCH 1/9] Allow for ignoring some arguments when caching. --- synapse/util/caches/descriptors.py | 29 +++++++++++++++++--- tests/util/caches/test_descriptors.py | 39 +++++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 3 deletions(-) diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 1cdead02f14b..b4cc5067a862 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -20,6 +20,7 @@ Any, Awaitable, Callable, + Collection, Dict, Generic, Hashable, @@ -69,6 +70,7 @@ def __init__( self, orig: Callable[..., Any], num_args: Optional[int], + uncached_args: Optional[Collection[str]] = None, cache_context: bool = False, ): self.orig = orig @@ -88,6 +90,9 @@ def __init__( " named `cache_context`" ) + if num_args is not None and uncached_args is not None: + raise ValueError("Cannot provide both num_args and uncached_args") + if num_args is None: num_args = len(all_args) - 1 if cache_context: @@ -105,6 +110,11 @@ def __init__( # list of the names of the args used as the cache key self.arg_names = all_args[1 : num_args + 1] + # If there are args to not cache on, filter them out (and fix the size of num_args). + if uncached_args is not None: + self.num_args -= len(uncached_args) + self.arg_names = [n for n in self.arg_names if n not in uncached_args] + # self.arg_defaults is a map of arg name to its default value for each # argument that has a default value if arg_spec.defaults: @@ -186,7 +196,9 @@ def __init__( max_entries: int = 1000, cache_context: bool = False, ): - super().__init__(orig, num_args=None, cache_context=cache_context) + super().__init__( + orig, num_args=None, uncached_args=None, cache_context=cache_context + ) self.max_entries = max_entries def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]: @@ -260,6 +272,9 @@ def foo(self, key, cache_context): num_args: number of positional arguments (excluding ``self`` and ``cache_context``) to use as cache keys. Defaults to all named args of the function. + uncached_args: a list of argument names to not use as the cache key. + (``self`` and ``cache_context`` are always ignored.) Cannot be used + with num_args. tree: cache_context: iterable: @@ -273,12 +288,18 @@ def __init__( orig: Callable[..., Any], max_entries: int = 1000, num_args: Optional[int] = None, + uncached_args: Optional[Collection[str]] = None, tree: bool = False, cache_context: bool = False, iterable: bool = False, prune_unread_entries: bool = True, ): - super().__init__(orig, num_args=num_args, cache_context=cache_context) + super().__init__( + orig, + num_args=num_args, + uncached_args=uncached_args, + cache_context=cache_context, + ) if tree and self.num_args < 2: raise RuntimeError( @@ -369,7 +390,7 @@ def __init__( but including list_name) to use as cache keys. Defaults to all named args of the function. """ - super().__init__(orig, num_args=num_args) + super().__init__(orig, num_args=num_args, uncached_args=None) self.list_name = list_name @@ -532,6 +553,7 @@ def get_instance( def cached( max_entries: int = 1000, num_args: Optional[int] = None, + uncached_args: Optional[Collection[str]] = None, tree: bool = False, cache_context: bool = False, iterable: bool = False, @@ -541,6 +563,7 @@ def cached( orig, max_entries=max_entries, num_args=num_args, + uncached_args=uncached_args, tree=tree, cache_context=cache_context, iterable=iterable, diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 19741ffcdaf1..b95742f0e432 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -141,6 +141,45 @@ def fn(self, arg1, arg2): self.assertEqual(r, "chips") obj.mock.assert_not_called() + @defer.inlineCallbacks + def test_cache_uncached_args(self): + """ + Only the arguments not named in uncached_args should matter to the cache + + Note that this is identical to test_cache_num_args, but provides the + arguments differently. + """ + + class Cls: + @descriptors.cached(uncached_args=("arg2",)) + def fn(self, arg1, arg2): + return self.mock(arg1, arg2) + + def __init__(self): + self.mock = mock.Mock() + + obj = Cls() + obj.mock.return_value = "fish" + r = yield obj.fn(1, 2) + self.assertEqual(r, "fish") + obj.mock.assert_called_once_with(1, 2) + obj.mock.reset_mock() + + # a call with different params should call the mock again + obj.mock.return_value = "chips" + r = yield obj.fn(2, 3) + self.assertEqual(r, "chips") + obj.mock.assert_called_once_with(2, 3) + obj.mock.reset_mock() + + # the two values should now be cached; we should be able to vary + # the second argument and still get the cached result. + r = yield obj.fn(1, 4) + self.assertEqual(r, "fish") + r = yield obj.fn(2, 5) + self.assertEqual(r, "chips") + obj.mock.assert_not_called() + def test_cache_with_sync_exception(self): """If the wrapped function throws synchronously, things should continue to work""" From a4a1c315d90ee29f538389ec2ddfeeaf07e5cb6f Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 9 Mar 2022 08:22:18 -0500 Subject: [PATCH 2/9] Require caches to use kwargs. --- synapse/storage/databases/main/events_worker.py | 4 ++-- synapse/util/caches/descriptors.py | 6 +++--- tests/util/caches/test_descriptors.py | 6 +++--- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 26784f755e40..59454a47dfdd 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -1286,7 +1286,7 @@ async def have_seen_events( ) return {eid for ((_rid, eid), have_event) in res.items() if have_event} - @cachedList("have_seen_event", "keys") + @cachedList(cached_method_name="have_seen_event", list_name="keys") async def _have_seen_events_dict( self, keys: Iterable[Tuple[str, str]] ) -> Dict[Tuple[str, str], bool]: @@ -1954,7 +1954,7 @@ def get_event_id_for_timestamp_txn(txn: LoggingTransaction) -> Optional[str]: get_event_id_for_timestamp_txn, ) - @cachedList("is_partial_state_event", list_name="event_ids") + @cachedList(cached_method_name="is_partial_state_event", list_name="event_ids") async def get_partial_state_events( self, event_ids: Collection[str] ) -> Dict[str, bool]: diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index b4cc5067a862..8edde490a42f 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -140,8 +140,7 @@ class _LruCachedFunction(Generic[F]): def lru_cache( - max_entries: int = 1000, - cache_context: bool = False, + *, max_entries: int = 1000, cache_context: bool = False ) -> Callable[[F], _LruCachedFunction[F]]: """A method decorator that applies a memoizing cache around the function. @@ -551,6 +550,7 @@ def get_instance( def cached( + *, max_entries: int = 1000, num_args: Optional[int] = None, uncached_args: Optional[Collection[str]] = None, @@ -574,7 +574,7 @@ def cached( def cachedList( - cached_method_name: str, list_name: str, num_args: Optional[int] = None + *, cached_method_name: str, list_name: str, num_args: Optional[int] = None ) -> Callable[[F], _CachedFunction[F]]: """Creates a descriptor that wraps a function in a `CacheListDescriptor`. diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index b95742f0e432..a79c875ec53b 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -695,7 +695,7 @@ def __init__(self): def fn(self, arg1, arg2): pass - @descriptors.cachedList("fn", "args1") + @descriptors.cachedList(cached_method_name="fn", list_name="args1") async def list_fn(self, args1, arg2): assert current_context().name == "c1" # we want this to behave like an asynchronous function @@ -754,7 +754,7 @@ def __init__(self): def fn(self, arg1): pass - @descriptors.cachedList("fn", "args1") + @descriptors.cachedList(cached_method_name="fn", list_name="args1") def list_fn(self, args1) -> "Deferred[dict]": return self.mock(args1) @@ -797,7 +797,7 @@ def __init__(self): def fn(self, arg1, arg2): pass - @descriptors.cachedList("fn", "args1") + @descriptors.cachedList(cached_method_name="fn", list_name="args1") async def list_fn(self, args1, arg2): # we want this to behave like an asynchronous function await run_on_reactor() From 18805d02ba09beec941f2df40ab8cea56b043d37 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 9 Mar 2022 08:23:15 -0500 Subject: [PATCH 3/9] Make a function private. --- synapse/util/caches/descriptors.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 8edde490a42f..ef7a36b21bcd 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -129,7 +129,7 @@ def __init__( self.add_cache_context = cache_context - self.cache_key_builder = get_cache_key_builder( + self.cache_key_builder = _get_cache_key_builder( self.arg_names, self.arg_defaults ) @@ -613,7 +613,7 @@ def batch_do_something(self, first_arg, second_args): return cast(Callable[[F], _CachedFunction[F]], func) -def get_cache_key_builder( +def _get_cache_key_builder( param_names: Sequence[str], param_defaults: Mapping[str, Any] ) -> Callable[[Sequence[Any], Mapping[str, Any]], CacheKey]: """Construct a function which will build cache keys suitable for a cached function From 020f0c60b6cca5b9c4a04bb80142cb8a8d1201a1 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 9 Mar 2022 09:23:42 -0500 Subject: [PATCH 4/9] Add a test for keyword arguments. --- tests/util/caches/test_descriptors.py | 37 +++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index a79c875ec53b..141afe245ce0 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -180,6 +180,43 @@ def __init__(self): self.assertEqual(r, "chips") obj.mock.assert_not_called() + @defer.inlineCallbacks + def test_cache_kwargs(self): + """Test that keyword arguments are treated properly""" + + class Cls: + def __init__(self): + self.mock = mock.Mock() + + @descriptors.cached() + def fn(self, arg1, kwarg1=2): + return self.mock(arg1, kwarg1=kwarg1) + + obj = Cls() + obj.mock.return_value = "fish" + r = yield obj.fn(1, kwarg1=2) + self.assertEqual(r, "fish") + obj.mock.assert_called_once_with(1, kwarg1=2) + obj.mock.reset_mock() + + # a call with different params should call the mock again + obj.mock.return_value = "chips" + r = yield obj.fn(1, kwarg1=3) + self.assertEqual(r, "chips") + obj.mock.assert_called_once_with(1, kwarg1=3) + obj.mock.reset_mock() + + # the values should now be cached. + r = yield obj.fn(1, kwarg1=2) + self.assertEqual(r, "fish") + # We should be able to not provide kwarg1 and get the cached value back. + r = yield obj.fn(1) + self.assertEqual(r, "fish") + # Keyword arguments can be in any order. + r = yield obj.fn(kwarg1=2, arg1=1) + self.assertEqual(r, "fish") + obj.mock.assert_not_called() + def test_cache_with_sync_exception(self): """If the wrapped function throws synchronously, things should continue to work""" From f69b643d903d8e3226070065af1f43bf250fadc1 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 9 Mar 2022 09:23:54 -0500 Subject: [PATCH 5/9] Keyword-only arguments are not supported. --- synapse/util/caches/descriptors.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index ef7a36b21bcd..db377a629538 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -78,6 +78,11 @@ def __init__( arg_spec = inspect.getfullargspec(orig) all_args = arg_spec.args + if arg_spec.kwonlyargs: + raise ValueError( + "_CacheDescriptorBase does not support keyword-only arguments." + ) + if "cache_context" in all_args: if not cache_context: raise ValueError( From 8fa311e1f23d0e9770b820c738abdfb9b38383bc Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 9 Mar 2022 09:26:25 -0500 Subject: [PATCH 6/9] Newsfragment --- changelog.d/12189.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/12189.misc diff --git a/changelog.d/12189.misc b/changelog.d/12189.misc new file mode 100644 index 000000000000..015e808e63c7 --- /dev/null +++ b/changelog.d/12189.misc @@ -0,0 +1 @@ +Support skipping some arguments when generating cache keys. From 2d41a6cc9fed68dc914108b601192961095b0a43 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 9 Mar 2022 09:45:20 -0500 Subject: [PATCH 7/9] Add a comment. --- synapse/util/caches/descriptors.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index db377a629538..4ff424481d81 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -78,6 +78,8 @@ def __init__( arg_spec = inspect.getfullargspec(orig) all_args = arg_spec.args + # There's no reason that keyword-only arguments couldn't be supported, + # but right now they're buggy so do not allow them. if arg_spec.kwonlyargs: raise ValueError( "_CacheDescriptorBase does not support keyword-only arguments." From 4460e921397c48e6d1dfd2fdc6517343d7d2980b Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 9 Mar 2022 10:35:21 -0500 Subject: [PATCH 8/9] Properly handle passing a uncached parameter as an arg. --- synapse/util/caches/descriptors.py | 33 +++++++++++++++++++-------- tests/util/caches/test_descriptors.py | 18 ++++++++------- 2 files changed, 34 insertions(+), 17 deletions(-) diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 4ff424481d81..f6cf896d1d1f 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -119,8 +119,9 @@ def __init__( # If there are args to not cache on, filter them out (and fix the size of num_args). if uncached_args is not None: - self.num_args -= len(uncached_args) - self.arg_names = [n for n in self.arg_names if n not in uncached_args] + include_arg_in_cache_key = [n not in uncached_args for n in self.arg_names] + else: + include_arg_in_cache_key = [True] * len(self.arg_names) # self.arg_defaults is a map of arg name to its default value for each # argument that has a default value @@ -137,7 +138,7 @@ def __init__( self.add_cache_context = cache_context self.cache_key_builder = _get_cache_key_builder( - self.arg_names, self.arg_defaults + self.arg_names, include_arg_in_cache_key, self.arg_defaults ) @@ -621,12 +622,15 @@ def batch_do_something(self, first_arg, second_args): def _get_cache_key_builder( - param_names: Sequence[str], param_defaults: Mapping[str, Any] + param_names: Sequence[str], + include_params: Sequence[bool], + param_defaults: Mapping[str, Any], ) -> Callable[[Sequence[Any], Mapping[str, Any]], CacheKey]: """Construct a function which will build cache keys suitable for a cached function Args: param_names: list of formal parameter names for the cached function + include_params: list of bools of whether to include the parameter name in the cache key param_defaults: a mapping from parameter name to default value for that param Returns: @@ -638,6 +642,7 @@ def _get_cache_key_builder( if len(param_names) == 1: nm = param_names[0] + assert include_params[0] is True def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey: if nm in kwargs: @@ -650,13 +655,18 @@ def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey: else: def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey: - return tuple(_get_cache_key_gen(param_names, param_defaults, args, kwargs)) + return tuple( + _get_cache_key_gen( + param_names, include_params, param_defaults, args, kwargs + ) + ) return get_cache_key def _get_cache_key_gen( param_names: Iterable[str], + include_params: Iterable[bool], param_defaults: Mapping[str, Any], args: Sequence[Any], kwargs: Mapping[str, Any], @@ -667,16 +677,21 @@ def _get_cache_key_gen( This is essentially the same operation as `inspect.getcallargs`, but optimised so that we don't need to inspect the target function for each call. """ + if param_names == (): + pass # We loop through each arg name, looking up if its in the `kwargs`, # otherwise using the next argument in `args`. If there are no more # args then we try looking the arg name up in the defaults. pos = 0 - for nm in param_names: + for nm, inc in zip(param_names, include_params): if nm in kwargs: - yield kwargs[nm] + if inc: + yield kwargs[nm] elif pos < len(args): - yield args[pos] + if inc: + yield args[pos] pos += 1 else: - yield param_defaults[nm] + if inc: + yield param_defaults[nm] diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 141afe245ce0..6a4b17527a7f 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -151,32 +151,34 @@ def test_cache_uncached_args(self): """ class Cls: + # Note that it is important that this is not the last argument to + # test behaviour of skipping arguments properly. @descriptors.cached(uncached_args=("arg2",)) - def fn(self, arg1, arg2): - return self.mock(arg1, arg2) + def fn(self, arg1, arg2, arg3): + return self.mock(arg1, arg2, arg3) def __init__(self): self.mock = mock.Mock() obj = Cls() obj.mock.return_value = "fish" - r = yield obj.fn(1, 2) + r = yield obj.fn(1, 2, 3) self.assertEqual(r, "fish") - obj.mock.assert_called_once_with(1, 2) + obj.mock.assert_called_once_with(1, 2, 3) obj.mock.reset_mock() # a call with different params should call the mock again obj.mock.return_value = "chips" - r = yield obj.fn(2, 3) + r = yield obj.fn(2, 3, 4) self.assertEqual(r, "chips") - obj.mock.assert_called_once_with(2, 3) + obj.mock.assert_called_once_with(2, 3, 4) obj.mock.reset_mock() # the two values should now be cached; we should be able to vary # the second argument and still get the cached result. - r = yield obj.fn(1, 4) + r = yield obj.fn(1, 4, 3) self.assertEqual(r, "fish") - r = yield obj.fn(2, 5) + r = yield obj.fn(2, 5, 4) self.assertEqual(r, "chips") obj.mock.assert_not_called() From 02fc051fac90bba8453a8de7dbfa850f430f980c Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 9 Mar 2022 11:08:25 -0500 Subject: [PATCH 9/9] Remove debugging code. --- synapse/util/caches/descriptors.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index f6cf896d1d1f..c3c5c16db96e 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -677,9 +677,6 @@ def _get_cache_key_gen( This is essentially the same operation as `inspect.getcallargs`, but optimised so that we don't need to inspect the target function for each call. """ - if param_names == (): - pass - # We loop through each arg name, looking up if its in the `kwargs`, # otherwise using the next argument in `args`. If there are no more # args then we try looking the arg name up in the defaults.