Skip to content

Commit c523e7f

Browse files
PIG208andersk
authored andcommitted
rate_limit: Stop wrapping rate limited functions.
This refactors `rate_limit` so that we no longer use it as a decorator. This is a workaround to python/mypy#12909 as `rate_limit` previous expects different parameters than its callers. Our approach to test logging handlers also needs to be updated because the view function is not decorated by `rate_limit`. Signed-off-by: Zixuan James Li <[email protected]>
1 parent 1f9573d commit c523e7f

File tree

3 files changed

+41
-65
lines changed

3 files changed

+41
-65
lines changed

zerver/decorator.py

+23-44
Original file line numberDiff line numberDiff line change
@@ -533,7 +533,8 @@ def _wrapped_view_func(
533533
request: HttpRequest, /, *args: ParamT.args, **kwargs: ParamT.kwargs
534534
) -> HttpResponse:
535535
process_client(request, request.user, is_browser_view=True, query=view_func.__name__)
536-
return rate_limit()(view_func)(request, *args, **kwargs)
536+
rate_limit(request)
537+
return view_func(request, *args, **kwargs)
537538

538539
return _wrapped_view_func
539540

@@ -724,10 +725,8 @@ def _wrapped_func_arguments(
724725
) -> HttpResponse:
725726
user_profile = validate_api_key(request, None, api_key, False)
726727
if not skip_rate_limiting:
727-
limited_func = rate_limit()(view_func)
728-
else:
729-
limited_func = view_func
730-
return limited_func(request, user_profile, *args, **kwargs)
728+
rate_limit(request)
729+
return view_func(request, user_profile, *args, **kwargs)
731730

732731
return _wrapped_func_arguments
733732

