@@ -56,13 +56,14 @@ def __init__(self, url, **kwargs):
56
56
if not os .path .isfile (matches .group (1 )):
57
57
raise RuntimeError ("not a file: {}" .format (matches .group (1 )))
58
58
59
- # Create engine, raising exception if back end's module not installed
60
- self ._engine = sqlalchemy .create_engine (url , ** kwargs ).execution_options (autocommit = True )
59
+ # Create engine, disabling SQLAlchemy's own autocommit mode, raising exception if back end's module not installed
60
+ self ._engine = sqlalchemy .create_engine (url , ** kwargs ).execution_options (autocommit = False )
61
61
62
62
# Listener for connections
63
63
def connect (dbapi_connection , connection_record ):
64
64
65
- # Disable underlying API's own emitting of BEGIN and COMMIT
65
+ # Disable underlying API's own emitting of BEGIN and COMMIT so we can ourselves
66
+ # https://docs.sqlalchemy.org/en/13/dialects/sqlite.html#serializable-isolation-savepoints-transactional-ddl
66
67
dbapi_connection .isolation_level = None
67
68
68
69
# Enable foreign key constraints
@@ -71,6 +72,9 @@ def connect(dbapi_connection, connection_record):
71
72
cursor .execute ("PRAGMA foreign_keys=ON" )
72
73
cursor .close ()
73
74
75
+ # Autocommit by default
76
+ self ._autocommit = True
77
+
74
78
# Register listener
75
79
sqlalchemy .event .listen (self ._engine , "connect" , connect )
76
80
@@ -90,9 +94,14 @@ def connect(dbapi_connection, connection_record):
90
94
self ._logger .disabled = disabled
91
95
92
96
def __del__ (self ):
97
+ """Disconnect from database."""
98
+ self ._disconnect ()
99
+
100
+ def _disconnect (self ):
93
101
"""Close database connection."""
94
102
if hasattr (self , "_connection" ):
95
103
self ._connection .close ()
104
+ delattr (self , "_connection" )
96
105
97
106
@_enable_logging
98
107
def execute (self , sql , * args , ** kwargs ):
@@ -107,7 +116,7 @@ def execute(self, sql, *args, **kwargs):
107
116
import warnings
108
117
109
118
# Parse statement, stripping comments and then leading/trailing whitespace
110
- statements = sqlparse .parse (sqlparse .format (sql , strip_comments = True ).strip ())
119
+ statements = sqlparse .parse (sqlparse .format (sql , keyword_case = "upper" , strip_comments = True ).strip ())
111
120
112
121
# Allow only one statement at a time, since SQLite doesn't support multiple
113
122
# https://docs.python.org/3/library/sqlite3.html#sqlite3.Cursor.execute
@@ -122,9 +131,10 @@ def execute(self, sql, *args, **kwargs):
122
131
123
132
# Infer command from (unflattened) statement
124
133
for token in statements [0 ]:
125
- if token .ttype in [sqlparse .tokens .Keyword .DDL , sqlparse .tokens .Keyword .DML ]:
126
- command = token .value .upper ()
127
- break
134
+ if token .ttype in [sqlparse .tokens .Keyword , sqlparse .tokens .Keyword .DDL , sqlparse .tokens .Keyword .DML ]:
135
+ if token .value in ["BEGIN" , "DELETE" , "INSERT" , "SELECT" , "START" , "UPDATE" ]:
136
+ command = token .value
137
+ break
128
138
else :
129
139
command = None
130
140
@@ -316,8 +326,21 @@ def shutdown_session(exception=None):
316
326
# Join tokens into statement, abbreviating binary data as <class 'bytes'>
317
327
_statement = "" .join ([str (bytes ) if token .ttype == sqlparse .tokens .Other else str (token ) for token in tokens ])
318
328
329
+ # Check for start of transaction
330
+ if command in ["BEGIN" , "START" ]:
331
+ self ._autocommit = False
332
+
319
333
# Execute statement
320
- result = connection .execute (sqlalchemy .text (statement ))
334
+ if self ._autocommit :
335
+ connection .execute (sqlalchemy .text ("BEGIN" ))
336
+ result = connection .execute (sqlalchemy .text (statement ))
337
+ connection .execute (sqlalchemy .text ("COMMIT" ))
338
+ else :
339
+ result = connection .execute (sqlalchemy .text (statement ))
340
+
341
+ # Check for end of transaction
342
+ if command in ["COMMIT" , "ROLLBACK" ]:
343
+ self ._autocommit = True
321
344
322
345
# Return value
323
346
ret = True
@@ -359,13 +382,15 @@ def shutdown_session(exception=None):
359
382
360
383
# If constraint violated, return None
361
384
except sqlalchemy .exc .IntegrityError as e :
385
+ self ._disconnect ()
362
386
self ._logger .debug (termcolor .colored (statement , "yellow" ))
363
387
e = RuntimeError (e .orig )
364
388
e .__cause__ = None
365
389
raise e
366
390
367
391
# If user errror
368
392
except sqlalchemy .exc .OperationalError as e :
393
+ self ._disconnect ()
369
394
self ._logger .debug (termcolor .colored (statement , "red" ))
370
395
e = RuntimeError (e .orig )
371
396
e .__cause__ = None
0 commit comments