@@ -755,81 +755,145 @@ async def claim_e2e_one_time_keys(
755
755
"""
756
756
757
757
@trace
758
- def _claim_e2e_one_time_keys (txn ):
759
- sql = (
760
- "SELECT key_id, key_json FROM e2e_one_time_keys_json"
761
- " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
762
- " LIMIT 1"
758
+ def _claim_e2e_one_time_key_simple (
759
+ txn , user_id : str , device_id : str , algorithm : str
760
+ ) -> Optional [Tuple [str , str ]]:
761
+ """Claim OTK for device for DBs that don't support RETURNING.
762
+
763
+ Returns:
764
+ A tuple of key name (algorithm + key ID) and key JSON, if an
765
+ OTK was found.
766
+ """
767
+
768
+ sql = """
769
+ SELECT key_id, key_json FROM e2e_one_time_keys_json
770
+ WHERE user_id = ? AND device_id = ? AND algorithm = ?
771
+ LIMIT 1
772
+ """
773
+
774
+ txn .execute (sql , (user_id , device_id , algorithm ))
775
+ otk_row = txn .fetchone ()
776
+ if otk_row is None :
777
+ return None
778
+
779
+ key_id , key_json = otk_row
780
+
781
+ self .db_pool .simple_delete_one_txn (
782
+ txn ,
783
+ table = "e2e_one_time_keys_json" ,
784
+ keyvalues = {
785
+ "user_id" : user_id ,
786
+ "device_id" : device_id ,
787
+ "algorithm" : algorithm ,
788
+ "key_id" : key_id ,
789
+ },
763
790
)
764
- fallback_sql = (
765
- "SELECT key_id, key_json, used FROM e2e_fallback_keys_json"
766
- " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
767
- " LIMIT 1"
791
+ self ._invalidate_cache_and_stream (
792
+ txn , self .count_e2e_one_time_keys , (user_id , device_id )
768
793
)
769
- result = {}
770
- delete = []
771
- used_fallbacks = []
772
- for user_id , device_id , algorithm in query_list :
773
- user_result = result .setdefault (user_id , {})
774
- device_result = user_result .setdefault (device_id , {})
775
- txn .execute (sql , (user_id , device_id , algorithm ))
776
- otk_row = txn .fetchone ()
777
- if otk_row is not None :
778
- key_id , key_json = otk_row
779
- device_result [algorithm + ":" + key_id ] = key_json
780
- delete .append ((user_id , device_id , algorithm , key_id ))
781
- else :
782
- # no one-time key available, so see if there's a fallback
783
- # key
784
- txn .execute (fallback_sql , (user_id , device_id , algorithm ))
785
- fallback_row = txn .fetchone ()
786
- if fallback_row is not None :
787
- key_id , key_json , used = fallback_row
788
- device_result [algorithm + ":" + key_id ] = key_json
789
- if not used :
790
- used_fallbacks .append (
791
- (user_id , device_id , algorithm , key_id )
792
- )
793
-
794
- # drop any one-time keys that were claimed
795
- sql = (
796
- "DELETE FROM e2e_one_time_keys_json"
797
- " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
798
- " AND key_id = ?"
794
+
795
+ return f"{ algorithm } :{ key_id } " , key_json
796
+
797
+ @trace
798
+ def _claim_e2e_one_time_key_returning (
799
+ txn , user_id : str , device_id : str , algorithm : str
800
+ ) -> Optional [Tuple [str , str ]]:
801
+ """Claim OTK for device for DBs that support RETURNING.
802
+
803
+ Returns:
804
+ A tuple of key name (algorithm + key ID) and key JSON, if an
805
+ OTK was found.
806
+ """
807
+
808
+ # We can use RETURNING to do the fetch and DELETE in once step.
809
+ sql = """
810
+ DELETE FROM e2e_one_time_keys_json
811
+ WHERE user_id = ? AND device_id = ? AND algorithm = ?
812
+ AND key_id IN (
813
+ SELECT key_id FROM e2e_one_time_keys_json
814
+ WHERE user_id = ? AND device_id = ? AND algorithm = ?
815
+ LIMIT 1
816
+ )
817
+ RETURNING key_id, key_json
818
+ """
819
+
820
+ txn .execute (
821
+ sql , (user_id , device_id , algorithm , user_id , device_id , algorithm )
799
822
)
800
- for user_id , device_id , algorithm , key_id in delete :
801
- log_kv (
802
- {
803
- "message" : "Executing claim e2e_one_time_keys transaction on database."
804
- }
805
- )
806
- txn .execute (sql , (user_id , device_id , algorithm , key_id ))
807
- log_kv ({"message" : "finished executing and invalidating cache" })
808
- self ._invalidate_cache_and_stream (
809
- txn , self .count_e2e_one_time_keys , (user_id , device_id )
823
+ otk_row = txn .fetchone ()
824
+ if otk_row is None :
825
+ return None
826
+
827
+ key_id , key_json = otk_row
828
+ return f"{ algorithm } :{ key_id } " , key_json
829
+
830
+ results = {}
831
+ for user_id , device_id , algorithm in query_list :
832
+ if self .database_engine .supports_returning :
833
+ # If we support RETURNING clause we can use a single query that
834
+ # allows us to use autocommit mode.
835
+ _claim_e2e_one_time_key = _claim_e2e_one_time_key_returning
836
+ db_autocommit = True
837
+ else :
838
+ _claim_e2e_one_time_key = _claim_e2e_one_time_key_simple
839
+ db_autocommit = False
840
+
841
+ row = await self .db_pool .runInteraction (
842
+ "claim_e2e_one_time_keys" ,
843
+ _claim_e2e_one_time_key ,
844
+ user_id ,
845
+ device_id ,
846
+ algorithm ,
847
+ db_autocommit = db_autocommit ,
848
+ )
849
+ if row :
850
+ device_results = results .setdefault (user_id , {}).setdefault (
851
+ device_id , {}
810
852
)
811
- # mark fallback keys as used
812
- for user_id , device_id , algorithm , key_id in used_fallbacks :
813
- self .db_pool .simple_update_txn (
814
- txn ,
815
- "e2e_fallback_keys_json" ,
816
- {
853
+ device_results [row [0 ]] = row [1 ]
854
+ continue
855
+
856
+ # No one-time key available, so see if there's a fallback
857
+ # key
858
+ row = await self .db_pool .simple_select_one (
859
+ table = "e2e_fallback_keys_json" ,
860
+ keyvalues = {
861
+ "user_id" : user_id ,
862
+ "device_id" : device_id ,
863
+ "algorithm" : algorithm ,
864
+ },
865
+ retcols = ("key_id" , "key_json" , "used" ),
866
+ desc = "_get_fallback_key" ,
867
+ allow_none = True ,
868
+ )
869
+ if row is None :
870
+ continue
871
+
872
+ key_id = row ["key_id" ]
873
+ key_json = row ["key_json" ]
874
+ used = row ["used" ]
875
+
876
+ # Mark fallback key as used if not already.
877
+ if not used :
878
+ await self .db_pool .simple_update_one (
879
+ table = "e2e_fallback_keys_json" ,
880
+ keyvalues = {
817
881
"user_id" : user_id ,
818
882
"device_id" : device_id ,
819
883
"algorithm" : algorithm ,
820
884
"key_id" : key_id ,
821
885
},
822
- {"used" : True },
886
+ updatevalues = {"used" : True },
887
+ desc = "_get_fallback_key_set_used" ,
823
888
)
824
- self ._invalidate_cache_and_stream (
825
- txn , self . get_e2e_unused_fallback_key_types , (user_id , device_id )
889
+ await self .invalidate_cache_and_stream (
890
+ " get_e2e_unused_fallback_key_types" , (user_id , device_id )
826
891
)
827
892
828
- return result
893
+ device_results = results .setdefault (user_id , {}).setdefault (device_id , {})
894
+ device_results [f"{ algorithm } :{ key_id } " ] = key_json
829
895
830
- return await self .db_pool .runInteraction (
831
- "claim_e2e_one_time_keys" , _claim_e2e_one_time_keys
832
- )
896
+ return results
833
897
834
898
835
899
class EndToEndKeyStore (EndToEndKeyWorkerStore , SQLBaseStore ):
0 commit comments