@@ -788,10 +787,8 @@ def _wrapped_func_arguments(
788787
try:
789788
if not skip_rate_limiting:
790789
# Apply rate limiting
791-
target_view_func = rate_limit()(view_func)
792-
else:
793-
target_view_func = view_func
794-
return target_view_func(request, profile, *args, **kwargs)
790+
rate_limit(request)
791+
return view_func(request, profile, *args, **kwargs)
795792
except Exception as err:
796793
if not webhook_client_name:
797794
raise err
@@ -865,9 +862,7 @@ def authenticate_log_and_execute_json(
865862
**kwargs: object,
866863
) -> HttpResponse:
867864
if not skip_rate_limiting:
868-
limited_view_func = rate_limit()(view_func)
869-
else:
870-
limited_view_func = view_func
865+
rate_limit(request)
871866

872867
if not request.user.is_authenticated:
873868
if not allow_unauthenticated:
@@ -878,7 +873,7 @@ def authenticate_log_and_execute_json(
878873
is_browser_view=True,
879874
query=view_func.__name__,
880875
)
881-
return limited_view_func(request, request.user, *args, **kwargs)
876+
return view_func(request, request.user, *args, **kwargs)
882877

883878
user_profile = request.user
884879
validate_account_and_subdomain(request, user_profile)
@@ -887,7 +882,7 @@ def authenticate_log_and_execute_json(
887882
raise JsonableError(_("Webhook bots can only access webhooks"))
888883

889884
process_client(request, user_profile, is_browser_view=True, query=view_func.__name__)
890-
return limited_view_func(request, user_profile, *args, **kwargs)
885+
return view_func(request, user_profile, *args, **kwargs)
891886

892887

893888
# Checks if the user is logged in. If not, return an error (the
@@ -1072,39 +1067,23 @@ def rate_limit_remote_server(
10721067
raise e
10731068

10741069

1075-
def rate_limit() -> Callable[[ViewFuncT], ViewFuncT]:
1076-
"""Rate-limits a view. Returns a decorator"""
1077-
1078-
def wrapper(func: ViewFuncT) -> ViewFuncT:
1079-
@wraps(func)
1080-
def wrapped_func(request: HttpRequest, *args: object, **kwargs: object) -> HttpResponse:
1081-
1082-
# It is really tempting to not even wrap our original function
1083-
# when settings.RATE_LIMITING is False, but it would make
1084-
# for awkward unit testing in some situations.
1085-
if not settings.RATE_LIMITING:
1086-
return func(request, *args, **kwargs)
1087-
1088-
if client_is_exempt_from_rate_limiting(request):
1089-
return func(request, *args, **kwargs)
1090-
1091-
user = request.user
1092-
remote_server = RequestNotes.get_notes(request).remote_server
1093-
1094-
if settings.ZILENCER_ENABLED and remote_server is not None:
1095-
rate_limit_remote_server(request, remote_server, domain="api_by_remote_server")
1096-
elif not user.is_authenticated:
1097-
rate_limit_request_by_ip(request, domain="api_by_ip")
1098-
return func(request, *args, **kwargs)
1099-
else:
1100-
assert isinstance(user, UserProfile)
1101-
rate_limit_user(request, user, domain="api_by_user")
1070+
def rate_limit(request: HttpRequest) -> None:
1071+
if not settings.RATE_LIMITING:
1072+
return
11021073

1103-
return func(request, *args, **kwargs)
1074+
if client_is_exempt_from_rate_limiting(request):
1075+
return
11041076

1105-
return cast(ViewFuncT, wrapped_func) # https://github.com/python/mypy/issues/1927
1077+
user = request.user
1078+
remote_server = RequestNotes.get_notes(request).remote_server
11061079

1107-
return wrapper
1080+
if settings.ZILENCER_ENABLED and remote_server is not None:
1081+
rate_limit_remote_server(request, remote_server, domain="api_by_remote_server")
1082+
elif not user.is_authenticated:
1083+
rate_limit_request_by_ip(request, domain="api_by_ip")
1084+
else:
1085+
assert isinstance(user, UserProfile)
1086+
rate_limit_user(request, user, domain="api_by_user")
11081087

11091088

11101089
def return_success_on_head_request(

zerver/tests/test_decorators.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -630,10 +630,9 @@ def test_authenticated_rest_api_view_errors(self) -> None:
630630
class RateLimitTestCase(ZulipTestCase):
631631
def get_ratelimited_view(self) -> Callable[..., HttpResponse]:
632632
def f(req: Any) -> HttpResponse:
633+
rate_limit(req)
633634
return json_response(msg="some value")
634635

635-
f = rate_limit()(f)
636-
637636
return f
638637

639638
def errors_disallowed(self) -> Any:

zerver/tests/test_logging_handlers.py

+17-19
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import sys
33
from functools import wraps
44
from types import TracebackType
5-
from typing import Callable, Dict, Iterator, NoReturn, Optional, Tuple, Type, Union, cast
5+
from typing import Dict, Iterator, NoReturn, Optional, Tuple, Type, Union, cast
66
from unittest import mock
77
from unittest.mock import MagicMock, patch
88

@@ -22,22 +22,19 @@
2222
] = None
2323

2424

25-
def capture_and_throw(domain: Optional[str] = None) -> Callable[[ViewFuncT], ViewFuncT]:
26-
def wrapper(view_func: ViewFuncT) -> ViewFuncT:
27-
@wraps(view_func)
28-
def wrapped_view(request: HttpRequest, *args: object, **kwargs: object) -> NoReturn:
29-
global captured_request
30-
captured_request = request
31-
try:
32-
raise Exception("Request error")
33-
except Exception as e:
34-
global captured_exc_info
35-
captured_exc_info = sys.exc_info()
36-
raise e
37-
38-
return cast(ViewFuncT, wrapped_view) # https://github.com/python/mypy/issues/1927
25+
def capture_and_throw(view_func: ViewFuncT) -> ViewFuncT:
26+
@wraps(view_func)
27+
def wrapped_view(request: HttpRequest, *args: object, **kwargs: object) -> NoReturn:
28+
global captured_request
29+
captured_request = request
30+
try:
31+
raise Exception("Request error")
32+
except Exception as e:
33+
global captured_exc_info
34+
captured_exc_info = sys.exc_info()
35+
raise e
3936

40-
return wrapper
37+
return cast(ViewFuncT, wrapped_view) # https://github.com/python/mypy/issues/1927
4138

4239

4340
class AdminNotifyHandlerTest(ZulipTestCase):
@@ -78,17 +75,18 @@ def test_basic(self, mock_function: MagicMock) -> None:
7875

7976
def simulate_error(self) -> logging.LogRecord:
8077
self.login("hamlet")
81-
with patch("zerver.decorator.rate_limit") as rate_limit_patch, self.assertLogs(
78+
with patch(
79+
"zerver.lib.rest.authenticated_json_view", side_effect=capture_and_throw
80+
) as view_decorator_patch, self.assertLogs(
8281
"django.request", level="ERROR"
8382
) as request_error_log, self.assertLogs(
8483
"zerver.middleware.json_error_handler", level="ERROR"
8584
) as json_error_handler_log, self.settings(
8685
TEST_SUITE=False
8786
):
88-
rate_limit_patch.side_effect = capture_and_throw
8987
result = self.client_get("/json/users")
9088
self.assert_json_error(result, "Internal server error", status_code=500)
91-
rate_limit_patch.assert_called_once()
89+
view_decorator_patch.assert_called_once()
9290
self.assertEqual(
9391
request_error_log.output, ["ERROR:django.request:Internal Server Error: /json/users"]
9492
)

0 commit comments

Comments
 (0)