31
31
from synapse .appservice import ApplicationService
32
32
from synapse .server import HomeServer
33
33
from synapse .storage .databases .main .registration import TokenLookupResult
34
- from synapse .types import Requester
34
+ from synapse .types import Requester , UserID
35
35
from synapse .util import Clock
36
36
37
37
from tests import unittest
41
41
42
42
43
43
class AuthTestCase (unittest .HomeserverTestCase ):
44
- def prepare (self , reactor : MemoryReactor , clock : Clock , hs : HomeServer ):
44
+ def prepare (self , reactor : MemoryReactor , clock : Clock , hs : HomeServer ) -> None :
45
45
self .store = Mock ()
46
46
47
- hs .datastores .main = self .store
47
+ # type-ignore: datastores is None until hs.setup() is called---but it'll
48
+ # have been called by the HomeserverTestCase machinery.
49
+ hs .datastores .main = self .store # type: ignore[union-attr]
48
50
hs .get_auth_handler ().store = self .store
49
51
self .auth = Auth (hs )
50
52
@@ -61,7 +63,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
61
63
self .store .insert_client_ip = simple_async_mock (None )
62
64
self .store .is_support_user = simple_async_mock (False )
63
65
64
- def test_get_user_by_req_user_valid_token (self ):
66
+ def test_get_user_by_req_user_valid_token (self ) -> None :
65
67
user_info = TokenLookupResult (
66
68
user_id = self .test_user , token_id = 5 , device_id = "device"
67
69
)
@@ -74,7 +76,7 @@ def test_get_user_by_req_user_valid_token(self):
74
76
requester = self .get_success (self .auth .get_user_by_req (request ))
75
77
self .assertEqual (requester .user .to_string (), self .test_user )
76
78
77
- def test_get_user_by_req_user_bad_token (self ):
79
+ def test_get_user_by_req_user_bad_token (self ) -> None :
78
80
self .store .get_user_by_access_token = simple_async_mock (None )
79
81
80
82
request = Mock (args = {})
@@ -86,7 +88,7 @@ def test_get_user_by_req_user_bad_token(self):
86
88
self .assertEqual (f .code , 401 )
87
89
self .assertEqual (f .errcode , "M_UNKNOWN_TOKEN" )
88
90
89
- def test_get_user_by_req_user_missing_token (self ):
91
+ def test_get_user_by_req_user_missing_token (self ) -> None :
90
92
user_info = TokenLookupResult (user_id = self .test_user , token_id = 5 )
91
93
self .store .get_user_by_access_token = simple_async_mock (user_info )
92
94
@@ -98,7 +100,7 @@ def test_get_user_by_req_user_missing_token(self):
98
100
self .assertEqual (f .code , 401 )
99
101
self .assertEqual (f .errcode , "M_MISSING_TOKEN" )
100
102
101
- def test_get_user_by_req_appservice_valid_token (self ):
103
+ def test_get_user_by_req_appservice_valid_token (self ) -> None :
102
104
app_service = Mock (
103
105
token = "foobar" , url = "a_url" , sender = self .test_user , ip_range_whitelist = None
104
106
)
@@ -112,7 +114,7 @@ def test_get_user_by_req_appservice_valid_token(self):
112
114
requester = self .get_success (self .auth .get_user_by_req (request ))
113
115
self .assertEqual (requester .user .to_string (), self .test_user )
114
116
115
- def test_get_user_by_req_appservice_valid_token_good_ip (self ):
117
+ def test_get_user_by_req_appservice_valid_token_good_ip (self ) -> None :
116
118
from netaddr import IPSet
117
119
118
120
app_service = Mock (
@@ -131,7 +133,7 @@ def test_get_user_by_req_appservice_valid_token_good_ip(self):
131
133
requester = self .get_success (self .auth .get_user_by_req (request ))
132
134
self .assertEqual (requester .user .to_string (), self .test_user )
133
135
134
- def test_get_user_by_req_appservice_valid_token_bad_ip (self ):
136
+ def test_get_user_by_req_appservice_valid_token_bad_ip (self ) -> None :
135
137
from netaddr import IPSet
136
138
137
139
app_service = Mock (
@@ -153,7 +155,7 @@ def test_get_user_by_req_appservice_valid_token_bad_ip(self):
153
155
self .assertEqual (f .code , 401 )
154
156
self .assertEqual (f .errcode , "M_UNKNOWN_TOKEN" )
155
157
156
- def test_get_user_by_req_appservice_bad_token (self ):
158
+ def test_get_user_by_req_appservice_bad_token (self ) -> None :
157
159
self .store .get_app_service_by_token = Mock (return_value = None )
158
160
self .store .get_user_by_access_token = simple_async_mock (None )
159
161
@@ -166,7 +168,7 @@ def test_get_user_by_req_appservice_bad_token(self):
166
168
self .assertEqual (f .code , 401 )
167
169
self .assertEqual (f .errcode , "M_UNKNOWN_TOKEN" )
168
170
169
- def test_get_user_by_req_appservice_missing_token (self ):
171
+ def test_get_user_by_req_appservice_missing_token (self ) -> None :
170
172
app_service = Mock (token = "foobar" , url = "a_url" , sender = self .test_user )
171
173
self .store .get_app_service_by_token = Mock (return_value = app_service )
172
174
self .store .get_user_by_access_token = simple_async_mock (None )
@@ -179,7 +181,7 @@ def test_get_user_by_req_appservice_missing_token(self):
179
181
self .assertEqual (f .code , 401 )
180
182
self .assertEqual (f .errcode , "M_MISSING_TOKEN" )
181
183
182
- def test_get_user_by_req_appservice_valid_token_valid_user_id (self ):
184
+ def test_get_user_by_req_appservice_valid_token_valid_user_id (self ) -> None :
183
185
masquerading_user_id = b"@doppelganger:matrix.org"
184
186
app_service = Mock (
185
187
token = "foobar" , url = "a_url" , sender = self .test_user , ip_range_whitelist = None
@@ -200,7 +202,7 @@ def test_get_user_by_req_appservice_valid_token_valid_user_id(self):
200
202
requester .user .to_string (), masquerading_user_id .decode ("utf8" )
201
203
)
202
204
203
- def test_get_user_by_req_appservice_valid_token_bad_user_id (self ):
205
+ def test_get_user_by_req_appservice_valid_token_bad_user_id (self ) -> None :
204
206
masquerading_user_id = b"@doppelganger:matrix.org"
205
207
app_service = Mock (
206
208
token = "foobar" , url = "a_url" , sender = self .test_user , ip_range_whitelist = None
@@ -217,7 +219,7 @@ def test_get_user_by_req_appservice_valid_token_bad_user_id(self):
217
219
self .get_failure (self .auth .get_user_by_req (request ), AuthError )
218
220
219
221
@override_config ({"experimental_features" : {"msc3202_device_masquerading" : True }})
220
- def test_get_user_by_req_appservice_valid_token_valid_device_id (self ):
222
+ def test_get_user_by_req_appservice_valid_token_valid_device_id (self ) -> None :
221
223
"""
222
224
Tests that when an application service passes the device_id URL parameter
223
225
with the ID of a valid device for the user in question,
@@ -249,7 +251,7 @@ def test_get_user_by_req_appservice_valid_token_valid_device_id(self):
249
251
self .assertEqual (requester .device_id , masquerading_device_id .decode ("utf8" ))
250
252
251
253
@override_config ({"experimental_features" : {"msc3202_device_masquerading" : True }})
252
- def test_get_user_by_req_appservice_valid_token_invalid_device_id (self ):
254
+ def test_get_user_by_req_appservice_valid_token_invalid_device_id (self ) -> None :
253
255
"""
254
256
Tests that when an application service passes the device_id URL parameter
255
257
with an ID that is not a valid device ID for the user in question,
@@ -279,7 +281,7 @@ def test_get_user_by_req_appservice_valid_token_invalid_device_id(self):
279
281
self .assertEqual (failure .value .code , 400 )
280
282
self .assertEqual (failure .value .errcode , Codes .EXCLUSIVE )
281
283
282
- def test_get_user_by_req__puppeted_token__not_tracking_puppeted_mau (self ):
284
+ def test_get_user_by_req__puppeted_token__not_tracking_puppeted_mau (self ) -> None :
283
285
self .store .get_user_by_access_token = simple_async_mock (
284
286
TokenLookupResult (
285
287
user_id = "@baldrick:matrix.org" ,
@@ -298,7 +300,7 @@ def test_get_user_by_req__puppeted_token__not_tracking_puppeted_mau(self):
298
300
self .get_success (self .auth .get_user_by_req (request ))
299
301
self .store .insert_client_ip .assert_called_once ()
300
302
301
- def test_get_user_by_req__puppeted_token__tracking_puppeted_mau (self ):
303
+ def test_get_user_by_req__puppeted_token__tracking_puppeted_mau (self ) -> None :
302
304
self .auth ._track_puppeted_user_ips = True
303
305
self .store .get_user_by_access_token = simple_async_mock (
304
306
TokenLookupResult (
@@ -318,7 +320,7 @@ def test_get_user_by_req__puppeted_token__tracking_puppeted_mau(self):
318
320
self .get_success (self .auth .get_user_by_req (request ))
319
321
self .assertEqual (self .store .insert_client_ip .call_count , 2 )
320
322
321
- def test_get_user_from_macaroon (self ):
323
+ def test_get_user_from_macaroon (self ) -> None :
322
324
self .store .get_user_by_access_token = simple_async_mock (None )
323
325
324
326
user_id = "@baldrick:matrix.org"
@@ -336,7 +338,7 @@ def test_get_user_from_macaroon(self):
336
338
self .auth .get_user_by_access_token (serialized ), InvalidClientTokenError
337
339
)
338
340
339
- def test_get_guest_user_from_macaroon (self ):
341
+ def test_get_guest_user_from_macaroon (self ) -> None :
340
342
self .store .get_user_by_id = simple_async_mock ({"is_guest" : True })
341
343
self .store .get_user_by_access_token = simple_async_mock (None )
342
344
@@ -357,7 +359,7 @@ def test_get_guest_user_from_macaroon(self):
357
359
self .assertTrue (user_info .is_guest )
358
360
self .store .get_user_by_id .assert_called_with (user_id )
359
361
360
- def test_blocking_mau (self ):
362
+ def test_blocking_mau (self ) -> None :
361
363
self .auth_blocking ._limit_usage_by_mau = False
362
364
self .auth_blocking ._max_mau_value = 50
363
365
lots_of_users = 100
@@ -381,7 +383,7 @@ def test_blocking_mau(self):
381
383
self .store .get_monthly_active_count = simple_async_mock (small_number_of_users )
382
384
self .get_success (self .auth_blocking .check_auth_blocking ())
383
385
384
- def test_blocking_mau__depending_on_user_type (self ):
386
+ def test_blocking_mau__depending_on_user_type (self ) -> None :
385
387
self .auth_blocking ._max_mau_value = 50
386
388
self .auth_blocking ._limit_usage_by_mau = True
387
389
@@ -400,7 +402,9 @@ def test_blocking_mau__depending_on_user_type(self):
400
402
# Real users not allowed
401
403
self .get_failure (self .auth_blocking .check_auth_blocking (), ResourceLimitError )
402
404
403
- def test_blocking_mau__appservice_requester_allowed_when_not_tracking_ips (self ):
405
+ def test_blocking_mau__appservice_requester_allowed_when_not_tracking_ips (
406
+ self ,
407
+ ) -> None :
404
408
self .auth_blocking ._max_mau_value = 50
405
409
self .auth_blocking ._limit_usage_by_mau = True
406
410
self .auth_blocking ._track_appservice_user_ips = False
@@ -418,7 +422,7 @@ def test_blocking_mau__appservice_requester_allowed_when_not_tracking_ips(self):
418
422
sender = "@appservice:sender" ,
419
423
)
420
424
requester = Requester (
421
- user = "@appservice:server" ,
425
+ user = UserID . from_string ( "@appservice:server" ) ,
422
426
access_token_id = None ,
423
427
device_id = "FOOBAR" ,
424
428
is_guest = False ,
@@ -428,7 +432,9 @@ def test_blocking_mau__appservice_requester_allowed_when_not_tracking_ips(self):
428
432
)
429
433
self .get_success (self .auth_blocking .check_auth_blocking (requester = requester ))
430
434
431
- def test_blocking_mau__appservice_requester_disallowed_when_tracking_ips (self ):
435
+ def test_blocking_mau__appservice_requester_disallowed_when_tracking_ips (
436
+ self ,
437
+ ) -> None :
432
438
self .auth_blocking ._max_mau_value = 50
433
439
self .auth_blocking ._limit_usage_by_mau = True
434
440
self .auth_blocking ._track_appservice_user_ips = True
@@ -446,7 +452,7 @@ def test_blocking_mau__appservice_requester_disallowed_when_tracking_ips(self):
446
452
sender = "@appservice:sender" ,
447
453
)
448
454
requester = Requester (
449
- user = "@appservice:server" ,
455
+ user = UserID . from_string ( "@appservice:server" ) ,
450
456
access_token_id = None ,
451
457
device_id = "FOOBAR" ,
452
458
is_guest = False ,
@@ -459,7 +465,7 @@ def test_blocking_mau__appservice_requester_disallowed_when_tracking_ips(self):
459
465
ResourceLimitError ,
460
466
)
461
467
462
- def test_reserved_threepid (self ):
468
+ def test_reserved_threepid (self ) -> None :
463
469
self .auth_blocking ._limit_usage_by_mau = True
464
470
self .auth_blocking ._max_mau_value = 1
465
471
self .store .get_monthly_active_count = simple_async_mock (2 )
@@ -476,7 +482,7 @@ def test_reserved_threepid(self):
476
482
477
483
self .get_success (self .auth_blocking .check_auth_blocking (threepid = threepid ))
478
484
479
- def test_hs_disabled (self ):
485
+ def test_hs_disabled (self ) -> None :
480
486
self .auth_blocking ._hs_disabled = True
481
487
self .auth_blocking ._hs_disabled_message = "Reason for being disabled"
482
488
e = self .get_failure (
@@ -486,7 +492,7 @@ def test_hs_disabled(self):
486
492
self .assertEqual (e .value .errcode , Codes .RESOURCE_LIMIT_EXCEEDED )
487
493
self .assertEqual (e .value .code , 403 )
488
494
489
- def test_hs_disabled_no_server_notices_user (self ):
495
+ def test_hs_disabled_no_server_notices_user (self ) -> None :
490
496
"""Check that 'hs_disabled_message' works correctly when there is no
491
497
server_notices user.
492
498
"""
@@ -503,7 +509,7 @@ def test_hs_disabled_no_server_notices_user(self):
503
509
self .assertEqual (e .value .errcode , Codes .RESOURCE_LIMIT_EXCEEDED )
504
510
self .assertEqual (e .value .code , 403 )
505
511
506
- def test_server_notices_mxid_special_cased (self ):
512
+ def test_server_notices_mxid_special_cased (self ) -> None :
507
513
self .auth_blocking ._hs_disabled = True
508
514
user = "@user:server"
509
515
self .auth_blocking ._server_notices_mxid = user
0 commit comments