@@ -482,11 +482,29 @@ async def test_transaction_commit(database_url):
482
482
483
483
@pytest .mark .parametrize ("database_url" , DATABASE_URLS )
484
484
@async_adapter
485
- async def test_transaction_context_child_task_interaction (database_url ):
485
+ async def test_transaction_context_child_task_inheritance (database_url ):
486
+ """
487
+ Ensure that transactions are inherited by child tasks.
488
+ """
489
+ async with Database (database_url ) as database :
490
+
491
+ async def check_transaction (transaction , active_transaction ):
492
+ # Should have inherited the same transaction backend from the parent task
493
+ assert transaction ._transaction is active_transaction
494
+
495
+ async with database .transaction () as transaction :
496
+ await asyncio .create_task (
497
+ check_transaction (transaction , transaction ._transaction )
498
+ )
499
+
500
+
501
+ @pytest .mark .parametrize ("database_url" , DATABASE_URLS )
502
+ @async_adapter
503
+ async def test_transaction_context_child_task_inheritance_example (database_url ):
486
504
"""
487
505
Ensure that child tasks may influence inherited transactions.
488
506
"""
489
- # This is an practical example of the next test.
507
+ # This is an practical example of the above test.
490
508
async with Database (database_url ) as database :
491
509
async with database .transaction ():
492
510
# Create a note
@@ -503,37 +521,19 @@ async def test_transaction_context_child_task_interaction(database_url):
503
521
result = await database .fetch_one (notes .select ().where (notes .c .id == 1 ))
504
522
assert result .text == "prior"
505
523
506
- async def run_update_from_child_task ():
507
- # Chage the note from a child task
508
- await database .execute (
524
+ async def run_update_from_child_task (connection ):
525
+ # Change the note from a child task
526
+ await connection .execute (
509
527
notes .update ().where (notes .c .id == 1 ).values (text = "test" )
510
528
)
511
529
512
- await asyncio .create_task (run_update_from_child_task ())
530
+ await asyncio .create_task (run_update_from_child_task (database . connection () ))
513
531
514
532
# Confirm the child's change
515
533
result = await database .fetch_one (notes .select ().where (notes .c .id == 1 ))
516
534
assert result .text == "test"
517
535
518
536
519
- @pytest .mark .parametrize ("database_url" , DATABASE_URLS )
520
- @async_adapter
521
- async def test_transaction_context_child_task_inheritance (database_url ):
522
- """
523
- Ensure that transactions are inherited by child tasks.
524
- """
525
- async with Database (database_url ) as database :
526
-
527
- async def check_transaction (transaction , active_transaction ):
528
- # Should have inherited the same transaction backend from the parent task
529
- assert transaction ._transaction is active_transaction
530
-
531
- async with database .transaction () as transaction :
532
- await asyncio .create_task (
533
- check_transaction (transaction , transaction ._transaction )
534
- )
535
-
536
-
537
537
@pytest .mark .parametrize ("database_url" , DATABASE_URLS )
538
538
@async_adapter
539
539
async def test_transaction_context_sibling_task_isolation (database_url ):
@@ -568,56 +568,99 @@ async def check_transaction(transaction):
568
568
569
569
@pytest .mark .parametrize ("database_url" , DATABASE_URLS )
570
570
@async_adapter
571
- async def test_connection_context_cleanup_contextmanager (database_url ):
571
+ async def test_transaction_context_sibling_task_isolation_example (database_url ):
572
+ """
573
+ Ensure that transactions are running in sibling tasks are isolated from eachother.
574
+ """
575
+ # This is an practical example of the above test.
576
+ setup = asyncio .Event ()
577
+ done = asyncio .Event ()
578
+
579
+ async def tx1 (connection ):
580
+ async with connection .transaction ():
581
+ await db .execute (
582
+ notes .insert (), values = {"id" : 1 , "text" : "tx1" , "completed" : False }
583
+ )
584
+ setup .set ()
585
+ await done .wait ()
586
+
587
+ async def tx2 (connection ):
588
+ async with connection .transaction ():
589
+ await setup .wait ()
590
+ result = await db .fetch_all (notes .select ())
591
+ assert result == [], result
592
+ done .set ()
593
+
594
+ async with Database (database_url ) as db :
595
+ await asyncio .gather (tx1 (db ), tx2 (db ))
596
+
597
+
598
+ @pytest .mark .parametrize ("database_url" , DATABASE_URLS )
599
+ @async_adapter
600
+ async def test_connection_cleanup_contextmanager (database_url ):
572
601
"""
573
- Ensure that contextvar connections are not persisted unecessarily.
602
+ Ensure that task connections are not persisted unecessarily.
574
603
"""
575
- from databases .core import _ACTIVE_CONNECTIONS
576
604
577
- assert _ACTIVE_CONNECTIONS .get () is None
605
+ ready = asyncio .Event ()
606
+ done = asyncio .Event ()
607
+
608
+ async def check_child_connection (database : Database ):
609
+ async with database .connection ():
610
+ ready .set ()
611
+ await done .wait ()
578
612
579
613
async with Database (database_url ) as database :
614
+ # Should have a connection in this task
580
615
# .connect is lazy, it doesn't create a Connection, but .connection does
581
616
connection = database .connection ()
617
+ assert isinstance (database ._connection_map , MutableMapping )
618
+ assert database ._connection_map .get (asyncio .current_task ()) is connection
582
619
583
- open_connections = _ACTIVE_CONNECTIONS .get ()
584
- assert isinstance (open_connections , MutableMapping )
585
- assert open_connections .get (database ) is connection
620
+ # Create a child task and see if it registers a connection
621
+ task = asyncio .create_task (check_child_connection (database ))
622
+ await ready .wait ()
623
+ assert database ._connection_map .get (task ) is not None
624
+ assert database ._connection_map .get (task ) is not connection
586
625
587
- # Context manager closes, open_connections is cleaned up
588
- open_connections = _ACTIVE_CONNECTIONS .get ()
589
- assert isinstance (open_connections , MutableMapping )
590
- assert open_connections .get (database , None ) is None
626
+ # Let the child task finish, and see if it cleaned up
627
+ done .set ()
628
+ await task
629
+ # This is normal exit logic cleanup, the WeakKeyDictionary
630
+ # shouldn't have cleaned up yet since the task is still referenced
631
+ assert task not in database ._connection_map
632
+
633
+ # Context manager closes, all open connections are removed
634
+ assert isinstance (database ._connection_map , MutableMapping )
635
+ assert len (database ._connection_map ) == 0
591
636
592
637
593
638
@pytest .mark .parametrize ("database_url" , DATABASE_URLS )
594
639
@async_adapter
595
- async def test_connection_context_cleanup_garbagecollector (database_url ):
640
+ async def test_connection_cleanup_garbagecollector (database_url ):
596
641
"""
597
- Ensure that contextvar connections are not persisted unecessarily, even
642
+ Ensure that connections for tasks are not persisted unecessarily, even
598
643
if exit handlers are not called.
599
644
"""
600
- from databases .core import _ACTIVE_CONNECTIONS
601
-
602
- assert _ACTIVE_CONNECTIONS .get () is None
603
-
604
645
database = Database (database_url )
605
646
await database .connect ()
606
- connection = database .connection ()
607
647
608
- # Should be tracking the connection
609
- open_connections = _ACTIVE_CONNECTIONS .get ()
610
- assert isinstance (open_connections , MutableMapping )
611
- assert open_connections .get (database ) is connection
648
+ created = asyncio .Event ()
649
+
650
+ async def check_child_connection (database : Database ):
651
+ # neither .disconnect nor .__aexit__ are called before deleting this task
652
+ database .connection ()
653
+ created .set ()
612
654
613
- # neither .disconnect nor .__aexit__ are called before deleting the reference
614
- del database
655
+ task = asyncio .create_task (check_child_connection (database ))
656
+ await created .wait ()
657
+ assert task in database ._connection_map
658
+ await task
659
+ del task
615
660
gc .collect ()
616
661
617
- # Should have dropped reference to connection, even without proper cleanup
618
- open_connections = _ACTIVE_CONNECTIONS .get ()
619
- assert isinstance (open_connections , MutableMapping )
620
- assert len (open_connections ) == 0
662
+ # Should not have a connection for the task anymore
663
+ assert len (database ._connection_map ) == 0
621
664
622
665
623
666
@pytest .mark .parametrize ("database_url" , DATABASE_URLS )
@@ -632,7 +675,6 @@ async def test_transaction_context_cleanup_contextmanager(database_url):
632
675
633
676
async with Database (database_url ) as database :
634
677
async with database .transaction () as transaction :
635
-
636
678
open_transactions = _ACTIVE_TRANSACTIONS .get ()
637
679
assert isinstance (open_transactions , MutableMapping )
638
680
assert open_transactions .get (transaction ) is transaction ._transaction
@@ -818,17 +860,44 @@ async def insert_data(raise_exception):
818
860
with pytest .raises (RuntimeError ):
819
861
await insert_data (raise_exception = True )
820
862
821
- query = notes .select ()
822
- results = await database .fetch_all (query = query )
863
+ results = await database .fetch_all (query = notes .select ())
823
864
assert len (results ) == 0
824
865
825
866
await insert_data (raise_exception = False )
826
867
827
- query = notes .select ()
828
- results = await database .fetch_all (query = query )
868
+ results = await database .fetch_all (query = notes .select ())
829
869
assert len (results ) == 1
830
870
831
871
872
+ @pytest .mark .parametrize ("database_url" , DATABASE_URLS )
873
+ @async_adapter
874
+ async def test_transaction_decorator_concurrent (database_url ):
875
+ """
876
+ Ensure that @database.transaction() can be called concurrently.
877
+ """
878
+
879
+ database = Database (database_url )
880
+
881
+ @database .transaction ()
882
+ async def insert_data ():
883
+ await database .execute (
884
+ query = notes .insert ().values (text = "example" , completed = True )
885
+ )
886
+
887
+ async with database :
888
+ await asyncio .gather (
889
+ insert_data (),
890
+ insert_data (),
891
+ insert_data (),
892
+ insert_data (),
893
+ insert_data (),
894
+ insert_data (),
895
+ )
896
+
897
+ results = await database .fetch_all (query = notes .select ())
898
+ assert len (results ) == 6
899
+
900
+
832
901
@pytest .mark .parametrize ("database_url" , DATABASE_URLS )
833
902
@async_adapter
834
903
async def test_datetime_field (database_url ):
@@ -1007,7 +1076,7 @@ async def test_connection_context_same_task(database_url):
1007
1076
1008
1077
@pytest .mark .parametrize ("database_url" , DATABASE_URLS )
1009
1078
@async_adapter
1010
- async def test_connection_context_multiple_tasks (database_url ):
1079
+ async def test_connection_context_multiple_sibling_tasks (database_url ):
1011
1080
async with Database (database_url ) as database :
1012
1081
connection_1 = None
1013
1082
connection_2 = None
@@ -1037,6 +1106,47 @@ async def get_connection_2():
1037
1106
await task_2
1038
1107
1039
1108
1109
+ @pytest .mark .parametrize ("database_url" , DATABASE_URLS )
1110
+ @async_adapter
1111
+ async def test_connection_context_multiple_tasks (database_url ):
1112
+ async with Database (database_url ) as database :
1113
+ parent_connection = database .connection ()
1114
+ connection_1 = None
1115
+ connection_2 = None
1116
+ task_1_ready = asyncio .Event ()
1117
+ task_2_ready = asyncio .Event ()
1118
+ test_complete = asyncio .Event ()
1119
+
1120
+ async def get_connection_1 ():
1121
+ nonlocal connection_1
1122
+
1123
+ async with database .connection () as connection :
1124
+ connection_1 = connection
1125
+ task_1_ready .set ()
1126
+ await test_complete .wait ()
1127
+
1128
+ async def get_connection_2 ():
1129
+ nonlocal connection_2
1130
+
1131
+ async with database .connection () as connection :
1132
+ connection_2 = connection
1133
+ task_2_ready .set ()
1134
+ await test_complete .wait ()
1135
+
1136
+ task_1 = asyncio .create_task (get_connection_1 ())
1137
+ task_2 = asyncio .create_task (get_connection_2 ())
1138
+ await task_1_ready .wait ()
1139
+ await task_2_ready .wait ()
1140
+
1141
+ assert connection_1 is not parent_connection
1142
+ assert connection_2 is not parent_connection
1143
+ assert connection_1 is not connection_2
1144
+
1145
+ test_complete .set ()
1146
+ await task_1
1147
+ await task_2
1148
+
1149
+
1040
1150
@pytest .mark .parametrize (
1041
1151
"database_url1,database_url2" ,
1042
1152
(
0 commit comments