Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 25fa295

Browse files
zevisertzanieb
andauthoredJul 25, 2023
fix: incorrect concurrent usage of connection and transaction (encode#546)
* fix: incorrect concurrent usage of connection and transaction * refactor: rename contextvar class attributes, add some explaination comments * fix: contextvar.get takes no keyword arguments * test: add concurrent task tests * feat: use ContextVar[dict] to track connections and transactions per task * test: check multiple databases in the same task use independant connections * chore: changes for linting and typechecking * chore: use typing.Tuple for lower python version compatibility * docs: update comment on _connection_contextmap * Update `Connection` and `Transaction` to be robust to concurrent use * chore: remove optional annotation on asyncio.Task * test: add new tests for upcoming contextvar inheritance/isolation and weakref cleanup * feat: reimplement concurrency system with contextvar and weakmap * chore: apply corrections from linters * fix: quote WeakKeyDictionary typing for python<=3.7 * docs: add examples for async transaction context and nested transactions * fix: remove connection inheritance, add more tests, update docs Connections are once again stored as state on the Database instance, keyed by the current asyncio.Task. Each task acquires it's own connection, and a WeakKeyDictionary allows the connection to be discarded if the owning task is garbage collected. TransactionBackends are still stored as contextvars, and a connection must be explicitly provided to descendant tasks if active transaction state is to be inherited. --------- Co-authored-by: Zanie <[email protected]>
1 parent c095428 commit 25fa295

File tree

3 files changed

+521
-35
lines changed

3 files changed

+521
-35
lines changed
 

‎databases/core.py

+78-14
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import functools
44
import logging
55
import typing
6+
import weakref
67
from contextvars import ContextVar
78
from types import TracebackType
89
from urllib.parse import SplitResult, parse_qsl, unquote, urlsplit
@@ -11,7 +12,7 @@
1112
from sqlalchemy.sql import ClauseElement
1213

1314
from databases.importer import import_from_string
14-
from databases.interfaces import DatabaseBackend, Record
15+
from databases.interfaces import DatabaseBackend, Record, TransactionBackend
1516

1617
try: # pragma: no cover
1718
import click
@@ -35,6 +36,11 @@
3536
logger = logging.getLogger("databases")
3637

3738

39+
_ACTIVE_TRANSACTIONS: ContextVar[
40+
typing.Optional["weakref.WeakKeyDictionary['Transaction', 'TransactionBackend']"]
41+
] = ContextVar("databases:active_transactions", default=None)
42+
43+
3844
class Database:
3945
SUPPORTED_BACKENDS = {
4046
"postgresql": "databases.backends.postgres:PostgresBackend",
@@ -45,6 +51,8 @@ class Database:
4551
"sqlite": "databases.backends.sqlite:SQLiteBackend",
4652
}
4753

54+
_connection_map: "weakref.WeakKeyDictionary[asyncio.Task, 'Connection']"
55+
4856
def __init__(
4957
self,
5058
url: typing.Union[str, "DatabaseURL"],
@@ -55,6 +63,7 @@ def __init__(
5563
self.url = DatabaseURL(url)
5664
self.options = options
5765
self.is_connected = False
66+
self._connection_map = weakref.WeakKeyDictionary()
5867

5968
self._force_rollback = force_rollback
6069

@@ -63,14 +72,35 @@ def __init__(
6372
assert issubclass(backend_cls, DatabaseBackend)
6473
self._backend = backend_cls(self.url, **self.options)
6574

66-
# Connections are stored as task-local state.
67-
self._connection_context: ContextVar = ContextVar("connection_context")
68-
6975
# When `force_rollback=True` is used, we use a single global
7076
# connection, within a transaction that always rolls back.
7177
self._global_connection: typing.Optional[Connection] = None
7278
self._global_transaction: typing.Optional[Transaction] = None
7379

80+
@property
81+
def _current_task(self) -> asyncio.Task:
82+
task = asyncio.current_task()
83+
if not task:
84+
raise RuntimeError("No currently active asyncio.Task found")
85+
return task
86+
87+
@property
88+
def _connection(self) -> typing.Optional["Connection"]:
89+
return self._connection_map.get(self._current_task)
90+
91+
@_connection.setter
92+
def _connection(
93+
self, connection: typing.Optional["Connection"]
94+
) -> typing.Optional["Connection"]:
95+
task = self._current_task
96+
97+
if connection is None:
98+
self._connection_map.pop(task, None)
99+
else:
100+
self._connection_map[task] = connection
101+
102+
return self._connection
103+
74104
async def connect(self) -> None:
75105
"""
76106
Establish the connection pool.
@@ -89,7 +119,7 @@ async def connect(self) -> None:
89119
assert self._global_connection is None
90120
assert self._global_transaction is None
91121

92-
self._global_connection = Connection(self._backend)
122+
self._global_connection = Connection(self, self._backend)
93123
self._global_transaction = self._global_connection.transaction(
94124
force_rollback=True
95125
)
@@ -113,7 +143,7 @@ async def disconnect(self) -> None:
113143
self._global_transaction = None
114144
self._global_connection = None
115145
else:
116-
self._connection_context = ContextVar("connection_context")
146+
self._connection = None
117147

118148
await self._backend.disconnect()
119149
logger.info(
@@ -187,12 +217,10 @@ def connection(self) -> "Connection":
187217
if self._global_connection is not None:
188218
return self._global_connection
189219

190-
try:
191-
return self._connection_context.get()
192-
except LookupError:
193-
connection = Connection(self._backend)
194-
self._connection_context.set(connection)
195-
return connection
220+
if not self._connection:
221+
self._connection = Connection(self, self._backend)
222+
223+
return self._connection
196224

197225
def transaction(
198226
self, *, force_rollback: bool = False, **kwargs: typing.Any
@@ -215,7 +243,8 @@ def _get_backend(self) -> str:
215243

216244

217245
class Connection:
218-
def __init__(self, backend: DatabaseBackend) -> None:
246+
def __init__(self, database: Database, backend: DatabaseBackend) -> None:
247+
self._database = database
219248
self._backend = backend
220249

221250
self._connection_lock = asyncio.Lock()
@@ -249,6 +278,7 @@ async def __aexit__(
249278
self._connection_counter -= 1
250279
if self._connection_counter == 0:
251280
await self._connection.release()
281+
self._database._connection = None
252282

253283
async def fetch_all(
254284
self,
@@ -345,6 +375,37 @@ def __init__(
345375
self._force_rollback = force_rollback
346376
self._extra_options = kwargs
347377

378+
@property
379+
def _connection(self) -> "Connection":
380+
# Returns the same connection if called multiple times
381+
return self._connection_callable()
382+
383+
@property
384+
def _transaction(self) -> typing.Optional["TransactionBackend"]:
385+
transactions = _ACTIVE_TRANSACTIONS.get()
386+
if transactions is None:
387+
return None
388+
389+
return transactions.get(self, None)
390+
391+
@_transaction.setter
392+
def _transaction(
393+
self, transaction: typing.Optional["TransactionBackend"]
394+
) -> typing.Optional["TransactionBackend"]:
395+
transactions = _ACTIVE_TRANSACTIONS.get()
396+
if transactions is None:
397+
transactions = weakref.WeakKeyDictionary()
398+
else:
399+
transactions = transactions.copy()
400+
401+
if transaction is None:
402+
transactions.pop(self, None)
403+
else:
404+
transactions[self] = transaction
405+
406+
_ACTIVE_TRANSACTIONS.set(transactions)
407+
return transactions.get(self, None)
408+
348409
async def __aenter__(self) -> "Transaction":
349410
"""
350411
Called when entering `async with database.transaction()`
@@ -385,7 +446,6 @@ async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
385446
return wrapper # type: ignore
386447

387448
async def start(self) -> "Transaction":
388-
self._connection = self._connection_callable()
389449
self._transaction = self._connection._connection.transaction()
390450

391451
async with self._connection._transaction_lock:
@@ -401,15 +461,19 @@ async def commit(self) -> None:
401461
async with self._connection._transaction_lock:
402462
assert self._connection._transaction_stack[-1] is self
403463
self._connection._transaction_stack.pop()
464+
assert self._transaction is not None
404465
await self._transaction.commit()
405466
await self._connection.__aexit__()
467+
self._transaction = None
406468

407469
async def rollback(self) -> None:
408470
async with self._connection._transaction_lock:
409471
assert self._connection._transaction_stack[-1] is self
410472
self._connection._transaction_stack.pop()
473+
assert self._transaction is not None
411474
await self._transaction.rollback()
412475
await self._connection.__aexit__()
476+
self._transaction = None
413477

414478

415479
class _EmptyNetloc(str):

‎docs/connections_and_transactions.md

+50-4
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@ that transparently handles the use of either transactions or savepoints.
77

88
## Connecting and disconnecting
99

10-
You can control the database connect/disconnect, by using it as a async context manager.
10+
You can control the database connection pool with an async context manager:
1111

1212
```python
1313
async with Database(DATABASE_URL) as database:
1414
...
1515
```
1616

17-
Or by using explicit connection and disconnection:
17+
Or by using the explicit `.connect()` and `.disconnect()` methods:
1818

1919
```python
2020
database = Database(DATABASE_URL)
@@ -23,6 +23,8 @@ await database.connect()
2323
await database.disconnect()
2424
```
2525

26+
Connections within this connection pool are acquired for each new `asyncio.Task`.
27+
2628
If you're integrating against a web framework, then you'll probably want
2729
to hook into framework startup or shutdown events. For example, with
2830
[Starlette][starlette] you would use the following:
@@ -67,6 +69,7 @@ A transaction can be acquired from the database connection pool:
6769
async with database.transaction():
6870
...
6971
```
72+
7073
It can also be acquired from a specific database connection:
7174

7275
```python
@@ -95,8 +98,51 @@ async def create_users(request):
9598
...
9699
```
97100

98-
Transaction blocks are managed as task-local state. Nested transactions
99-
are fully supported, and are implemented using database savepoints.
101+
Transaction state is tied to the connection used in the currently executing asynchronous task.
102+
If you would like to influence an active transaction from another task, the connection must be
103+
shared. This state is _inherited_ by tasks that are share the same connection:
104+
105+
```python
106+
async def add_excitement(connnection: databases.core.Connection, id: int):
107+
await connection.execute(
108+
"UPDATE notes SET text = CONCAT(text, '!!!') WHERE id = :id",
109+
{"id": id}
110+
)
111+
112+
113+
async with Database(database_url) as database:
114+
async with database.transaction():
115+
# This note won't exist until the transaction closes...
116+
await database.execute(
117+
"INSERT INTO notes(id, text) values (1, 'databases is cool')"
118+
)
119+
# ...but child tasks can use this connection now!
120+
await asyncio.create_task(add_excitement(database.connection(), id=1))
121+
122+
await database.fetch_val("SELECT text FROM notes WHERE id=1")
123+
# ^ returns: "databases is cool!!!"
124+
```
125+
126+
Nested transactions are fully supported, and are implemented using database savepoints:
127+
128+
```python
129+
async with databases.Database(database_url) as db:
130+
async with db.transaction() as outer:
131+
# Do something in the outer transaction
132+
...
133+
134+
# Suppress to prevent influence on the outer transaction
135+
with contextlib.suppress(ValueError):
136+
async with db.transaction():
137+
# Do something in the inner transaction
138+
...
139+
140+
raise ValueError('Abort the inner transaction')
141+
142+
# Observe the results of the outer transaction,
143+
# without effects from the inner transaction.
144+
await db.fetch_all('SELECT * FROM ...')
145+
```
100146

101147
Transaction isolation-level can be specified if the driver backend supports that:
102148

‎tests/test_databases.py

+393-17
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@
22
import datetime
33
import decimal
44
import functools
5+
import gc
6+
import itertools
57
import os
68
import re
9+
from typing import MutableMapping
710
from unittest.mock import MagicMock, patch
811

912
import pytest
@@ -477,6 +480,254 @@ async def test_transaction_commit(database_url):
477480
assert len(results) == 1
478481

479482

483+
@pytest.mark.parametrize("database_url", DATABASE_URLS)
484+
@async_adapter
485+
async def test_transaction_context_child_task_inheritance(database_url):
486+
"""
487+
Ensure that transactions are inherited by child tasks.
488+
"""
489+
async with Database(database_url) as database:
490+
491+
async def check_transaction(transaction, active_transaction):
492+
# Should have inherited the same transaction backend from the parent task
493+
assert transaction._transaction is active_transaction
494+
495+
async with database.transaction() as transaction:
496+
await asyncio.create_task(
497+
check_transaction(transaction, transaction._transaction)
498+
)
499+
500+
501+
@pytest.mark.parametrize("database_url", DATABASE_URLS)
502+
@async_adapter
503+
async def test_transaction_context_child_task_inheritance_example(database_url):
504+
"""
505+
Ensure that child tasks may influence inherited transactions.
506+
"""
507+
# This is an practical example of the above test.
508+
async with Database(database_url) as database:
509+
async with database.transaction():
510+
# Create a note
511+
await database.execute(
512+
notes.insert().values(id=1, text="setup", completed=True)
513+
)
514+
515+
# Change the note from the same task
516+
await database.execute(
517+
notes.update().where(notes.c.id == 1).values(text="prior")
518+
)
519+
520+
# Confirm the change
521+
result = await database.fetch_one(notes.select().where(notes.c.id == 1))
522+
assert result.text == "prior"
523+
524+
async def run_update_from_child_task(connection):
525+
# Change the note from a child task
526+
await connection.execute(
527+
notes.update().where(notes.c.id == 1).values(text="test")
528+
)
529+
530+
await asyncio.create_task(run_update_from_child_task(database.connection()))
531+
532+
# Confirm the child's change
533+
result = await database.fetch_one(notes.select().where(notes.c.id == 1))
534+
assert result.text == "test"
535+
536+
537+
@pytest.mark.parametrize("database_url", DATABASE_URLS)
538+
@async_adapter
539+
async def test_transaction_context_sibling_task_isolation(database_url):
540+
"""
541+
Ensure that transactions are isolated between sibling tasks.
542+
"""
543+
start = asyncio.Event()
544+
end = asyncio.Event()
545+
546+
async with Database(database_url) as database:
547+
548+
async def check_transaction(transaction):
549+
await start.wait()
550+
# Parent task is now in a transaction, we should not
551+
# see its transaction backend since this task was
552+
# _started_ in a context where no transaction was active.
553+
assert transaction._transaction is None
554+
end.set()
555+
556+
transaction = database.transaction()
557+
assert transaction._transaction is None
558+
task = asyncio.create_task(check_transaction(transaction))
559+
560+
async with transaction:
561+
start.set()
562+
assert transaction._transaction is not None
563+
await end.wait()
564+
565+
# Cleanup for "Task not awaited" warning
566+
await task
567+
568+
569+
@pytest.mark.parametrize("database_url", DATABASE_URLS)
570+
@async_adapter
571+
async def test_transaction_context_sibling_task_isolation_example(database_url):
572+
"""
573+
Ensure that transactions are running in sibling tasks are isolated from eachother.
574+
"""
575+
# This is an practical example of the above test.
576+
setup = asyncio.Event()
577+
done = asyncio.Event()
578+
579+
async def tx1(connection):
580+
async with connection.transaction():
581+
await db.execute(
582+
notes.insert(), values={"id": 1, "text": "tx1", "completed": False}
583+
)
584+
setup.set()
585+
await done.wait()
586+
587+
async def tx2(connection):
588+
async with connection.transaction():
589+
await setup.wait()
590+
result = await db.fetch_all(notes.select())
591+
assert result == [], result
592+
done.set()
593+
594+
async with Database(database_url) as db:
595+
await asyncio.gather(tx1(db), tx2(db))
596+
597+
598+
@pytest.mark.parametrize("database_url", DATABASE_URLS)
599+
@async_adapter
600+
async def test_connection_cleanup_contextmanager(database_url):
601+
"""
602+
Ensure that task connections are not persisted unecessarily.
603+
"""
604+
605+
ready = asyncio.Event()
606+
done = asyncio.Event()
607+
608+
async def check_child_connection(database: Database):
609+
async with database.connection():
610+
ready.set()
611+
await done.wait()
612+
613+
async with Database(database_url) as database:
614+
# Should have a connection in this task
615+
# .connect is lazy, it doesn't create a Connection, but .connection does
616+
connection = database.connection()
617+
assert isinstance(database._connection_map, MutableMapping)
618+
assert database._connection_map.get(asyncio.current_task()) is connection
619+
620+
# Create a child task and see if it registers a connection
621+
task = asyncio.create_task(check_child_connection(database))
622+
await ready.wait()
623+
assert database._connection_map.get(task) is not None
624+
assert database._connection_map.get(task) is not connection
625+
626+
# Let the child task finish, and see if it cleaned up
627+
done.set()
628+
await task
629+
# This is normal exit logic cleanup, the WeakKeyDictionary
630+
# shouldn't have cleaned up yet since the task is still referenced
631+
assert task not in database._connection_map
632+
633+
# Context manager closes, all open connections are removed
634+
assert isinstance(database._connection_map, MutableMapping)
635+
assert len(database._connection_map) == 0
636+
637+
638+
@pytest.mark.parametrize("database_url", DATABASE_URLS)
639+
@async_adapter
640+
async def test_connection_cleanup_garbagecollector(database_url):
641+
"""
642+
Ensure that connections for tasks are not persisted unecessarily, even
643+
if exit handlers are not called.
644+
"""
645+
database = Database(database_url)
646+
await database.connect()
647+
648+
created = asyncio.Event()
649+
650+
async def check_child_connection(database: Database):
651+
# neither .disconnect nor .__aexit__ are called before deleting this task
652+
database.connection()
653+
created.set()
654+
655+
task = asyncio.create_task(check_child_connection(database))
656+
await created.wait()
657+
assert task in database._connection_map
658+
await task
659+
del task
660+
gc.collect()
661+
662+
# Should not have a connection for the task anymore
663+
assert len(database._connection_map) == 0
664+
665+
666+
@pytest.mark.parametrize("database_url", DATABASE_URLS)
667+
@async_adapter
668+
async def test_transaction_context_cleanup_contextmanager(database_url):
669+
"""
670+
Ensure that contextvar transactions are not persisted unecessarily.
671+
"""
672+
from databases.core import _ACTIVE_TRANSACTIONS
673+
674+
assert _ACTIVE_TRANSACTIONS.get() is None
675+
676+
async with Database(database_url) as database:
677+
async with database.transaction() as transaction:
678+
open_transactions = _ACTIVE_TRANSACTIONS.get()
679+
assert isinstance(open_transactions, MutableMapping)
680+
assert open_transactions.get(transaction) is transaction._transaction
681+
682+
# Context manager closes, open_transactions is cleaned up
683+
open_transactions = _ACTIVE_TRANSACTIONS.get()
684+
assert isinstance(open_transactions, MutableMapping)
685+
assert open_transactions.get(transaction, None) is None
686+
687+
688+
@pytest.mark.parametrize("database_url", DATABASE_URLS)
689+
@async_adapter
690+
async def test_transaction_context_cleanup_garbagecollector(database_url):
691+
"""
692+
Ensure that contextvar transactions are not persisted unecessarily, even
693+
if exit handlers are not called.
694+
695+
This test should be an XFAIL, but cannot be due to the way that is hangs
696+
during teardown.
697+
"""
698+
from databases.core import _ACTIVE_TRANSACTIONS
699+
700+
assert _ACTIVE_TRANSACTIONS.get() is None
701+
702+
async with Database(database_url) as database:
703+
transaction = database.transaction()
704+
await transaction.start()
705+
706+
# Should be tracking the transaction
707+
open_transactions = _ACTIVE_TRANSACTIONS.get()
708+
assert isinstance(open_transactions, MutableMapping)
709+
assert open_transactions.get(transaction) is transaction._transaction
710+
711+
# neither .commit, .rollback, nor .__aexit__ are called
712+
del transaction
713+
gc.collect()
714+
715+
# TODO(zevisert,review): Could skip instead of using the logic below
716+
# A strong reference to the transaction is kept alive by the connection's
717+
# ._transaction_stack, so it is still be tracked at this point.
718+
assert len(open_transactions) == 1
719+
720+
# If that were magically cleared, the transaction would be cleaned up,
721+
# but as it stands this always causes a hang during teardown at
722+
# `Database(...).disconnect()` if the transaction is not closed.
723+
transaction = database.connection()._transaction_stack[-1]
724+
await transaction.rollback()
725+
del transaction
726+
727+
# Now with the transaction rolled-back, it should be cleaned up.
728+
assert len(open_transactions) == 0
729+
730+
480731
@pytest.mark.parametrize("database_url", DATABASE_URLS)
481732
@async_adapter
482733
async def test_transaction_commit_serializable(database_url):
@@ -609,17 +860,44 @@ async def insert_data(raise_exception):
609860
with pytest.raises(RuntimeError):
610861
await insert_data(raise_exception=True)
611862

612-
query = notes.select()
613-
results = await database.fetch_all(query=query)
863+
results = await database.fetch_all(query=notes.select())
614864
assert len(results) == 0
615865

616866
await insert_data(raise_exception=False)
617867

618-
query = notes.select()
619-
results = await database.fetch_all(query=query)
868+
results = await database.fetch_all(query=notes.select())
620869
assert len(results) == 1
621870

622871

872+
@pytest.mark.parametrize("database_url", DATABASE_URLS)
873+
@async_adapter
874+
async def test_transaction_decorator_concurrent(database_url):
875+
"""
876+
Ensure that @database.transaction() can be called concurrently.
877+
"""
878+
879+
database = Database(database_url)
880+
881+
@database.transaction()
882+
async def insert_data():
883+
await database.execute(
884+
query=notes.insert().values(text="example", completed=True)
885+
)
886+
887+
async with database:
888+
await asyncio.gather(
889+
insert_data(),
890+
insert_data(),
891+
insert_data(),
892+
insert_data(),
893+
insert_data(),
894+
insert_data(),
895+
)
896+
897+
results = await database.fetch_all(query=notes.select())
898+
assert len(results) == 6
899+
900+
623901
@pytest.mark.parametrize("database_url", DATABASE_URLS)
624902
@async_adapter
625903
async def test_datetime_field(database_url):
@@ -789,15 +1067,16 @@ async def test_connect_and_disconnect(database_url):
7891067

7901068
@pytest.mark.parametrize("database_url", DATABASE_URLS)
7911069
@async_adapter
792-
async def test_connection_context(database_url):
793-
"""
794-
Test connection contexts are task-local.
795-
"""
1070+
async def test_connection_context_same_task(database_url):
7961071
async with Database(database_url) as database:
7971072
async with database.connection() as connection_1:
7981073
async with database.connection() as connection_2:
7991074
assert connection_1 is connection_2
8001075

1076+
1077+
@pytest.mark.parametrize("database_url", DATABASE_URLS)
1078+
@async_adapter
1079+
async def test_connection_context_multiple_sibling_tasks(database_url):
8011080
async with Database(database_url) as database:
8021081
connection_1 = None
8031082
connection_2 = None
@@ -817,9 +1096,8 @@ async def get_connection_2():
8171096
connection_2 = connection
8181097
await test_complete.wait()
8191098

820-
loop = asyncio.get_event_loop()
821-
task_1 = loop.create_task(get_connection_1())
822-
task_2 = loop.create_task(get_connection_2())
1099+
task_1 = asyncio.create_task(get_connection_1())
1100+
task_2 = asyncio.create_task(get_connection_2())
8231101
while connection_1 is None or connection_2 is None:
8241102
await asyncio.sleep(0.000001)
8251103
assert connection_1 is not connection_2
@@ -828,6 +1106,61 @@ async def get_connection_2():
8281106
await task_2
8291107

8301108

1109+
@pytest.mark.parametrize("database_url", DATABASE_URLS)
1110+
@async_adapter
1111+
async def test_connection_context_multiple_tasks(database_url):
1112+
async with Database(database_url) as database:
1113+
parent_connection = database.connection()
1114+
connection_1 = None
1115+
connection_2 = None
1116+
task_1_ready = asyncio.Event()
1117+
task_2_ready = asyncio.Event()
1118+
test_complete = asyncio.Event()
1119+
1120+
async def get_connection_1():
1121+
nonlocal connection_1
1122+
1123+
async with database.connection() as connection:
1124+
connection_1 = connection
1125+
task_1_ready.set()
1126+
await test_complete.wait()
1127+
1128+
async def get_connection_2():
1129+
nonlocal connection_2
1130+
1131+
async with database.connection() as connection:
1132+
connection_2 = connection
1133+
task_2_ready.set()
1134+
await test_complete.wait()
1135+
1136+
task_1 = asyncio.create_task(get_connection_1())
1137+
task_2 = asyncio.create_task(get_connection_2())
1138+
await task_1_ready.wait()
1139+
await task_2_ready.wait()
1140+
1141+
assert connection_1 is not parent_connection
1142+
assert connection_2 is not parent_connection
1143+
assert connection_1 is not connection_2
1144+
1145+
test_complete.set()
1146+
await task_1
1147+
await task_2
1148+
1149+
1150+
@pytest.mark.parametrize(
1151+
"database_url1,database_url2",
1152+
(
1153+
pytest.param(db1, db2, id=f"{db1} | {db2}")
1154+
for (db1, db2) in itertools.combinations(DATABASE_URLS, 2)
1155+
),
1156+
)
1157+
@async_adapter
1158+
async def test_connection_context_multiple_databases(database_url1, database_url2):
1159+
async with Database(database_url1) as database1:
1160+
async with Database(database_url2) as database2:
1161+
assert database1.connection() is not database2.connection()
1162+
1163+
8311164
@pytest.mark.parametrize("database_url", DATABASE_URLS)
8321165
@async_adapter
8331166
async def test_connection_context_with_raw_connection(database_url):
@@ -961,16 +1294,59 @@ async def test_database_url_interface(database_url):
9611294
@pytest.mark.parametrize("database_url", DATABASE_URLS)
9621295
@async_adapter
9631296
async def test_concurrent_access_on_single_connection(database_url):
964-
database_url = DatabaseURL(database_url)
965-
if database_url.dialect != "postgresql":
966-
pytest.skip("Test requires `pg_sleep()`")
967-
9681297
async with Database(database_url, force_rollback=True) as database:
9691298

9701299
async def db_lookup():
971-
await database.fetch_one("SELECT pg_sleep(1)")
1300+
await database.fetch_one("SELECT 1 AS value")
1301+
1302+
await asyncio.gather(
1303+
db_lookup(),
1304+
db_lookup(),
1305+
)
1306+
1307+
1308+
@pytest.mark.parametrize("database_url", DATABASE_URLS)
1309+
@async_adapter
1310+
async def test_concurrent_transactions_on_single_connection(database_url: str):
1311+
async with Database(database_url) as database:
1312+
1313+
@database.transaction()
1314+
async def db_lookup():
1315+
await database.fetch_one(query="SELECT 1 AS value")
1316+
1317+
await asyncio.gather(
1318+
db_lookup(),
1319+
db_lookup(),
1320+
)
1321+
1322+
1323+
@pytest.mark.parametrize("database_url", DATABASE_URLS)
1324+
@async_adapter
1325+
async def test_concurrent_tasks_on_single_connection(database_url: str):
1326+
async with Database(database_url) as database:
1327+
1328+
async def db_lookup():
1329+
await database.fetch_one(query="SELECT 1 AS value")
1330+
1331+
await asyncio.gather(
1332+
asyncio.create_task(db_lookup()),
1333+
asyncio.create_task(db_lookup()),
1334+
)
1335+
1336+
1337+
@pytest.mark.parametrize("database_url", DATABASE_URLS)
1338+
@async_adapter
1339+
async def test_concurrent_task_transactions_on_single_connection(database_url: str):
1340+
async with Database(database_url) as database:
1341+
1342+
@database.transaction()
1343+
async def db_lookup():
1344+
await database.fetch_one(query="SELECT 1 AS value")
9721345

973-
await asyncio.gather(db_lookup(), db_lookup())
1346+
await asyncio.gather(
1347+
asyncio.create_task(db_lookup()),
1348+
asyncio.create_task(db_lookup()),
1349+
)
9741350

9751351

9761352
@pytest.mark.parametrize("database_url", DATABASE_URLS)

0 commit comments

Comments
 (0)
Please sign in to comment.