Skip to content

Commit 1ef42ad

Browse files
committed
rate_limit: Stop wrapping rate limited functions.
This refactors `rate_limit` so that we no longer use it as a decorator. This is a workaround to python/mypy#12909 as `rate_limit` previous expects different parameters than its callers. Our approach to test logging handlers also needs to be updated because the view function is not decorated by `rate_limit`. Signed-off-by: Zixuan James Li <[email protected]>
1 parent a805972 commit 1ef42ad

File tree

3 files changed

+41
-62
lines changed

3 files changed

+41
-62
lines changed

zerver/decorator.py

+23-41
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,8 @@ def _wrapped_view_func(
532532
request: HttpRequest, /, *args: ParamT.args, **kwargs: ParamT.kwargs
533533
) -> HttpResponse:
534534
process_client(request, request.user, is_browser_view=True, query=view_func.__name__)
535-
return rate_limit(view_func)(request, *args, **kwargs)
535+
rate_limit(request)
536+
return view_func(request, *args, **kwargs)
536537

537538
return _wrapped_view_func
538539

@@ -723,10 +724,8 @@ def _wrapped_func_arguments(
723724
) -> HttpResponse:
724725
user_profile = validate_api_key(request, None, api_key, False)
725726
if not skip_rate_limiting:
726-
limited_func = rate_limit(view_func)
727-
else:
728-
limited_func = view_func
729-
return limited_func(request, user_profile, *args, **kwargs)
727+
rate_limit(request)
728+
return view_func(request, user_profile, *args, **kwargs)
730729

731730
return _wrapped_func_arguments
732731

