Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit b94f097

Browse files
committedMay 26, 2023
fix: remove connection inheritance, add more tests, update docs
Connections are once again stored as state on the Database instance, keyed by the current asyncio.Task. Each task acquires it's own connection, and a WeakKeyDictionary allows the connection to be discarded if the owning task is garbage collected. TransactionBackends are still stored as contextvars, and a connection must be explicitly provided to descendant tasks if active transaction state is to be inherited.
1 parent 6de4f60 commit b94f097

File tree

3 files changed

+203
-90
lines changed

3 files changed

+203
-90
lines changed
 

‎databases/core.py

+24-20
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,9 @@
3636
logger = logging.getLogger("databases")
3737

3838

39-
_ACTIVE_CONNECTIONS: ContextVar[
40-
typing.Optional["weakref.WeakKeyDictionary['Database', 'Connection']"]
41-
] = ContextVar("databases:open_connections", default=None)
4239
_ACTIVE_TRANSACTIONS: ContextVar[
4340
typing.Optional["weakref.WeakKeyDictionary['Transaction', 'TransactionBackend']"]
44-
] = ContextVar("databases:open_transactions", default=None)
41+
] = ContextVar("databases:active_transactions", default=None)
4542

4643

4744
class Database:
@@ -54,6 +51,8 @@ class Database:
5451
"sqlite": "databases.backends.sqlite:SQLiteBackend",
5552
}
5653

