1
1
import asyncio
2
2
import contextlib
3
- from contextvars import ContextVar
4
3
import functools
5
4
import logging
6
5
import typing
6
+ import weakref
7
+ from contextvars import ContextVar
7
8
from types import TracebackType
8
9
from urllib .parse import SplitResult , parse_qsl , unquote , urlsplit
9
- import weakref
10
+
10
11
from sqlalchemy import text
11
12
from sqlalchemy .sql import ClauseElement
12
13
@@ -93,8 +94,12 @@ def _connection(
93
94
connections = weakref .WeakKeyDictionary ()
94
95
_ACTIVE_CONNECTIONS .set (connections )
95
96
96
- connections [self ] = connection
97
- return connections [self ]
97
+ if connection is None :
98
+ connections .pop (self , None )
99
+ else :
100
+ connections [self ] = connection
101
+
102
+ return connections .get (self , None )
98
103
99
104
async def connect (self ) -> None :
100
105
"""
@@ -390,8 +395,12 @@ def _transaction(
390
395
transactions = weakref .WeakKeyDictionary ()
391
396
_ACTIVE_TRANSACTIONS .set (transactions )
392
397
393
- transactions [self ] = transaction
394
- return transactions [self ]
398
+ if transaction is None :
399
+ transactions .pop (self , None )
400
+ else :
401
+ transactions [self ] = transaction
402
+
403
+ return transactions .get (self , None )
395
404
396
405
async def __aenter__ (self ) -> "Transaction" :
397
406
"""
@@ -448,6 +457,7 @@ async def commit(self) -> None:
448
457
async with self ._connection ._transaction_lock :
449
458
assert self ._connection ._transaction_stack [- 1 ] is self
450
459
self ._connection ._transaction_stack .pop ()
460
+ assert self ._transaction is not None
451
461
await self ._transaction .commit ()
452
462
await self ._connection .__aexit__ ()
453
463
self ._transaction = None
@@ -456,6 +466,7 @@ async def rollback(self) -> None:
456
466
async with self ._connection ._transaction_lock :
457
467
assert self ._connection ._transaction_stack [- 1 ] is self
458
468
self ._connection ._transaction_stack .pop ()
469
+ assert self ._transaction is not None
459
470
await self ._transaction .rollback ()
460
471
await self ._connection .__aexit__ ()
461
472
self ._transaction = None
0 commit comments