@@ -43,6 +43,7 @@ def __init__(self, url, **kwargs):
43
43
import os
44
44
import re
45
45
import sqlalchemy
46
+ import sqlalchemy .orm
46
47
import sqlite3
47
48
48
49
# Get logger
@@ -59,6 +60,11 @@ def __init__(self, url, **kwargs):
59
60
# Create engine, disabling SQLAlchemy's own autocommit mode, raising exception if back end's module not installed
60
61
self ._engine = sqlalchemy .create_engine (url , ** kwargs ).execution_options (autocommit = False )
61
62
63
+ # Create a variable to hold the session. If None, autocommit is on.
64
+ self ._Session = sqlalchemy .orm .session .sessionmaker (bind = self ._engine )
65
+ self ._session = None
66
+ self ._in_transaction = False
67
+
62
68
# Listener for connections
63
69
def connect (dbapi_connection , connection_record ):
64
70
@@ -90,9 +96,8 @@ def connect(dbapi_connection, connection_record):
90
96
self ._logger .disabled = disabled
91
97
92
98
def __del__ (self ):
93
- """Close database connection."""
94
- if hasattr (self , "_connection" ):
95
- self ._connection .close ()
99
+ """Close database session and connection."""
100
+ self ._close_session ()
96
101
97
102
@_enable_logging
98
103
def execute (self , sql , * args , ** kwargs ):
@@ -125,6 +130,12 @@ def execute(self, sql, *args, **kwargs):
125
130
if token .ttype in [sqlparse .tokens .Keyword .DDL , sqlparse .tokens .Keyword .DML ]:
126
131
command = token .value .upper ()
127
132
break
133
+
134
+ # Begin a new session, if transaction opened by caller (not using autocommit)
135
+ elif token .value .upper () in ["BEGIN" , "START" ]:
136
+ if self ._in_transaction :
137
+ raise RuntimeError ("transaction already open" )
138
+ self ._in_transaction = True
128
139
else :
129
140
command = None
130
141
@@ -272,6 +283,10 @@ def execute(self, sql, *args, **kwargs):
272
283
statement = "" .join ([str (token ) for token in tokens ])
273
284
274
285
# Connect to database (for transactions' sake)
286
+ if self ._session is None :
287
+ self ._session = self ._Session ()
288
+
289
+ # Set up a Flask app teardown function to close session at teardown
275
290
try :
276
291
277
292
# Infer whether Flask is installed
@@ -280,29 +295,17 @@ def execute(self, sql, *args, **kwargs):
280
295
# Infer whether app is defined
281
296
assert flask .current_app
282
297
283
- # If no connection for app's current request yet
284
- if not hasattr (flask .g , "_connection" ):
298
+ # Disconnect later - but only once
299
+ if not hasattr (self , "_teardown_appcontext_added" ):
300
+ self ._teardown_appcontext_added = True
285
301
286
- # Connect now
287
- flask .g ._connection = self ._engine .connect ()
288
-
289
- # Disconnect later
290
302
@flask .current_app .teardown_appcontext
291
303
def shutdown_session (exception = None ):
292
- if hasattr (flask .g , "_connection" ):
293
- flask .g ._connection .close ()
294
-
295
- # Use this connection
296
- connection = flask .g ._connection
304
+ """Close any existing session on app context teardown."""
305
+ self ._close_session ()
297
306
298
307
except (ModuleNotFoundError , AssertionError ):
299
-
300
- # If no connection yet
301
- if not hasattr (self , "_connection" ):
302
- self ._connection = self ._engine .connect ()
303
-
304
- # Use this connection
305
- connection = self ._connection
308
+ pass
306
309
307
310
# Catch SQLAlchemy warnings
308
311
with warnings .catch_warnings ():
@@ -316,8 +319,14 @@ def shutdown_session(exception=None):
316
319
# Join tokens into statement, abbreviating binary data as <class 'bytes'>
317
320
_statement = "" .join ([str (bytes ) if token .ttype == sqlparse .tokens .Other else str (token ) for token in tokens ])
318
321
322
+ # If COMMIT or ROLLBACK, turn on autocommit mode
323
+ if command in ["COMMIT" , "ROLLBACK" ] and "TO" not in (token .value for token in tokens ):
324
+ if not self ._in_transaction :
325
+ raise RuntimeError ("transactions must be opened with BEGIN or START TRANSACTION" )
326
+ self ._in_transaction = False
327
+
319
328
# Execute statement
320
- result = connection .execute (sqlalchemy .text (statement ))
329
+ result = self . _session .execute (sqlalchemy .text (statement ))
321
330
322
331
# Return value
323
332
ret = True
@@ -346,7 +355,7 @@ def shutdown_session(exception=None):
346
355
elif command == "INSERT" :
347
356
if self ._engine .url .get_backend_name () in ["postgres" , "postgresql" ]:
348
357
try :
349
- result = connection .execute ("SELECT LASTVAL()" )
358
+ result = self . _session .execute ("SELECT LASTVAL()" )
350
359
ret = result .first ()[0 ]
351
360
except sqlalchemy .exc .OperationalError : # If lastval is not yet defined in this session
352
361
ret = None
@@ -357,6 +366,10 @@ def shutdown_session(exception=None):
357
366
elif command in ["DELETE" , "UPDATE" ]:
358
367
ret = result .rowcount
359
368
369
+ # If autocommit is on, commit
370
+ if not self ._in_transaction :
371
+ self ._session .commit ()
372
+
360
373
# If constraint violated, return None
361
374
except sqlalchemy .exc .IntegrityError as e :
362
375
self ._logger .debug (termcolor .colored (statement , "yellow" ))
@@ -376,6 +389,13 @@ def shutdown_session(exception=None):
376
389
self ._logger .debug (termcolor .colored (_statement , "green" ))
377
390
return ret
378
391
392
+ def _close_session (self ):
393
+ """Closes any existing session and resets instance variables."""
394
+ if self ._session is not None :
395
+ self ._session .close ()
396
+ self ._session = None
397
+ self ._in_transaction = False
398
+
379
399
def _escape (self , value ):
380
400
"""
381
401
Escapes value using engine's conversion function.
0 commit comments