2
2
import datetime
3
3
import decimal
4
4
import functools
5
+ import gc
6
+ import itertools
5
7
import os
6
8
import re
9
+ from typing import MutableMapping
7
10
from unittest .mock import MagicMock , patch
8
11
9
12
import pytest
@@ -477,6 +480,254 @@ async def test_transaction_commit(database_url):
477
480
assert len (results ) == 1
478
481
479
482
483
+ @pytest .mark .parametrize ("database_url" , DATABASE_URLS )
484
+ @async_adapter
485
+ async def test_transaction_context_child_task_inheritance (database_url ):
486
+ """
487
+ Ensure that transactions are inherited by child tasks.
488
+ """
489
+ async with Database (database_url ) as database :
490
+
491
+ async def check_transaction (transaction , active_transaction ):
492
+ # Should have inherited the same transaction backend from the parent task
493
+ assert transaction ._transaction is active_transaction
494
+
495
+ async with database .transaction () as transaction :
496
+ await asyncio .create_task (
497
+ check_transaction (transaction , transaction ._transaction )
498
+ )
499
+
500
+
501
+ @pytest .mark .parametrize ("database_url" , DATABASE_URLS )
502
+ @async_adapter
503
+ async def test_transaction_context_child_task_inheritance_example (database_url ):
504
+ """
505
+ Ensure that child tasks may influence inherited transactions.
506
+ """
507
+ # This is an practical example of the above test.
508
+ async with Database (database_url ) as database :
509
+ async with database .transaction ():
510
+ # Create a note
511
+ await database .execute (
512
+ notes .insert ().values (id = 1 , text = "setup" , completed = True )
513
+ )
514
+
515
+ # Change the note from the same task
516
+ await database .execute (
517
+ notes .update ().where (notes .c .id == 1 ).values (text = "prior" )
518
+ )
519
+
520
+ # Confirm the change
521
+ result = await database .fetch_one (notes .select ().where (notes .c .id == 1 ))
522
+ assert result .text == "prior"
523
+
524
+ async def run_update_from_child_task (connection ):
525
+ # Change the note from a child task
526
+ await connection .execute (
527
+ notes .update ().where (notes .c .id == 1 ).values (text = "test" )
528
+ )
529
+
530
+ await asyncio .create_task (run_update_from_child_task (database .connection ()))
531
+
532
+ # Confirm the child's change
533
+ result = await database .fetch_one (notes .select ().where (notes .c .id == 1 ))
534
+ assert result .text == "test"
535
+
536
+
537
+ @pytest .mark .parametrize ("database_url" , DATABASE_URLS )
538
+ @async_adapter
539
+ async def test_transaction_context_sibling_task_isolation (database_url ):
540
+ """
541
+ Ensure that transactions are isolated between sibling tasks.
542
+ """
543
+ start = asyncio .Event ()
544
+ end = asyncio .Event ()
545
+
546
+ async with Database (database_url ) as database :
547
+
548
+ async def check_transaction (transaction ):
549
+ await start .wait ()
550
+ # Parent task is now in a transaction, we should not
551
+ # see its transaction backend since this task was
552
+ # _started_ in a context where no transaction was active.
553
+ assert transaction ._transaction is None
554
+ end .set ()
555
+
556
+ transaction = database .transaction ()
557
+ assert transaction ._transaction is None
558
+ task = asyncio .create_task (check_transaction (transaction ))
559
+
560
+ async with transaction :
561
+ start .set ()
562
+ assert transaction ._transaction is not None
563
+ await end .wait ()
564
+
565
+ # Cleanup for "Task not awaited" warning
566
+ await task
567
+
568
+
569
+ @pytest .mark .parametrize ("database_url" , DATABASE_URLS )
570
+ @async_adapter
571
+ async def test_transaction_context_sibling_task_isolation_example (database_url ):
572
+ """
573
+ Ensure that transactions are running in sibling tasks are isolated from eachother.
574
+ """
575
+ # This is an practical example of the above test.
576
+ setup = asyncio .Event ()
577
+ done = asyncio .Event ()
578
+
579
+ async def tx1 (connection ):
580
+ async with connection .transaction ():
581
+ await db .execute (
582
+ notes .insert (), values = {"id" : 1 , "text" : "tx1" , "completed" : False }
583
+ )
584
+ setup .set ()
585
+ await done .wait ()
586
+
587
+ async def tx2 (connection ):
588
+ async with connection .transaction ():
589
+ await setup .wait ()
590
+ result = await db .fetch_all (notes .select ())
591
+ assert result == [], result
592
+ done .set ()
593
+
594
+ async with Database (database_url ) as db :
595
+ await asyncio .gather (tx1 (db ), tx2 (db ))
596
+
597
+
598
+ @pytest .mark .parametrize ("database_url" , DATABASE_URLS )
599
+ @async_adapter
600
+ async def test_connection_cleanup_contextmanager (database_url ):
601
+ """
602
+ Ensure that task connections are not persisted unecessarily.
603
+ """
604
+
605
+ ready = asyncio .Event ()
606
+ done = asyncio .Event ()
607
+
608
+ async def check_child_connection (database : Database ):
609
+ async with database .connection ():
610
+ ready .set ()
611
+ await done .wait ()
612
+
613
+ async with Database (database_url ) as database :
614
+ # Should have a connection in this task
615
+ # .connect is lazy, it doesn't create a Connection, but .connection does
616
+ connection = database .connection ()
617
+ assert isinstance (database ._connection_map , MutableMapping )
618
+ assert database ._connection_map .get (asyncio .current_task ()) is connection
619
+
620
+ # Create a child task and see if it registers a connection
621
+ task = asyncio .create_task (check_child_connection (database ))
622
+ await ready .wait ()
623
+ assert database ._connection_map .get (task ) is not None
624
+ assert database ._connection_map .get (task ) is not connection
625
+
626
+ # Let the child task finish, and see if it cleaned up
627
+ done .set ()
628
+ await task
629
+ # This is normal exit logic cleanup, the WeakKeyDictionary
630
+ # shouldn't have cleaned up yet since the task is still referenced
631
+ assert task not in database ._connection_map
632
+
633
+ # Context manager closes, all open connections are removed
634
+ assert isinstance (database ._connection_map , MutableMapping )
635
+ assert len (database ._connection_map ) == 0
636
+
637
+
638
+ @pytest .mark .parametrize ("database_url" , DATABASE_URLS )
639
+ @async_adapter
640
+ async def test_connection_cleanup_garbagecollector (database_url ):
641
+ """
642
+ Ensure that connections for tasks are not persisted unecessarily, even
643
+ if exit handlers are not called.
644
+ """
645
+ database = Database (database_url )
646
+ await database .connect ()
647
+
648
+ created = asyncio .Event ()
649
+
650
+ async def check_child_connection (database : Database ):
651
+ # neither .disconnect nor .__aexit__ are called before deleting this task
652
+ database .connection ()
653
+ created .set ()
654
+
655
+ task = asyncio .create_task (check_child_connection (database ))
656
+ await created .wait ()
657
+ assert task in database ._connection_map
658
+ await task
659
+ del task
660
+ gc .collect ()
661
+
662
+ # Should not have a connection for the task anymore
663
+ assert len (database ._connection_map ) == 0
664
+
665
+
666
+ @pytest .mark .parametrize ("database_url" , DATABASE_URLS )
667
+ @async_adapter
668
+ async def test_transaction_context_cleanup_contextmanager (database_url ):
669
+ """
670
+ Ensure that contextvar transactions are not persisted unecessarily.
671
+ """
672
+ from databases .core import _ACTIVE_TRANSACTIONS
673
+
674
+ assert _ACTIVE_TRANSACTIONS .get () is None
675
+
676
+ async with Database (database_url ) as database :
677
+ async with database .transaction () as transaction :
678
+ open_transactions = _ACTIVE_TRANSACTIONS .get ()
679
+ assert isinstance (open_transactions , MutableMapping )
680
+ assert open_transactions .get (transaction ) is transaction ._transaction
681
+
682
+ # Context manager closes, open_transactions is cleaned up
683
+ open_transactions = _ACTIVE_TRANSACTIONS .get ()
684
+ assert isinstance (open_transactions , MutableMapping )
685
+ assert open_transactions .get (transaction , None ) is None
686
+
687
+
688
+ @pytest .mark .parametrize ("database_url" , DATABASE_URLS )
689
+ @async_adapter
690
+ async def test_transaction_context_cleanup_garbagecollector (database_url ):
691
+ """
692
+ Ensure that contextvar transactions are not persisted unecessarily, even
693
+ if exit handlers are not called.
694
+
695
+ This test should be an XFAIL, but cannot be due to the way that is hangs
696
+ during teardown.
697
+ """
698
+ from databases .core import _ACTIVE_TRANSACTIONS
699
+
700
+ assert _ACTIVE_TRANSACTIONS .get () is None
701
+
702
+ async with Database (database_url ) as database :
703
+ transaction = database .transaction ()
704
+ await transaction .start ()
705
+
706
+ # Should be tracking the transaction
707
+ open_transactions = _ACTIVE_TRANSACTIONS .get ()
708
+ assert isinstance (open_transactions , MutableMapping )
709
+ assert open_transactions .get (transaction ) is transaction ._transaction
710
+
711
+ # neither .commit, .rollback, nor .__aexit__ are called
712
+ del transaction
713
+ gc .collect ()
714
+
715
+ # TODO(zevisert,review): Could skip instead of using the logic below
716
+ # A strong reference to the transaction is kept alive by the connection's
717
+ # ._transaction_stack, so it is still be tracked at this point.
718
+ assert len (open_transactions ) == 1
719
+
720
+ # If that were magically cleared, the transaction would be cleaned up,
721
+ # but as it stands this always causes a hang during teardown at
722
+ # `Database(...).disconnect()` if the transaction is not closed.
723
+ transaction = database .connection ()._transaction_stack [- 1 ]
724
+ await transaction .rollback ()
725
+ del transaction
726
+
727
+ # Now with the transaction rolled-back, it should be cleaned up.
728
+ assert len (open_transactions ) == 0
729
+
730
+
480
731
@pytest .mark .parametrize ("database_url" , DATABASE_URLS )
481
732
@async_adapter
482
733
async def test_transaction_commit_serializable (database_url ):
@@ -609,17 +860,44 @@ async def insert_data(raise_exception):
609
860
with pytest .raises (RuntimeError ):
610
861
await insert_data (raise_exception = True )
611
862
612
- query = notes .select ()
613
- results = await database .fetch_all (query = query )
863
+ results = await database .fetch_all (query = notes .select ())
614
864
assert len (results ) == 0
615
865
616
866
await insert_data (raise_exception = False )
617
867
618
- query = notes .select ()
619
- results = await database .fetch_all (query = query )
868
+ results = await database .fetch_all (query = notes .select ())
620
869
assert len (results ) == 1
621
870
622
871
872
+ @pytest .mark .parametrize ("database_url" , DATABASE_URLS )
873
+ @async_adapter
874
+ async def test_transaction_decorator_concurrent (database_url ):
875
+ """
876
+ Ensure that @database.transaction() can be called concurrently.
877
+ """
878
+
879
+ database = Database (database_url )
880
+
881
+ @database .transaction ()
882
+ async def insert_data ():
883
+ await database .execute (
884
+ query = notes .insert ().values (text = "example" , completed = True )
885
+ )
886
+
887
+ async with database :
888
+ await asyncio .gather (
889
+ insert_data (),
890
+ insert_data (),
891
+ insert_data (),
892
+ insert_data (),
893
+ insert_data (),
894
+ insert_data (),
895
+ )
896
+
897
+ results = await database .fetch_all (query = notes .select ())
898
+ assert len (results ) == 6
899
+
900
+
623
901
@pytest .mark .parametrize ("database_url" , DATABASE_URLS )
624
902
@async_adapter
625
903
async def test_datetime_field (database_url ):
@@ -789,15 +1067,16 @@ async def test_connect_and_disconnect(database_url):
789
1067
790
1068
@pytest .mark .parametrize ("database_url" , DATABASE_URLS )
791
1069
@async_adapter
792
- async def test_connection_context (database_url ):
793
- """
794
- Test connection contexts are task-local.
795
- """
1070
+ async def test_connection_context_same_task (database_url ):
796
1071
async with Database (database_url ) as database :
797
1072
async with database .connection () as connection_1 :
798
1073
async with database .connection () as connection_2 :
799
1074
assert connection_1 is connection_2
800
1075
1076
+
1077
+ @pytest .mark .parametrize ("database_url" , DATABASE_URLS )
1078
+ @async_adapter
1079
+ async def test_connection_context_multiple_sibling_tasks (database_url ):
801
1080
async with Database (database_url ) as database :
802
1081
connection_1 = None
803
1082
connection_2 = None
@@ -817,9 +1096,8 @@ async def get_connection_2():
817
1096
connection_2 = connection
818
1097
await test_complete .wait ()
819
1098
820
- loop = asyncio .get_event_loop ()
821
- task_1 = loop .create_task (get_connection_1 ())
822
- task_2 = loop .create_task (get_connection_2 ())
1099
+ task_1 = asyncio .create_task (get_connection_1 ())
1100
+ task_2 = asyncio .create_task (get_connection_2 ())
823
1101
while connection_1 is None or connection_2 is None :
824
1102
await asyncio .sleep (0.000001 )
825
1103
assert connection_1 is not connection_2
@@ -828,6 +1106,61 @@ async def get_connection_2():
828
1106
await task_2
829
1107
830
1108
1109
+ @pytest .mark .parametrize ("database_url" , DATABASE_URLS )
1110
+ @async_adapter
1111
+ async def test_connection_context_multiple_tasks (database_url ):
1112
+ async with Database (database_url ) as database :
1113
+ parent_connection = database .connection ()
1114
+ connection_1 = None
1115
+ connection_2 = None
1116
+ task_1_ready = asyncio .Event ()
1117
+ task_2_ready = asyncio .Event ()
1118
+ test_complete = asyncio .Event ()
1119
+
1120
+ async def get_connection_1 ():
1121
+ nonlocal connection_1
1122
+
1123
+ async with database .connection () as connection :
1124
+ connection_1 = connection
1125
+ task_1_ready .set ()
1126
+ await test_complete .wait ()
1127
+
1128
+ async def get_connection_2 ():
1129
+ nonlocal connection_2
1130
+
1131
+ async with database .connection () as connection :
1132
+ connection_2 = connection
1133
+ task_2_ready .set ()
1134
+ await test_complete .wait ()
1135
+
1136
+ task_1 = asyncio .create_task (get_connection_1 ())
1137
+ task_2 = asyncio .create_task (get_connection_2 ())
1138
+ await task_1_ready .wait ()
1139
+ await task_2_ready .wait ()
1140
+
1141
+ assert connection_1 is not parent_connection
1142
+ assert connection_2 is not parent_connection
1143
+ assert connection_1 is not connection_2
1144
+
1145
+ test_complete .set ()
1146
+ await task_1
1147
+ await task_2
1148
+
1149
+
1150
+ @pytest .mark .parametrize (
1151
+ "database_url1,database_url2" ,
1152
+ (
1153
+ pytest .param (db1 , db2 , id = f"{ db1 } | { db2 } " )
1154
+ for (db1 , db2 ) in itertools .combinations (DATABASE_URLS , 2 )
1155
+ ),
1156
+ )
1157
+ @async_adapter
1158
+ async def test_connection_context_multiple_databases (database_url1 , database_url2 ):
1159
+ async with Database (database_url1 ) as database1 :
1160
+ async with Database (database_url2 ) as database2 :
1161
+ assert database1 .connection () is not database2 .connection ()
1162
+
1163
+
831
1164
@pytest .mark .parametrize ("database_url" , DATABASE_URLS )
832
1165
@async_adapter
833
1166
async def test_connection_context_with_raw_connection (database_url ):
@@ -961,16 +1294,59 @@ async def test_database_url_interface(database_url):
961
1294
@pytest .mark .parametrize ("database_url" , DATABASE_URLS )
962
1295
@async_adapter
963
1296
async def test_concurrent_access_on_single_connection (database_url ):
964
- database_url = DatabaseURL (database_url )
965
- if database_url .dialect != "postgresql" :
966
- pytest .skip ("Test requires `pg_sleep()`" )
967
-
968
1297
async with Database (database_url , force_rollback = True ) as database :
969
1298
970
1299
async def db_lookup ():
971
- await database .fetch_one ("SELECT pg_sleep(1)" )
1300
+ await database .fetch_one ("SELECT 1 AS value" )
1301
+
1302
+ await asyncio .gather (
1303
+ db_lookup (),
1304
+ db_lookup (),
1305
+ )
1306
+
1307
+
1308
+ @pytest .mark .parametrize ("database_url" , DATABASE_URLS )
1309
+ @async_adapter
1310
+ async def test_concurrent_transactions_on_single_connection (database_url : str ):
1311
+ async with Database (database_url ) as database :
1312
+
1313
+ @database .transaction ()
1314
+ async def db_lookup ():
1315
+ await database .fetch_one (query = "SELECT 1 AS value" )
1316
+
1317
+ await asyncio .gather (
1318
+ db_lookup (),
1319
+ db_lookup (),
1320
+ )
1321
+
1322
+
1323
+ @pytest .mark .parametrize ("database_url" , DATABASE_URLS )
1324
+ @async_adapter
1325
+ async def test_concurrent_tasks_on_single_connection (database_url : str ):
1326
+ async with Database (database_url ) as database :
1327
+
1328
+ async def db_lookup ():
1329
+ await database .fetch_one (query = "SELECT 1 AS value" )
1330
+
1331
+ await asyncio .gather (
1332
+ asyncio .create_task (db_lookup ()),
1333
+ asyncio .create_task (db_lookup ()),
1334
+ )
1335
+
1336
+
1337
+ @pytest .mark .parametrize ("database_url" , DATABASE_URLS )
1338
+ @async_adapter
1339
+ async def test_concurrent_task_transactions_on_single_connection (database_url : str ):
1340
+ async with Database (database_url ) as database :
1341
+
1342
+ @database .transaction ()
1343
+ async def db_lookup ():
1344
+ await database .fetch_one (query = "SELECT 1 AS value" )
972
1345
973
- await asyncio .gather (db_lookup (), db_lookup ())
1346
+ await asyncio .gather (
1347
+ asyncio .create_task (db_lookup ()),
1348
+ asyncio .create_task (db_lookup ()),
1349
+ )
974
1350
975
1351
976
1352
@pytest .mark .parametrize ("database_url" , DATABASE_URLS )
0 commit comments