Skip to content

Commit b94f097

Browse files
committed
fix: remove connection inheritance, add more tests, update docs
Connections are once again stored as state on the Database instance, keyed by the current asyncio.Task. Each task acquires it's own connection, and a WeakKeyDictionary allows the connection to be discarded if the owning task is garbage collected. TransactionBackends are still stored as contextvars, and a connection must be explicitly provided to descendant tasks if active transaction state is to be inherited.
1 parent 6de4f60 commit b94f097

File tree

3 files changed

+203
-90
lines changed

3 files changed

+203
-90
lines changed

databases/core.py

+24-20
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,9 @@
3636
logger = logging.getLogger("databases")
3737

3838

39-
_ACTIVE_CONNECTIONS: ContextVar[
40-
typing.Optional["weakref.WeakKeyDictionary['Database', 'Connection']"]
41-
] = ContextVar("databases:open_connections", default=None)
4239
_ACTIVE_TRANSACTIONS: ContextVar[
4340
typing.Optional["weakref.WeakKeyDictionary['Transaction', 'TransactionBackend']"]
44-
] = ContextVar("databases:open_transactions", default=None)
41+
] = ContextVar("databases:active_transactions", default=None)
4542

4643

4744
class Database:
@@ -54,6 +51,8 @@ class Database:
5451
"sqlite": "databases.backends.sqlite:SQLiteBackend",
5552
}
5653