@@ -787,10 +786,8 @@ def _wrapped_func_arguments(
787786
try:
788787
if not skip_rate_limiting:
789788
# Apply rate limiting
790-
target_view_func = rate_limit(view_func)
791-
else:
792-
target_view_func = view_func
793-
return target_view_func(request, profile, *args, **kwargs)
789+
rate_limit(request)
790+
return view_func(request, profile, *args, **kwargs)
794791
except Exception as err:
795792
if not webhook_client_name:
796793
raise err
@@ -864,9 +861,7 @@ def authenticate_log_and_execute_json(
864861
**kwargs: object,
865862
) -> HttpResponse:
866863
if not skip_rate_limiting:
867-
limited_view_func = rate_limit(view_func)
868-
else:
869-
limited_view_func = view_func
864+
rate_limit(request)
870865

871866
if not request.user.is_authenticated:
872867
if not allow_unauthenticated:
@@ -877,7 +872,7 @@ def authenticate_log_and_execute_json(
877872
is_browser_view=True,
878873
query=view_func.__name__,
879874
)
880-
return limited_view_func(request, request.user, *args, **kwargs)
875+
return view_func(request, request.user, *args, **kwargs)
881876

882877
user_profile = request.user
883878
validate_account_and_subdomain(request, user_profile)
@@ -886,7 +881,7 @@ def authenticate_log_and_execute_json(
886881
raise JsonableError(_("Webhook bots can only access webhooks"))
887882

888883
process_client(request, user_profile, is_browser_view=True, query=view_func.__name__)
889-
return limited_view_func(request, user_profile, *args, **kwargs)
884+
return view_func(request, user_profile, *args, **kwargs)
890885

891886

892887
# Checks if the user is logged in. If not, return an error (the
@@ -1069,36 +1064,23 @@ def rate_limit_remote_server(
10691064
raise e
10701065

10711066

1072-
def rate_limit(func: ViewFuncT) -> ViewFuncT:
1073-
"""Rate-limits a view."""
1074-
1075-
@wraps(func)
1076-
def wrapped_func(request: HttpRequest, *args: object, **kwargs: object) -> HttpResponse:
1077-
1078-
# It is really tempting to not even wrap our original function
1079-
# when settings.RATE_LIMITING is False, but it would make
1080-
# for awkward unit testing in some situations.
1081-
if not settings.RATE_LIMITING:
1082-
return func(request, *args, **kwargs)
1083-
1084-
if client_is_exempt_from_rate_limiting(request):
1085-
return func(request, *args, **kwargs)
1086-
1087-
user = request.user
1088-
remote_server = RequestNotes.get_notes(request).remote_server
1067+
def rate_limit(request: HttpRequest) -> None:
1068+
if not settings.RATE_LIMITING:
1069+
return
10891070

1090-
if settings.ZILENCER_ENABLED and remote_server is not None:
1091-
rate_limit_remote_server(request, remote_server, domain="api_by_remote_server")
1092-
elif not user.is_authenticated:
1093-
rate_limit_request_by_ip(request, domain="api_by_ip")
1094-
return func(request, *args, **kwargs)
1095-
else:
1096-
assert isinstance(user, UserProfile)
1097-
rate_limit_user(request, user, domain="api_by_user")
1071+
if client_is_exempt_from_rate_limiting(request):
1072+
return
10981073

1099-
return func(request, *args, **kwargs)
1074+
user = request.user
1075+
remote_server = RequestNotes.get_notes(request).remote_server
11001076

1101-
return cast(ViewFuncT, wrapped_func) # https://github.com/python/mypy/issues/1927
1077+
if settings.ZILENCER_ENABLED and remote_server is not None:
1078+
rate_limit_remote_server(request, remote_server, domain="api_by_remote_server")
1079+
elif not user.is_authenticated:
1080+
rate_limit_request_by_ip(request, domain="api_by_ip")
1081+
else:
1082+
assert isinstance(user, UserProfile)
1083+
rate_limit_user(request, user, domain="api_by_user")
11021084

11031085

11041086
def return_success_on_head_request(

zerver/tests/test_decorators.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -630,10 +630,9 @@ def test_authenticated_rest_api_view_errors(self) -> None:
630630
class RateLimitTestCase(ZulipTestCase):
631631
def get_ratelimited_view(self) -> Callable[..., HttpResponse]:
632632
def f(req: Any) -> HttpResponse:
633+
rate_limit(req)
633634
return json_response(msg="some value")
634635

635-
f = rate_limit(f)
636-
637636
return f
638637

639638
def errors_disallowed(self) -> Any:

zerver/tests/test_logging_handlers.py

+17-19
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import sys
33
from functools import wraps
44
from types import TracebackType
5-
from typing import Callable, Dict, Iterator, NoReturn, Optional, Tuple, Type, Union, cast
5+
from typing import Dict, Iterator, NoReturn, Optional, Tuple, Type, Union, cast
66
from unittest import mock
77
from unittest.mock import MagicMock, patch
88

@@ -22,22 +22,19 @@
2222
] = None
2323

2424

25-
def capture_and_throw(domain: Optional[str] = None) -> Callable[[ViewFuncT], ViewFuncT]:
26-
def wrapper(view_func: ViewFuncT) -> ViewFuncT:
27-
@wraps(view_func)
28-
def wrapped_view(request: HttpRequest, *args: object, **kwargs: object) -> NoReturn:
29-
global captured_request
30-
captured_request = request
31-
try:
32-
raise Exception("Request error")
33-
except Exception as e:
34-
global captured_exc_info
35-
captured_exc_info = sys.exc_info()
36-
raise e
37-
38-
return cast(ViewFuncT, wrapped_view) # https://github.com/python/mypy/issues/1927
25+
def capture_and_throw(view_func: ViewFuncT) -> ViewFuncT:
26+
@wraps(view_func)
27+
def wrapped_view(request: HttpRequest, *args: object, **kwargs: object) -> NoReturn:
28+
global captured_request
29+
captured_request = request
30+
try:
31+
raise Exception("Request error")
32+
except Exception as e:
33+
global captured_exc_info
34+
captured_exc_info = sys.exc_info()
35+
raise e
3936

40-
return wrapper
37+
return cast(ViewFuncT, wrapped_view) # https://github.com/python/mypy/issues/1927
4138

4239

4340
class AdminNotifyHandlerTest(ZulipTestCase):
@@ -78,17 +75,18 @@ def test_basic(self, mock_function: MagicMock) -> None:
7875

7976
def simulate_error(self) -> logging.LogRecord:
8077
self.login("hamlet")
81-
with patch("zerver.decorator.rate_limit") as rate_limit_patch, self.assertLogs(
78+
with patch(
79+
"zerver.lib.rest.authenticated_json_view", side_effect=capture_and_throw
80+
) as view_decorator_patch, self.assertLogs(
8281
"django.request", level="ERROR"
8382
) as request_error_log, self.assertLogs(
8483
"zerver.middleware.json_error_handler", level="ERROR"
8584
) as json_error_handler_log, self.settings(
8685
TEST_SUITE=False
8786
):
88-
rate_limit_patch.side_effect = capture_and_throw
8987
result = self.client_get("/json/users")
9088
self.assert_json_error(result, "Internal server error", status_code=500)
91-
rate_limit_patch.assert_called_once()
89+
view_decorator_patch.assert_called_once()
9290
self.assertEqual(
9391
request_error_log.output, ["ERROR:django.request:Internal Server Error: /json/users"]
9492
)

0 commit comments

Comments
 (0)