54+
_connection_map: "weakref.WeakKeyDictionary[asyncio.Task, 'Connection']"
55+
5756
def __init__(
5857
self,
5958
url: typing.Union[str, "DatabaseURL"],
@@ -64,6 +63,7 @@ def __init__(
6463
self.url = DatabaseURL(url)
6564
self.options = options
6665
self.is_connected = False
66+
self._connection_map = weakref.WeakKeyDictionary()
6767

6868
self._force_rollback = force_rollback
6969

@@ -78,28 +78,28 @@ def __init__(
7878
self._global_transaction: typing.Optional[Transaction] = None
7979

8080
@property
81-
def _connection(self) -> typing.Optional["Connection"]:
82-
connections = _ACTIVE_CONNECTIONS.get()
83-
if connections is None:
84-
return None
81+
def _current_task(self) -> asyncio.Task:
82+
task = asyncio.current_task()
83+
if not task:
84+
raise RuntimeError("No currently active asyncio.Task found")
85+
return task
8586

86-
return connections.get(self, None)
87+
@property
88+
def _connection(self) -> typing.Optional["Connection"]:
89+
return self._connection_map.get(self._current_task)
8790

8891
@_connection.setter
8992
def _connection(
9093
self, connection: typing.Optional["Connection"]
9194
) -> typing.Optional["Connection"]:
92-
connections = _ACTIVE_CONNECTIONS.get()
93-
if connections is None:
94-
connections = weakref.WeakKeyDictionary()
95-
_ACTIVE_CONNECTIONS.set(connections)
95+
task = self._current_task
9696

9797
if connection is None:
98-
connections.pop(self, None)
98+
self._connection_map.pop(task, None)
9999
else:
100-
connections[self] = connection
100+
self._connection_map[task] = connection
101101

102-
return connections.get(self, None)
102+
return self._connection
103103

104104
async def connect(self) -> None:
105105
"""
@@ -119,7 +119,7 @@ async def connect(self) -> None:
119119
assert self._global_connection is None
120120
assert self._global_transaction is None
121121

122-
self._global_connection = Connection(self._backend)
122+
self._global_connection = Connection(self, self._backend)
123123
self._global_transaction = self._global_connection.transaction(
124124
force_rollback=True
125125
)
@@ -218,7 +218,7 @@ def connection(self) -> "Connection":
218218
return self._global_connection
219219

220220
if not self._connection:
221-
self._connection = Connection(self._backend)
221+
self._connection = Connection(self, self._backend)
222222

223223
return self._connection
224224

@@ -243,7 +243,8 @@ def _get_backend(self) -> str:
243243

244244

245245
class Connection:
246-
def __init__(self, backend: DatabaseBackend) -> None:
246+
def __init__(self, database: Database, backend: DatabaseBackend) -> None:
247+
self._database = database
247248
self._backend = backend
248249

249250
self._connection_lock = asyncio.Lock()
@@ -277,6 +278,7 @@ async def __aexit__(
277278
self._connection_counter -= 1
278279
if self._connection_counter == 0:
279280
await self._connection.release()
281+
self._database._connection = None
280282

281283
async def fetch_all(
282284
self,
@@ -393,13 +395,15 @@ def _transaction(
393395
transactions = _ACTIVE_TRANSACTIONS.get()
394396
if transactions is None:
395397
transactions = weakref.WeakKeyDictionary()
396-
_ACTIVE_TRANSACTIONS.set(transactions)
398+
else:
399+
transactions = transactions.copy()
397400

398401
if transaction is None:
399402
transactions.pop(self, None)
400403
else:
401404
transactions[self] = transaction
402405

406+
_ACTIVE_TRANSACTIONS.set(transactions)
403407
return transactions.get(self, None)
404408

405409
async def __aenter__(self) -> "Transaction":

‎docs/connections_and_transactions.md

+11-12
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@ that transparently handles the use of either transactions or savepoints.
77

88
## Connecting and disconnecting
99

10-
You can control the database connect/disconnect, by using it as a async context manager.
10+
You can control the database connection pool with an async context manager:
1111

1212
```python
1313
async with Database(DATABASE_URL) as database:
1414
...
1515
```
1616

17-
Or by using explicit connection and disconnection:
17+
Or by using the explicit `.connect()` and `.disconnect()` methods:
1818

1919
```python
2020
database = Database(DATABASE_URL)
@@ -23,6 +23,8 @@ await database.connect()
2323
await database.disconnect()
2424
```
2525

26+
Connections within this connection pool are acquired for each new `asyncio.Task`.
27+
2628
If you're integrating against a web framework, then you'll probably want
2729
to hook into framework startup or shutdown events. For example, with
2830
[Starlette][starlette] you would use the following:
@@ -96,12 +98,13 @@ async def create_users(request):
9698
...
9799
```
98100

99-
Transaction state is stored in the context of the currently executing asynchronous task.
100-
This state is _inherited_ by tasks that are started from within an active transaction:
101+
Transaction state is tied to the connection used in the currently executing asynchronous task.
102+
If you would like to influence an active transaction from another task, the connection must be
103+
shared. This state is _inherited_ by tasks that are share the same connection:
101104

102105
```python
103-
async def add_excitement(database: Database, id: int):
104-
await database.execute(
106+
async def add_excitement(connnection: databases.core.Connection, id: int):
107+
await connection.execute(
105108
"UPDATE notes SET text = CONCAT(text, '!!!') WHERE id = :id",
106109
{"id": id}
107110
)
@@ -113,17 +116,13 @@ async with Database(database_url) as database:
113116
await database.execute(
114117
"INSERT INTO notes(id, text) values (1, 'databases is cool')"
115118
)
116-
# ...but child tasks inherit transaction state!
117-
await asyncio.create_task(add_excitement(database, id=1))
119+
# ...but child tasks can use this connection now!
120+
await asyncio.create_task(add_excitement(database.connection(), id=1))
118121

119122
await database.fetch_val("SELECT text FROM notes WHERE id=1")
120123
# ^ returns: "databases is cool!!!"
121124
```
122125

123-
!!! note
124-
In python 3.11, you can opt-out of context propagation by providing a new context to
125-
[`asyncio.create_task`](https://docs.python.org/3.11/library/asyncio-task.html#creating-tasks).
126-
127126
Nested transactions are fully supported, and are implemented using database savepoints:
128127

129128
```python

‎tests/test_databases.py

+168-58
Original file line numberDiff line numberDiff line change
@@ -482,11 +482,29 @@ async def test_transaction_commit(database_url):
482482

483483
@pytest.mark.parametrize("database_url", DATABASE_URLS)
484484
@async_adapter
485-
async def test_transaction_context_child_task_interaction(database_url):
485+
async def test_transaction_context_child_task_inheritance(database_url):
486+
"""
487+
Ensure that transactions are inherited by child tasks.
488+
"""
489+
async with Database(database_url) as database:
490+
491+
async def check_transaction(transaction, active_transaction):
492+
# Should have inherited the same transaction backend from the parent task
493+
assert transaction._transaction is active_transaction
494+
495+
async with database.transaction() as transaction:
496+
await asyncio.create_task(
497+
check_transaction(transaction, transaction._transaction)
498+
)
499+
500+
501+
@pytest.mark.parametrize("database_url", DATABASE_URLS)
502+
@async_adapter
503+
async def test_transaction_context_child_task_inheritance_example(database_url):
486504
"""
487505
Ensure that child tasks may influence inherited transactions.
488506
"""
489-
# This is an practical example of the next test.
507+
# This is an practical example of the above test.
490508
async with Database(database_url) as database:
491509
async with database.transaction():
492510
# Create a note
@@ -503,37 +521,19 @@ async def test_transaction_context_child_task_interaction(database_url):
503521
result = await database.fetch_one(notes.select().where(notes.c.id == 1))
504522
assert result.text == "prior"
505523

506-
async def run_update_from_child_task():
507-
# Chage the note from a child task
508-
await database.execute(
524+
async def run_update_from_child_task(connection):
525+
# Change the note from a child task
526+
await connection.execute(
509527
notes.update().where(notes.c.id == 1).values(text="test")
510528
)
511529

512-
await asyncio.create_task(run_update_from_child_task())
530+
await asyncio.create_task(run_update_from_child_task(database.connection()))
513531

514532
# Confirm the child's change
515533
result = await database.fetch_one(notes.select().where(notes.c.id == 1))
516534
assert result.text == "test"
517535

518536

519-
@pytest.mark.parametrize("database_url", DATABASE_URLS)
520-
@async_adapter
521-
async def test_transaction_context_child_task_inheritance(database_url):
522-
"""
523-
Ensure that transactions are inherited by child tasks.
524-
"""
525-
async with Database(database_url) as database:
526-
527-
async def check_transaction(transaction, active_transaction):
528-
# Should have inherited the same transaction backend from the parent task
529-
assert transaction._transaction is active_transaction
530-
531-
async with database.transaction() as transaction:
532-
await asyncio.create_task(
533-
check_transaction(transaction, transaction._transaction)
534-
)
535-
536-
537537
@pytest.mark.parametrize("database_url", DATABASE_URLS)
538538
@async_adapter
539539
async def test_transaction_context_sibling_task_isolation(database_url):
@@ -568,56 +568,99 @@ async def check_transaction(transaction):
568568

569569
@pytest.mark.parametrize("database_url", DATABASE_URLS)
570570
@async_adapter
571-
async def test_connection_context_cleanup_contextmanager(database_url):
571+
async def test_transaction_context_sibling_task_isolation_example(database_url):
572+
"""
573+
Ensure that transactions are running in sibling tasks are isolated from eachother.
574+
"""
575+
# This is an practical example of the above test.
576+
setup = asyncio.Event()
577+
done = asyncio.Event()
578+
579+
async def tx1(connection):
580+
async with connection.transaction():
581+
await db.execute(
582+
notes.insert(), values={"id": 1, "text": "tx1", "completed": False}
583+
)
584+
setup.set()
585+
await done.wait()
586+
587+
async def tx2(connection):
588+
async with connection.transaction():
589+
await setup.wait()
590+
result = await db.fetch_all(notes.select())
591+
assert result == [], result
592+
done.set()
593+
594+
async with Database(database_url) as db:
595+
await asyncio.gather(tx1(db), tx2(db))
596+
597+
598+
@pytest.mark.parametrize("database_url", DATABASE_URLS)
599+
@async_adapter
600+
async def test_connection_cleanup_contextmanager(database_url):
572601
"""
573-
Ensure that contextvar connections are not persisted unecessarily.
602+
Ensure that task connections are not persisted unecessarily.
574603
"""
575-
from databases.core import _ACTIVE_CONNECTIONS
576604

577-
assert _ACTIVE_CONNECTIONS.get() is None
605+
ready = asyncio.Event()
606+
done = asyncio.Event()
607+
608+
async def check_child_connection(database: Database):
609+
async with database.connection():
610+
ready.set()
611+
await done.wait()
578612

579613
async with Database(database_url) as database:
614+
# Should have a connection in this task
580615
# .connect is lazy, it doesn't create a Connection, but .connection does
581616
connection = database.connection()
617+
assert isinstance(database._connection_map, MutableMapping)
618+
assert database._connection_map.get(asyncio.current_task()) is connection
582619

583-
open_connections = _ACTIVE_CONNECTIONS.get()
584-
assert isinstance(open_connections, MutableMapping)
585-
assert open_connections.get(database) is connection
620+
# Create a child task and see if it registers a connection
621+
task = asyncio.create_task(check_child_connection(database))
622+
await ready.wait()
623+
assert database._connection_map.get(task) is not None
624+
assert database._connection_map.get(task) is not connection
586625

587-
# Context manager closes, open_connections is cleaned up
588-
open_connections = _ACTIVE_CONNECTIONS.get()
589-
assert isinstance(open_connections, MutableMapping)
590-
assert open_connections.get(database, None) is None
626+
# Let the child task finish, and see if it cleaned up
627+
done.set()
628+
await task
629+
# This is normal exit logic cleanup, the WeakKeyDictionary
630+
# shouldn't have cleaned up yet since the task is still referenced
631+
assert task not in database._connection_map
632+
633+
# Context manager closes, all open connections are removed
634+
assert isinstance(database._connection_map, MutableMapping)
635+
assert len(database._connection_map) == 0
591636

592637

593638
@pytest.mark.parametrize("database_url", DATABASE_URLS)
594639
@async_adapter
595-
async def test_connection_context_cleanup_garbagecollector(database_url):
640+
async def test_connection_cleanup_garbagecollector(database_url):
596641
"""
597-
Ensure that contextvar connections are not persisted unecessarily, even
642+
Ensure that connections for tasks are not persisted unecessarily, even
598643
if exit handlers are not called.
599644
"""
600-
from databases.core import _ACTIVE_CONNECTIONS
601-
602-
assert _ACTIVE_CONNECTIONS.get() is None
603-
604645
database = Database(database_url)
605646
await database.connect()
606-
connection = database.connection()
607647

608-
# Should be tracking the connection
609-
open_connections = _ACTIVE_CONNECTIONS.get()
610-
assert isinstance(open_connections, MutableMapping)
611-
assert open_connections.get(database) is connection
648+
created = asyncio.Event()
649+
650+
async def check_child_connection(database: Database):
651+
# neither .disconnect nor .__aexit__ are called before deleting this task
652+
database.connection()
653+
created.set()
612654

613-
# neither .disconnect nor .__aexit__ are called before deleting the reference
614-
del database
655+
task = asyncio.create_task(check_child_connection(database))
656+
await created.wait()
657+
assert task in database._connection_map
658+
await task
659+
del task
615660
gc.collect()
616661

617-
# Should have dropped reference to connection, even without proper cleanup
618-
open_connections = _ACTIVE_CONNECTIONS.get()
619-
assert isinstance(open_connections, MutableMapping)
620-
assert len(open_connections) == 0
662+
# Should not have a connection for the task anymore
663+
assert len(database._connection_map) == 0
621664

622665

623666
@pytest.mark.parametrize("database_url", DATABASE_URLS)
@@ -632,7 +675,6 @@ async def test_transaction_context_cleanup_contextmanager(database_url):
632675

633676
async with Database(database_url) as database:
634677
async with database.transaction() as transaction:
635-
636678
open_transactions = _ACTIVE_TRANSACTIONS.get()
637679
assert isinstance(open_transactions, MutableMapping)
638680
assert open_transactions.get(transaction) is transaction._transaction
@@ -818,17 +860,44 @@ async def insert_data(raise_exception):
818860
with pytest.raises(RuntimeError):
819861
await insert_data(raise_exception=True)
820862

821-
query = notes.select()
822-
results = await database.fetch_all(query=query)
863+
results = await database.fetch_all(query=notes.select())
823864
assert len(results) == 0
824865

825866
await insert_data(raise_exception=False)
826867

827-
query = notes.select()
828-
results = await database.fetch_all(query=query)
868+
results = await database.fetch_all(query=notes.select())
829869
assert len(results) == 1
830870

831871

872+
@pytest.mark.parametrize("database_url", DATABASE_URLS)
873+
@async_adapter
874+
async def test_transaction_decorator_concurrent(database_url):
875+
"""
876+
Ensure that @database.transaction() can be called concurrently.
877+
"""
878+
879+
database = Database(database_url)
880+
881+
@database.transaction()
882+
async def insert_data():
883+
await database.execute(
884+
query=notes.insert().values(text="example", completed=True)
885+
)
886+
887+
async with database:
888+
await asyncio.gather(
889+
insert_data(),
890+
insert_data(),
891+
insert_data(),
892+
insert_data(),
893+
insert_data(),
894+
insert_data(),
895+
)
896+
897+
results = await database.fetch_all(query=notes.select())
898+
assert len(results) == 6
899+
900+
832901
@pytest.mark.parametrize("database_url", DATABASE_URLS)
833902
@async_adapter
834903
async def test_datetime_field(database_url):
@@ -1007,7 +1076,7 @@ async def test_connection_context_same_task(database_url):
10071076

10081077
@pytest.mark.parametrize("database_url", DATABASE_URLS)
10091078
@async_adapter
1010-
async def test_connection_context_multiple_tasks(database_url):
1079+
async def test_connection_context_multiple_sibling_tasks(database_url):
10111080
async with Database(database_url) as database:
10121081
connection_1 = None
10131082
connection_2 = None
@@ -1037,6 +1106,47 @@ async def get_connection_2():
10371106
await task_2
10381107

10391108

1109+
@pytest.mark.parametrize("database_url", DATABASE_URLS)
1110+
@async_adapter
1111+
async def test_connection_context_multiple_tasks(database_url):
1112+
async with Database(database_url) as database:
1113+
parent_connection = database.connection()
1114+
connection_1 = None
1115+
connection_2 = None
1116+
task_1_ready = asyncio.Event()
1117+
task_2_ready = asyncio.Event()
1118+
test_complete = asyncio.Event()
1119+
1120+
async def get_connection_1():
1121+
nonlocal connection_1
1122+
1123+
async with database.connection() as connection:
1124+
connection_1 = connection
1125+
task_1_ready.set()
1126+
await test_complete.wait()
1127+
1128+
async def get_connection_2():
1129+
nonlocal connection_2
1130+
1131+
async with database.connection() as connection:
1132+
connection_2 = connection
1133+
task_2_ready.set()
1134+
await test_complete.wait()
1135+
1136+
task_1 = asyncio.create_task(get_connection_1())
1137+
task_2 = asyncio.create_task(get_connection_2())
1138+
await task_1_ready.wait()
1139+
await task_2_ready.wait()
1140+
1141+
assert connection_1 is not parent_connection
1142+
assert connection_2 is not parent_connection
1143+
assert connection_1 is not connection_2
1144+
1145+
test_complete.set()
1146+
await task_1
1147+
await task_2
1148+
1149+
10401150
@pytest.mark.parametrize(
10411151
"database_url1,database_url2",
10421152
(

0 commit comments

Comments
 (0)
Please sign in to comment.