21
21
from typing_extensions import NoReturn , Protocol
22
22
23
23
from twisted .web .http import Request
24
+ from twisted .web .iweb import IRequest
24
25
25
26
from synapse .api .constants import LoginType
26
27
from synapse .api .errors import Codes , NotFoundError , RedirectException , SynapseError
27
28
from synapse .handlers .ui_auth import UIAuthSessionDataConstants
28
29
from synapse .http import get_request_user_agent
29
- from synapse .http .server import respond_with_html
30
+ from synapse .http .server import respond_with_html , respond_with_redirect
30
31
from synapse .http .site import SynapseRequest
31
32
from synapse .types import JsonDict , UserID , contains_invalid_mxid_characters
32
33
from synapse .util .async_helpers import Linearizer
@@ -141,6 +142,9 @@ class UsernameMappingSession:
141
142
# expiry time for the session, in milliseconds
142
143
expiry_time_ms = attr .ib (type = int )
143
144
145
+ # choices made by the user
146
+ chosen_localpart = attr .ib (type = Optional [str ], default = None )
147
+
144
148
145
149
# the HTTP cookie used to track the mapping session id
146
150
USERNAME_MAPPING_SESSION_COOKIE_NAME = b"username_mapping_session"
@@ -647,6 +651,25 @@ async def complete_sso_ui_auth_request(
647
651
)
648
652
respond_with_html (request , 200 , html )
649
653
654
+ def get_mapping_session (self , session_id : str ) -> UsernameMappingSession :
655
+ """Look up the given username mapping session
656
+
657
+ If it is not found, raises a SynapseError with an http code of 400
658
+
659
+ Args:
660
+ session_id: session to look up
661
+ Returns:
662
+ active mapping session
663
+ Raises:
664
+ SynapseError if the session is not found/has expired
665
+ """
666
+ self ._expire_old_sessions ()
667
+ session = self ._username_mapping_sessions .get (session_id )
668
+ if session :
669
+ return session
670
+ logger .info ("Couldn't find session id %s" , session_id )
671
+ raise SynapseError (400 , "unknown session" )
672
+
650
673
async def check_username_availability (
651
674
self , localpart : str , session_id : str ,
652
675
) -> bool :
@@ -663,12 +686,7 @@ async def check_username_availability(
663
686
664
687
# make sure that there is a valid mapping session, to stop people dictionary-
665
688
# scanning for accounts
666
-
667
- self ._expire_old_sessions ()
668
- session = self ._username_mapping_sessions .get (session_id )
669
- if not session :
670
- logger .info ("Couldn't find session id %s" , session_id )
671
- raise SynapseError (400 , "unknown session" )
689
+ self .get_mapping_session (session_id )
672
690
673
691
logger .info (
674
692
"[session %s] Checking for availability of username %s" ,
@@ -696,16 +714,33 @@ async def handle_submit_username_request(
696
714
localpart: localpart requested by the user
697
715
session_id: ID of the username mapping session, extracted from a cookie
698
716
"""
699
- self ._expire_old_sessions ()
700
- session = self ._username_mapping_sessions .get (session_id )
701
- if not session :
702
- logger .info ("Couldn't find session id %s" , session_id )
703
- raise SynapseError (400 , "unknown session" )
717
+ session = self .get_mapping_session (session_id )
718
+
719
+ # update the session with the user's choices
720
+ session .chosen_localpart = localpart
721
+
722
+ # we're done; now we can register the user
723
+ respond_with_redirect (request , b"/_synapse/client/sso_register" )
724
+
725
+ async def register_sso_user (self , request : Request , session_id : str ) -> None :
726
+ """Called once we have all the info we need to register a new user.
704
727
705
- logger .info ("[session %s] Registering localpart %s" , session_id , localpart )
728
+ Does so and serves an HTTP response
729
+
730
+ Args:
731
+ request: HTTP request
732
+ session_id: ID of the username mapping session, extracted from a cookie
733
+ """
734
+ session = self .get_mapping_session (session_id )
735
+
736
+ logger .info (
737
+ "[session %s] Registering localpart %s" ,
738
+ session_id ,
739
+ session .chosen_localpart ,
740
+ )
706
741
707
742
attributes = UserAttributes (
708
- localpart = localpart ,
743
+ localpart = session . chosen_localpart ,
709
744
display_name = session .display_name ,
710
745
emails = session .emails ,
711
746
)
@@ -720,7 +755,12 @@ async def handle_submit_username_request(
720
755
request .getClientIP (),
721
756
)
722
757
723
- logger .info ("[session %s] Registered userid %s" , session_id , user_id )
758
+ logger .info (
759
+ "[session %s] Registered userid %s with attributes %s" ,
760
+ session_id ,
761
+ user_id ,
762
+ attributes ,
763
+ )
724
764
725
765
# delete the mapping session and the cookie
726
766
del self ._username_mapping_sessions [session_id ]
@@ -751,3 +791,14 @@ def _expire_old_sessions(self):
751
791
for session_id in to_expire :
752
792
logger .info ("Expiring mapping session %s" , session_id )
753
793
del self ._username_mapping_sessions [session_id ]
794
+
795
+
796
+ def get_username_mapping_session_cookie_from_request (request : IRequest ) -> str :
797
+ """Extract the session ID from the cookie
798
+
799
+ Raises a SynapseError if the cookie isn't found
800
+ """
801
+ session_id = request .getCookie (USERNAME_MAPPING_SESSION_COOKIE_NAME )
802
+ if not session_id :
803
+ raise SynapseError (code = 400 , msg = "missing session_id" )
804
+ return session_id .decode ("ascii" , errors = "replace" )
0 commit comments