54+
_connection_map: "weakref.WeakKeyDictionary[asyncio.Task, 'Connection']"
55+
5756
def __init__(
5857
self,
5958
url: typing.Union[str, "DatabaseURL"],
@@ -64,6 +63,7 @@ def __init__(
6463
self.url = DatabaseURL(url)
6564
self.options = options
6665
self.is_connected = False
66+
self._connection_map = weakref.WeakKeyDictionary()
6767

6868
self._force_rollback = force_rollback
6969

@@ -78,28 +78,28 @@ def __init__(
7878
self._global_transaction: typing.Optional[Transaction] = None
7979

8080
@property
81-
def _connection(self) -> typing.Optional["Connection"]:
82-
connections = _ACTIVE_CONNECTIONS.get()
83-
if connections is None:
84-
return None
81+
def _current_task(self) -> asyncio.Task:
82+
task = asyncio.current_task()
83+
if not task:
84+
raise RuntimeError("No currently active asyncio.Task found")
85+
return task
8586

86-
return connections.get(self, None)
87+
@property
88+
def _connection(self) -> typing.Optional["Connection"]:
89+
return self._connection_map.get(self._current_task)
8790

8891
@_connection.setter
8992
def _connection(
9093
self, connection: typing.Optional["Connection"]
9194
) -> typing.Optional["Connection"]:
92-
connections = _ACTIVE_CONNECTIONS.get()
93-
if connections is None:
94-
connections = weakref.WeakKeyDictionary()
95-
_ACTIVE_CONNECTIONS.set(connections)
95+
task = self._current_task
9696

9797
if connection is None:
98-
connections.pop(self, None)
98+
self._connection_map.pop(task, None)
9999
else:
100-
connections[self] = connection
100+
self._connection_map[task] = connection
101101

102-
return connections.get(self, None)
102+
return self._connection
103103

104104
async def connect(self) -> None:
105105
"""
@@ -119,7 +119,7 @@ async def connect(self) -> None:
119119
assert self._global_connection is None
120120
assert self._global_transaction is None
121121

122-
self._global_connection = Connection(self._backend)
122+
self._global_connection = Connection(self, self._backend)
123123
self._global_transaction = self._global_connection.transaction(
124124
force_rollback=True
125125
)
@@ -218,7 +218,7 @@ def connection(self) -> "Connection":
218218
return self._global_connection
219219

220220
if not self._connection:
221-
self._connection = Connection(self._backend)
221+
self._connection = Connection(self, self._backend)
222222

223223
return self._connection
224224

@@ -243,7 +243,8 @@ def _get_backend(self) -> str:
243243

244244

245245
class Connection:
246-
def __init__(self, backend: DatabaseBackend) -> None:
246+
def __init__(self, database: Database, backend: DatabaseBackend) -> None:
247+
self._database = database
247248
self._backend = backend
248249

249250
self._connection_lock = asyncio.Lock()
@@ -277,6 +278,7 @@ async def __aexit__(
277278
self._connection_counter -= 1
278279
if self._connection_counter == 0:
279280
await self._connection.release()
281+
self._database._connection = None
280282

281283
async def fetch_all(
282284
self,
@@ -393,13 +395,15 @@ def _transaction(
393395
transactions = _ACTIVE_TRANSACTIONS.get()
394396
if transactions is None:
395397
transactions = weakref.WeakKeyDictionary()
396-
_ACTIVE_TRANSACTIONS.set(transactions)
398+
else:
399+
transactions = transactions.copy()
397400

398401
if transaction is None:
399402
transactions.pop(self, None)
400403
else:
401404
transactions[self] = transaction
402405

406+
_ACTIVE_TRANSACTIONS.set(transactions)
403407
return transactions.get(self, None)
404408

405409
async def __aenter__(self) -> "Transaction":

docs/connections_and_transactions.md

+11-12
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@ that transparently handles the use of either transactions or savepoints.
77

88
## Connecting and disconnecting
99

10-
You can control the database connect/disconnect, by using it as a async context manager.
10+
You can control the database connection pool with an async context manager:
1111

1212
```python
1313
async with Database(DATABASE_URL) as database:
1414
...
1515
```
1616

17-
Or by using explicit connection and disconnection:
17+
Or by using the explicit `.connect()` and `.disconnect()` methods:
1818

1919
```python
2020
database = Database(DATABASE_URL)
@@ -23,6 +23,8 @@ await database.connect()
2323
await database.disconnect()
2424
```
2525

26+
Connections within this connection pool are acquired for each new `asyncio.Task`.
27+
2628
If you're integrating against a web framework, then you'll probably want
2729
to hook into framework startup or shutdown events. For example, with
2830
[Starlette][starlette] you would use the following:
@@ -96,12 +98,13 @@ async def create_users(request):
9698
...
9799
```
98100

99-
Transaction state is stored in the context of the currently executing asynchronous task.
100-
This state is _inherited_ by tasks that are started from within an active transaction:
101+
Transaction state is tied to the connection used in the currently executing asynchronous task.
102+
If you would like to influence an active transaction from another task, the connection must be
103+
shared. This state is _inherited_ by tasks that are share the same connection:
101104

102105
```python
103-
async def add_excitement(database: Database, id: int):
104-
await database.execute(
106+
async def add_excitement(connnection: databases.core.Connection, id: int):
107+
await connection.execute(
105108
"UPDATE notes SET text = CONCAT(text, '!!!') WHERE id = :id",
106109
{"id": id}
107110
)
@@ -113,17 +116,13 @@ async with Database(database_url) as database:
113116
await database.execute(
114117
"INSERT INTO notes(id, text) values (1, 'databases is cool')"
115118
)
116-
# ...but child tasks inherit transaction state!
117-
await asyncio.create_task(add_excitement(database, id=1))
119+
# ...but child tasks can use this connection now!
120+
await asyncio.create_task(add_excitement(database.connection(), id=1))
118121

119122
await database.fetch_val("SELECT text FROM notes WHERE id=1")
120123
# ^ returns: "databases is cool!!!"
121124
```
122125

123-
!!! note
124-
In python 3.11, you can opt-out of context propagation by providing a new context to
125-
[`asyncio.create_task`](https://docs.python.org/3.11/library/asyncio-task.html#creating-tasks).
126-
127126
Nested transactions are fully supported, and are implemented using database savepoints:
128127

129128
```python

0 commit comments

Comments
 (0)