Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 9d007a4

Browse files
committedNov 22, 2020
fixed support for transactions
1 parent b2fc969 commit 9d007a4

File tree

2 files changed

+38
-16
lines changed

2 files changed

+38
-16
lines changed
 

‎src/cs50/sql.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,14 @@ def __init__(self, url, **kwargs):
5656
if not os.path.isfile(matches.group(1)):
5757
raise RuntimeError("not a file: {}".format(matches.group(1)))
5858

59-
# Create engine, raising exception if back end's module not installed
60-
self._engine = sqlalchemy.create_engine(url, **kwargs).execution_options(autocommit=True)
59+
# Create engine, disabling SQLAlchemy's own autocommit mode, raising exception if back end's module not installed
60+
self._engine = sqlalchemy.create_engine(url, **kwargs).execution_options(autocommit=False)
6161

6262
# Listener for connections
6363
def connect(dbapi_connection, connection_record):
6464

65-
# Disable underlying API's own emitting of BEGIN and COMMIT
65+
# Disable underlying API's own emitting of BEGIN and COMMIT so we can ourselves
66+
# https://docs.sqlalchemy.org/en/13/dialects/sqlite.html#serializable-isolation-savepoints-transactional-ddl
6667
dbapi_connection.isolation_level = None
6768

6869
# Enable foreign key constraints
@@ -71,6 +72,9 @@ def connect(dbapi_connection, connection_record):
7172
cursor.execute("PRAGMA foreign_keys=ON")
7273
cursor.close()
7374

75+
# Autocommit by default
76+
self._autocommit = True
77+
7478
# Register listener
7579
sqlalchemy.event.listen(self._engine, "connect", connect)
7680

@@ -90,9 +94,14 @@ def connect(dbapi_connection, connection_record):
9094
self._logger.disabled = disabled
9195

9296
def __del__(self):
97+
"""Disconnect from database."""
98+
self._disconnect()
99+
100+
def _disconnect(self):
93101
"""Close database connection."""
94102
if hasattr(self, "_connection"):
95103
self._connection.close()
104+
delattr(self, "_connection")
96105

97106
@_enable_logging
98107
def execute(self, sql, *args, **kwargs):
@@ -107,7 +116,7 @@ def execute(self, sql, *args, **kwargs):
107116
import warnings
108117

109118
# Parse statement, stripping comments and then leading/trailing whitespace
110-
statements = sqlparse.parse(sqlparse.format(sql, strip_comments=True).strip())
119+
statements = sqlparse.parse(sqlparse.format(sql, keyword_case="upper", strip_comments=True).strip())
111120

112121
# Allow only one statement at a time, since SQLite doesn't support multiple
113122
# https://docs.python.org/3/library/sqlite3.html#sqlite3.Cursor.execute
@@ -122,9 +131,10 @@ def execute(self, sql, *args, **kwargs):
122131

123132
# Infer command from (unflattened) statement
124133
for token in statements[0]:
125-
if token.ttype in [sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML]:
126-
command = token.value.upper()
127-
break
134+
if token.ttype in [sqlparse.tokens.Keyword, sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML]:
135+
if token.value in ["BEGIN", "DELETE", "INSERT", "SELECT", "START", "UPDATE"]:
136+
command = token.value
137+
break
128138
else:
129139
command = None
130140

@@ -316,8 +326,21 @@ def shutdown_session(exception=None):
316326
# Join tokens into statement, abbreviating binary data as <class 'bytes'>
317327
_statement = "".join([str(bytes) if token.ttype == sqlparse.tokens.Other else str(token) for token in tokens])
318328

329+
# Check for start of transaction
330+
if command in ["BEGIN", "START"]:
331+
self._autocommit = False
332+
319333
# Execute statement
320-
result = connection.execute(sqlalchemy.text(statement))
334+
if self._autocommit:
335+
connection.execute(sqlalchemy.text("BEGIN"))
336+
result = connection.execute(sqlalchemy.text(statement))
337+
connection.execute(sqlalchemy.text("COMMIT"))
338+
else:
339+
result = connection.execute(sqlalchemy.text(statement))
340+
341+
# Check for end of transaction
342+
if command in ["COMMIT", "ROLLBACK"]:
343+
self._autocommit = True
321344

322345
# Return value
323346
ret = True
@@ -359,13 +382,15 @@ def shutdown_session(exception=None):
359382

360383
# If constraint violated, return None
361384
except sqlalchemy.exc.IntegrityError as e:
385+
self._disconnect()
362386
self._logger.debug(termcolor.colored(statement, "yellow"))
363387
e = RuntimeError(e.orig)
364388
e.__cause__ = None
365389
raise e
366390

367391
# If user errror
368392
except sqlalchemy.exc.OperationalError as e:
393+
self._disconnect()
369394
self._logger.debug(termcolor.colored(statement, "red"))
370395
e = RuntimeError(e.orig)
371396
e.__cause__ = None

‎tests/foo.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55

66
import cs50
77

8+
"""
89
db = cs50.SQL("sqlite:///foo.db")
910
1011
logging.getLogger("cs50").disabled = False
1112
12-
"""
1313
#db.execute("SELECT ? FROM ? ORDER BY ?", "a", "tbl", "c")
1414
db.execute("CREATE TABLE IF NOT EXISTS bar (firstname STRING)")
1515
@@ -36,16 +36,13 @@
3636
print(db.execute("DROP TABLE IF EXISTS cs50"))
3737
print(db.execute("CREATE TABLE cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16), bin BYTEA)"))
3838
print(db.execute("INSERT INTO cs50 (val) VALUES('foo')"))
39-
print(db.execute("INSERT INTO cs50 (val) VALUES('foo')"))
40-
print(db.execute("INSERT INTO cs50 (val) VALUES('foo')"))
39+
print(db.execute("INSERT INTO cs50 (val) VALUES('bar')"))
40+
print(db.execute("INSERT INTO cs50 (val) VALUES('baz')"))
4141
print(db.execute("SELECT * FROM cs50"))
42-
print(db.execute("COMMIT"))
43-
"""
4442
try:
4543
print(db.execute("INSERT INTO cs50 (id, val) VALUES(1, 'bar')"))
4644
except Exception as e:
4745
print(e)
4846
pass
49-
print(db.execute("INSERT INTO cs50 (val) VALUES('foo')"))
50-
print(db.execute("DELETE FROM cs50"))
51-
"""
47+
print(db.execute("INSERT INTO cs50 (val) VALUES('qux')"))
48+
#print(db.execute("DELETE FROM cs50"))

0 commit comments

Comments
 (0)
Please sign in to comment.