Skip to content

Commit 25fa295

Browse files
zevisertzanieb
andauthored
fix: incorrect concurrent usage of connection and transaction (#546)
* fix: incorrect concurrent usage of connection and transaction * refactor: rename contextvar class attributes, add some explaination comments * fix: contextvar.get takes no keyword arguments * test: add concurrent task tests * feat: use ContextVar[dict] to track connections and transactions per task * test: check multiple databases in the same task use independant connections * chore: changes for linting and typechecking * chore: use typing.Tuple for lower python version compatibility * docs: update comment on _connection_contextmap * Update `Connection` and `Transaction` to be robust to concurrent use * chore: remove optional annotation on asyncio.Task * test: add new tests for upcoming contextvar inheritance/isolation and weakref cleanup * feat: reimplement concurrency system with contextvar and weakmap * chore: apply corrections from linters * fix: quote WeakKeyDictionary typing for python<=3.7 * docs: add examples for async transaction context and nested transactions * fix: remove connection inheritance, add more tests, update docs Connections are once again stored as state on the Database instance, keyed by the current asyncio.Task. Each task acquires it's own connection, and a WeakKeyDictionary allows the connection to be discarded if the owning task is garbage collected. TransactionBackends are still stored as contextvars, and a connection must be explicitly provided to descendant tasks if active transaction state is to be inherited. --------- Co-authored-by: Zanie <[email protected]>
1 parent c095428 commit 25fa295

File tree

3 files changed

+521
-35
lines changed

3 files changed

+521
-35
lines changed

databases/core.py

+78-14
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import functools
44
import logging
55
import typing
6+
import weakref
67
from contextvars import ContextVar
78
from types import TracebackType
89
from urllib.parse import SplitResult, parse_qsl, unquote, urlsplit
@@ -11,7 +12,7 @@
1112
from sqlalchemy.sql import ClauseElement
1213

1314
from databases.importer import import_from_string
14-
from databases.interfaces import DatabaseBackend, Record
15+
from databases.interfaces import DatabaseBackend, Record, TransactionBackend
1516

1617
try: # pragma: no cover
1718
import click
@@ -35,6 +36,11 @@
3536
logger = logging.getLogger("databases")
3637

3738

39+
_ACTIVE_TRANSACTIONS: ContextVar[
40+
typing.Optional["weakref.WeakKeyDictionary['Transaction', 'TransactionBackend']"]
41+
] = ContextVar("databases:active_transactions", default=None)
42+
43+
3844
class Database:
3945
SUPPORTED_BACKENDS = {
4046
"postgresql": "databases.backends.postgres:PostgresBackend",
@@ -45,6 +51,8 @@ class Database:
4551
"sqlite": "databases.backends.sqlite:SQLiteBackend",
4652
}
4753

54+
_connection_map: "weakref.WeakKeyDictionary[asyncio.Task, 'Connection']"
55+
4856
def __init__(
4957
self,
5058
url: typing.Union[str, "DatabaseURL"],
@@ -55,6 +63,7 @@ def __init__(
5563
self.url = DatabaseURL(url)
5664
self.options = options
5765
self.is_connected = False
66+
self._connection_map = weakref.WeakKeyDictionary()
5867

5968
self._force_rollback = force_rollback
6069

@@ -63,14 +72,35 @@ def __init__(
6372
assert issubclass(backend_cls, DatabaseBackend)
6473
self._backend = backend_cls(self.url, **self.options)
6574

66-
# Connections are stored as task-local state.
67-
self._connection_context: ContextVar = ContextVar("connection_context")
68-
6975
# When `force_rollback=True` is used, we use a single global
7076
# connection, within a transaction that always rolls back.
7177
self._global_connection: typing.Optional[Connection] = None
7278
self._global_transaction: typing.Optional[Transaction] = None
7379

80+
@property
81+
def _current_task(self) -> asyncio.Task:
82+
task = asyncio.current_task()
83+
if not task:
84+
raise RuntimeError("No currently active asyncio.Task found")
85+
return task
86+
87+
@property
88+
def _connection(self) -> typing.Optional["Connection"]:
89+
return self._connection_map.get(self._current_task)
90+
91+
@_connection.setter
92+
def _connection(
93+
self, connection: typing.Optional["Connection"]
94+
) -> typing.Optional["Connection"]:
95+
task = self._current_task
96+
97+
if connection is None:
98+
self._connection_map.pop(task, None)
99+
else:
100+
self._connection_map[task] = connection
101+
102+
return self._connection
103+
74104
async def connect(self) -> None:
75105
"""
76106
Establish the connection pool.
@@ -89,7 +119,7 @@ async def connect(self) -> None:
89119
assert self._global_connection is None
90120
assert self._global_transaction is None
91121

92-
self._global_connection = Connection(self._backend)
122+
self._global_connection = Connection(self, self._backend)
93123
self._global_transaction = self._global_connection.transaction(
94124
force_rollback=True
95125
)
@@ -113,7 +143,7 @@ async def disconnect(self) -> None:
113143
self._global_transaction = None
114144
self._global_connection = None
115145
else:
116-
self._connection_context = ContextVar("connection_context")
146+
self._connection = None
117147

118148
await self._backend.disconnect()
119149
logger.info(
@@ -187,12 +217,10 @@ def connection(self) -> "Connection":
187217
if self._global_connection is not None:
188218
return self._global_connection
189219

190-
try:
191-
return self._connection_context.get()
192-
except LookupError:
193-
connection = Connection(self._backend)
194-
self._connection_context.set(connection)
195-
return connection
220+
if not self._connection:
221+
self._connection = Connection(self, self._backend)
222+
223+
return self._connection
196224

197225
def transaction(
198226
self, *, force_rollback: bool = False, **kwargs: typing.Any
@@ -215,7 +243,8 @@ def _get_backend(self) -> str:
215243

216244

217245
class Connection:
218-
def __init__(self, backend: DatabaseBackend) -> None:
246+
def __init__(self, database: Database, backend: DatabaseBackend) -> None:
247+
self._database = database
219248
self._backend = backend
220249

221250
self._connection_lock = asyncio.Lock()
@@ -249,6 +278,7 @@ async def __aexit__(
249278
self._connection_counter -= 1
250279
if self._connection_counter == 0:
251280
await self._connection.release()
281+
self._database._connection = None
252282

253283
async def fetch_all(
254284
self,
@@ -345,6 +375,37 @@ def __init__(
345375
self._force_rollback = force_rollback
346376
self._extra_options = kwargs
347377

378+
@property
379+
def _connection(self) -> "Connection":
380+
# Returns the same connection if called multiple times
381+
return self._connection_callable()
382+
383+
@property
384+
def _transaction(self) -> typing.Optional["TransactionBackend"]:
385+
transactions = _ACTIVE_TRANSACTIONS.get()
386+
if transactions is None:
387+
return None
388+
389+
return transactions.get(self, None)
390+
391+
@_transaction.setter
392+
def _transaction(
393+
self, transaction: typing.Optional["TransactionBackend"]
394+
) -> typing.Optional["TransactionBackend"]:
395+
transactions = _ACTIVE_TRANSACTIONS.get()
396+
if transactions is None:
397+
transactions = weakref.WeakKeyDictionary()
398+
else:
399+
transactions = transactions.copy()
400+
401+
if transaction is None:
402+
transactions.pop(self, None)
403+
else:
404+
transactions[self] = transaction
405+
406+
_ACTIVE_TRANSACTIONS.set(transactions)
407+
return transactions.get(self, None)
408+
348409
async def __aenter__(self) -> "Transaction":
349410
"""
350411
Called when entering `async with database.transaction()`
@@ -385,7 +446,6 @@ async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
385446
return wrapper # type: ignore
386447

387448
async def start(self) -> "Transaction":
388-
self._connection = self._connection_callable()
389449
self._transaction = self._connection._connection.transaction()
390450

391451
async with self._connection._transaction_lock:
@@ -401,15 +461,19 @@ async def commit(self) -> None:
401461
async with self._connection._transaction_lock:
402462
assert self._connection._transaction_stack[-1] is self
403463
self._connection._transaction_stack.pop()
464+
assert self._transaction is not None
404465
await self._transaction.commit()
405466
await self._connection.__aexit__()
467+
self._transaction = None
406468

407469
async def rollback(self) -> None:
408470
async with self._connection._transaction_lock:
409471
assert self._connection._transaction_stack[-1] is self
410472
self._connection._transaction_stack.pop()
473+
assert self._transaction is not None
411474
await self._transaction.rollback()
412475
await self._connection.__aexit__()
476+
self._transaction = None
413477

414478

415479
class _EmptyNetloc(str):

docs/connections_and_transactions.md

+50-4
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@ that transparently handles the use of either transactions or savepoints.
77

88
## Connecting and disconnecting
99

10-
You can control the database connect/disconnect, by using it as a async context manager.
10+
You can control the database connection pool with an async context manager:
1111

1212
```python
1313
async with Database(DATABASE_URL) as database:
1414
...
1515
```
1616

17-
Or by using explicit connection and disconnection:
17+
Or by using the explicit `.connect()` and `.disconnect()` methods:
1818

1919
```python
2020
database = Database(DATABASE_URL)
@@ -23,6 +23,8 @@ await database.connect()
2323
await database.disconnect()
2424
```
2525

26+
Connections within this connection pool are acquired for each new `asyncio.Task`.
27+
2628
If you're integrating against a web framework, then you'll probably want
2729
to hook into framework startup or shutdown events. For example, with
2830
[Starlette][starlette] you would use the following:
@@ -67,6 +69,7 @@ A transaction can be acquired from the database connection pool:
6769
async with database.transaction():
6870
...
6971
```
72+
7073
It can also be acquired from a specific database connection:
7174

7275
```python
@@ -95,8 +98,51 @@ async def create_users(request):
9598
...
9699
```
97100

98-
Transaction blocks are managed as task-local state. Nested transactions
99-
are fully supported, and are implemented using database savepoints.
101+
Transaction state is tied to the connection used in the currently executing asynchronous task.
102+
If you would like to influence an active transaction from another task, the connection must be
103+
shared. This state is _inherited_ by tasks that are share the same connection:
104+
105+
```python
106+
async def add_excitement(connnection: databases.core.Connection, id: int):
107+
await connection.execute(
108+
"UPDATE notes SET text = CONCAT(text, '!!!') WHERE id = :id",
109+
{"id": id}
110+
)
111+
112+
113+
async with Database(database_url) as database:
114+
async with database.transaction():
115+
# This note won't exist until the transaction closes...
116+
await database.execute(
117+
"INSERT INTO notes(id, text) values (1, 'databases is cool')"
118+
)
119+
# ...but child tasks can use this connection now!
120+
await asyncio.create_task(add_excitement(database.connection(), id=1))
121+
122+
await database.fetch_val("SELECT text FROM notes WHERE id=1")
123+
# ^ returns: "databases is cool!!!"
124+
```
125+
126+
Nested transactions are fully supported, and are implemented using database savepoints:
127+
128+
```python
129+
async with databases.Database(database_url) as db:
130+
async with db.transaction() as outer:
131+
# Do something in the outer transaction
132+
...
133+
134+
# Suppress to prevent influence on the outer transaction
135+
with contextlib.suppress(ValueError):
136+
async with db.transaction():
137+
# Do something in the inner transaction
138+
...
139+
140+
raise ValueError('Abort the inner transaction')
141+
142+
# Observe the results of the outer transaction,
143+
# without effects from the inner transaction.
144+
await db.fetch_all('SELECT * FROM ...')
145+
```
100146

101147
Transaction isolation-level can be specified if the driver backend supports that:
102148

0 commit comments

Comments
 (0)