36
36
logger = logging .getLogger ("databases" )
37
37
38
38
39
- _ACTIVE_CONNECTIONS : ContextVar [
40
- typing .Optional ["weakref.WeakKeyDictionary['Database', 'Connection']" ]
41
- ] = ContextVar ("databases:open_connections" , default = None )
42
39
_ACTIVE_TRANSACTIONS : ContextVar [
43
40
typing .Optional ["weakref.WeakKeyDictionary['Transaction', 'TransactionBackend']" ]
44
- ] = ContextVar ("databases:open_transactions " , default = None )
41
+ ] = ContextVar ("databases:active_transactions " , default = None )
45
42
46
43
47
44
class Database :
@@ -54,6 +51,8 @@ class Database:
54
51
"sqlite" : "databases.backends.sqlite:SQLiteBackend" ,
55
52
}
56
53
54
+ _connection_map : "weakref.WeakKeyDictionary[asyncio.Task, 'Connection']"
55
+
57
56
def __init__ (
58
57
self ,
59
58
url : typing .Union [str , "DatabaseURL" ],
@@ -64,6 +63,7 @@ def __init__(
64
63
self .url = DatabaseURL (url )
65
64
self .options = options
66
65
self .is_connected = False
66
+ self ._connection_map = weakref .WeakKeyDictionary ()
67
67
68
68
self ._force_rollback = force_rollback
69
69
@@ -78,28 +78,28 @@ def __init__(
78
78
self ._global_transaction : typing .Optional [Transaction ] = None
79
79
80
80
@property
81
- def _connection (self ) -> typing .Optional ["Connection" ]:
82
- connections = _ACTIVE_CONNECTIONS .get ()
83
- if connections is None :
84
- return None
81
+ def _current_task (self ) -> asyncio .Task :
82
+ task = asyncio .current_task ()
83
+ if not task :
84
+ raise RuntimeError ("No currently active asyncio.Task found" )
85
+ return task
85
86
86
- return connections .get (self , None )
87
+ @property
88
+ def _connection (self ) -> typing .Optional ["Connection" ]:
89
+ return self ._connection_map .get (self ._current_task )
87
90
88
91
@_connection .setter
89
92
def _connection (
90
93
self , connection : typing .Optional ["Connection" ]
91
94
) -> typing .Optional ["Connection" ]:
92
- connections = _ACTIVE_CONNECTIONS .get ()
93
- if connections is None :
94
- connections = weakref .WeakKeyDictionary ()
95
- _ACTIVE_CONNECTIONS .set (connections )
95
+ task = self ._current_task
96
96
97
97
if connection is None :
98
- connections . pop (self , None )
98
+ self . _connection_map . pop (task , None )
99
99
else :
100
- connections [ self ] = connection
100
+ self . _connection_map [ task ] = connection
101
101
102
- return connections . get ( self , None )
102
+ return self . _connection
103
103
104
104
async def connect (self ) -> None :
105
105
"""
@@ -119,7 +119,7 @@ async def connect(self) -> None:
119
119
assert self ._global_connection is None
120
120
assert self ._global_transaction is None
121
121
122
- self ._global_connection = Connection (self ._backend )
122
+ self ._global_connection = Connection (self , self ._backend )
123
123
self ._global_transaction = self ._global_connection .transaction (
124
124
force_rollback = True
125
125
)
@@ -218,7 +218,7 @@ def connection(self) -> "Connection":
218
218
return self ._global_connection
219
219
220
220
if not self ._connection :
221
- self ._connection = Connection (self ._backend )
221
+ self ._connection = Connection (self , self ._backend )
222
222
223
223
return self ._connection
224
224
@@ -243,7 +243,8 @@ def _get_backend(self) -> str:
243
243
244
244
245
245
class Connection :
246
- def __init__ (self , backend : DatabaseBackend ) -> None :
246
+ def __init__ (self , database : Database , backend : DatabaseBackend ) -> None :
247
+ self ._database = database
247
248
self ._backend = backend
248
249
249
250
self ._connection_lock = asyncio .Lock ()
@@ -277,6 +278,7 @@ async def __aexit__(
277
278
self ._connection_counter -= 1
278
279
if self ._connection_counter == 0 :
279
280
await self ._connection .release ()
281
+ self ._database ._connection = None
280
282
281
283
async def fetch_all (
282
284
self ,
@@ -393,13 +395,15 @@ def _transaction(
393
395
transactions = _ACTIVE_TRANSACTIONS .get ()
394
396
if transactions is None :
395
397
transactions = weakref .WeakKeyDictionary ()
396
- _ACTIVE_TRANSACTIONS .set (transactions )
398
+ else :
399
+ transactions = transactions .copy ()
397
400
398
401
if transaction is None :
399
402
transactions .pop (self , None )
400
403
else :
401
404
transactions [self ] = transaction
402
405
406
+ _ACTIVE_TRANSACTIONS .set (transactions )
403
407
return transactions .get (self , None )
404
408
405
409
async def __aenter__ (self ) -> "Transaction" :
0 commit comments