Skip to content

Commit e9464d8

Browse files
authored
Merge pull request #133 from cs50/transactions
Fixes support for transactions for SQLite, PostgreSQL, and MySQL
2 parents a598334 + 16e52bf commit e9464d8

File tree

5 files changed

+174
-24
lines changed

5 files changed

+174
-24
lines changed

README.md

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,74 @@ f = cs50.get_float();
1919
i = cs50.get_int();
2020
s = cs50.get_string();
2121
```
22+
23+
## Testing
24+
25+
1. Run `cli50` in `python-cs50`.
26+
1. Run `sudo su -`.
27+
1. Run `apt install -y libmysqlclient-dev mysql-server postgresql`.
28+
1. Run `pip3 install mysqlclient psycopg2-binary`.
29+
1. In `/etc/mysql/mysql.conf.d/mysqld.cnf`, add `skip-grant-tables` under `[mysqld]`.
30+
1. In `/etc/profile.d/cli.sh`, remove `valgrind` function for now.
31+
1. Run `service mysql start`.
32+
1. Run `mysql -e 'CREATE DATABASE IF NOT EXISTS test;'`.
33+
1. In `/etc/postgresql/10/main/pg_hba.conf, change:
34+
```
35+
local all postgres peer
36+
host all all 127.0.0.1/32 md5
37+
```
38+
to:
39+
```
40+
local all postgres trust
41+
host all all 127.0.0.1/32 trust
42+
```
43+
1. Run `service postgresql start`.
44+
1. Run `psql -c 'create database test;' -U postgres`.
45+
1. Run `touch test.db`.
46+
47+
### Sample Tests
48+
49+
```
50+
import cs50
51+
db = cs50.SQL("sqlite:///foo.db")
52+
db.execute("CREATE TABLE IF NOT EXISTS cs50 (id INTEGER PRIMARY KEY, val TEXT, bin BLOB)")
53+
db.execute("INSERT INTO cs50 (val) VALUES('a')")
54+
db.execute("INSERT INTO cs50 (val) VALUES('b')")
55+
db.execute("BEGIN")
56+
db.execute("INSERT INTO cs50 (val) VALUES('c')")
57+
db.execute("INSERT INTO cs50 (val) VALUES('x')")
58+
db.execute("INSERT INTO cs50 (val) VALUES('y')")
59+
db.execute("ROLLBACK")
60+
db.execute("INSERT INTO cs50 (val) VALUES('z')")
61+
db.execute("COMMIT")
62+
63+
---
64+
65+
import cs50
66+
db = cs50.SQL("mysql://root@localhost/test")
67+
db.execute("CREATE TABLE IF NOT EXISTS cs50 (id INTEGER PRIMARY KEY, val TEXT, bin BLOB)")
68+
db.execute("INSERT INTO cs50 (val) VALUES('a')")
69+
db.execute("INSERT INTO cs50 (val) VALUES('b')")
70+
db.execute("BEGIN")
71+
db.execute("INSERT INTO cs50 (val) VALUES('c')")
72+
db.execute("INSERT INTO cs50 (val) VALUES('x')")
73+
db.execute("INSERT INTO cs50 (val) VALUES('y')")
74+
db.execute("ROLLBACK")
75+
db.execute("INSERT INTO cs50 (val) VALUES('z')")
76+
db.execute("COMMIT")
77+
78+
---
79+
80+
import cs50
81+
db = cs50.SQL("postgresql://postgres@localhost/test")
82+
db.execute("CREATE TABLE IF NOT EXISTS cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16), bin BYTEA)")
83+
db.execute("INSERT INTO cs50 (val) VALUES('a')")
84+
db.execute("INSERT INTO cs50 (val) VALUES('b')")
85+
db.execute("BEGIN")
86+
db.execute("INSERT INTO cs50 (val) VALUES('c')")
87+
db.execute("INSERT INTO cs50 (val) VALUES('x')")
88+
db.execute("INSERT INTO cs50 (val) VALUES('y')")
89+
db.execute("ROLLBACK")
90+
db.execute("INSERT INTO cs50 (val) VALUES('z')")
91+
db.execute("COMMIT")
92+
```

src/cs50/flask.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,21 @@ def _wrap_flask(f):
1414

1515
f.logging.default_handler.formatter.formatException = lambda exc_info: _formatException(*exc_info)
1616

17-
if os.getenv("CS50_IDE_TYPE") == "online":
17+
if os.getenv("CS50_IDE_TYPE"):
1818
from werkzeug.middleware.proxy_fix import ProxyFix
1919
_flask_init_before = f.Flask.__init__
2020
def _flask_init_after(self, *args, **kwargs):
2121
_flask_init_before(self, *args, **kwargs)
22-
self.wsgi_app = ProxyFix(self.wsgi_app, x_proto=1)
22+
self.config["TEMPLATES_AUTO_RELOAD"] = True # Automatically reload templates
23+
self.wsgi_app = ProxyFix(self.wsgi_app, x_proto=1) # For HTTPS-to-HTTP proxy
2324
f.Flask.__init__ = _flask_init_after
2425

2526

26-
# Flask was imported before cs50
27+
# If Flask was imported before cs50
2728
if "flask" in sys.modules:
2829
_wrap_flask(sys.modules["flask"])
2930

30-
# Flask wasn't imported
31+
# If Flask wasn't imported
3132
else:
3233
flask_loader = pkgutil.get_loader('flask')
3334
if flask_loader:

src/cs50/sql.py

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ def __init__(self, url, **kwargs):
6262
# Listener for connections
6363
def connect(dbapi_connection, connection_record):
6464

65-
# Disable underlying API's own emitting of BEGIN and COMMIT
65+
# Disable underlying API's own emitting of BEGIN and COMMIT so we can ourselves
66+
# https://docs.sqlalchemy.org/en/13/dialects/sqlite.html#serializable-isolation-savepoints-transactional-ddl
6667
dbapi_connection.isolation_level = None
6768

6869
# Enable foreign key constraints
@@ -71,6 +72,9 @@ def connect(dbapi_connection, connection_record):
7172
cursor.execute("PRAGMA foreign_keys=ON")
7273
cursor.close()
7374

75+
# Autocommit by default
76+
self._autocommit = True
77+
7478
# Register listener
7579
sqlalchemy.event.listen(self._engine, "connect", connect)
7680

@@ -90,9 +94,14 @@ def connect(dbapi_connection, connection_record):
9094
self._logger.disabled = disabled
9195

9296
def __del__(self):
97+
"""Disconnect from database."""
98+
self._disconnect()
99+
100+
def _disconnect(self):
93101
"""Close database connection."""
94102
if hasattr(self, "_connection"):
95103
self._connection.close()
104+
delattr(self, "_connection")
96105

97106
@_enable_logging
98107
def execute(self, sql, *args, **kwargs):
@@ -107,7 +116,7 @@ def execute(self, sql, *args, **kwargs):
107116
import warnings
108117

109118
# Parse statement, stripping comments and then leading/trailing whitespace
110-
statements = sqlparse.parse(sqlparse.format(sql, strip_comments=True).strip())
119+
statements = sqlparse.parse(sqlparse.format(sql, keyword_case="upper", strip_comments=True).strip())
111120

112121
# Allow only one statement at a time, since SQLite doesn't support multiple
113122
# https://docs.python.org/3/library/sqlite3.html#sqlite3.Cursor.execute
@@ -122,9 +131,10 @@ def execute(self, sql, *args, **kwargs):
122131

123132
# Infer command from (unflattened) statement
124133
for token in statements[0]:
125-
if token.ttype in [sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML]:
126-
command = token.value.upper()
127-
break
134+
if token.ttype in [sqlparse.tokens.Keyword, sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML]:
135+
if token.value in ["BEGIN", "DELETE", "INSERT", "SELECT", "START", "UPDATE"]:
136+
command = token.value
137+
break
128138
else:
129139
command = None
130140

@@ -271,7 +281,7 @@ def execute(self, sql, *args, **kwargs):
271281
# Join tokens into statement
272282
statement = "".join([str(token) for token in tokens])
273283

274-
# Connect to database (for transactions' sake)
284+
# Connect to database
275285
try:
276286

277287
# Infer whether Flask is installed
@@ -280,19 +290,23 @@ def execute(self, sql, *args, **kwargs):
280290
# Infer whether app is defined
281291
assert flask.current_app
282292

283-
# If no connection for app's current request yet
293+
# If new context
284294
if not hasattr(flask.g, "_connection"):
285295

286-
# Connect now
287-
flask.g._connection = self._engine.connect()
296+
# Ready to connect
297+
flask.g._connection = None
288298

289299
# Disconnect later
290300
@flask.current_app.teardown_appcontext
291301
def shutdown_session(exception=None):
292-
if hasattr(flask.g, "_connection"):
302+
if flask.g._connection:
293303
flask.g._connection.close()
294304

295-
# Use this connection
305+
# If no connection for context yet
306+
if not flask.g._connection:
307+
flas.g._connection = self._engine.connect()
308+
309+
# Use context's connection
296310
connection = flask.g._connection
297311

298312
except (ModuleNotFoundError, AssertionError):
@@ -316,8 +330,20 @@ def shutdown_session(exception=None):
316330
# Join tokens into statement, abbreviating binary data as <class 'bytes'>
317331
_statement = "".join([str(bytes) if token.ttype == sqlparse.tokens.Other else str(token) for token in tokens])
318332

333+
# Check for start of transaction
334+
if command in ["BEGIN", "START"]:
335+
self._autocommit = False
336+
319337
# Execute statement
338+
if self._autocommit:
339+
connection.execute(sqlalchemy.text("BEGIN"))
320340
result = connection.execute(sqlalchemy.text(statement))
341+
if self._autocommit:
342+
connection.execute(sqlalchemy.text("COMMIT"))
343+
344+
# Check for end of transaction
345+
if command in ["COMMIT", "ROLLBACK"]:
346+
self._autocommit = True
321347

322348
# Return value
323349
ret = True
@@ -360,12 +386,13 @@ def shutdown_session(exception=None):
360386
# If constraint violated, return None
361387
except sqlalchemy.exc.IntegrityError as e:
362388
self._logger.debug(termcolor.colored(statement, "yellow"))
363-
e = RuntimeError(e.orig)
389+
e = ValueError(e.orig)
364390
e.__cause__ = None
365391
raise e
366392

367-
# If user errror
368-
except sqlalchemy.exc.OperationalError as e:
393+
# If user error
394+
except (sqlalchemy.exc.OperationalError, sqlalchemy.exc.ProgrammingError) as e:
395+
self._disconnect()
369396
self._logger.debug(termcolor.colored(statement, "red"))
370397
e = RuntimeError(e.orig)
371398
e.__cause__ = None

tests/foo.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import logging
2+
import sys
3+
4+
sys.path.insert(0, "../src")
5+
6+
import cs50
7+
8+
"""
9+
db = cs50.SQL("sqlite:///foo.db")
10+
11+
logging.getLogger("cs50").disabled = False
12+
13+
#db.execute("SELECT ? FROM ? ORDER BY ?", "a", "tbl", "c")
14+
db.execute("CREATE TABLE IF NOT EXISTS bar (firstname STRING)")
15+
16+
db.execute("INSERT INTO bar VALUES (?)", "baz")
17+
db.execute("INSERT INTO bar VALUES (?)", "qux")
18+
db.execute("SELECT * FROM bar WHERE firstname IN (?)", ("baz", "qux"))
19+
db.execute("DELETE FROM bar")
20+
"""
21+
22+
db = cs50.SQL("postgresql://postgres@localhost/test")
23+
24+
"""
25+
print(db.execute("DROP TABLE IF EXISTS cs50"))
26+
print(db.execute("CREATE TABLE cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16), bin BYTEA)"))
27+
print(db.execute("INSERT INTO cs50 (val) VALUES('foo')"))
28+
print(db.execute("SELECT * FROM cs50"))
29+
30+
print(db.execute("DROP TABLE IF EXISTS cs50"))
31+
print(db.execute("CREATE TABLE cs50 (val VARCHAR(16), bin BYTEA)"))
32+
print(db.execute("INSERT INTO cs50 (val) VALUES('foo')"))
33+
print(db.execute("SELECT * FROM cs50"))
34+
"""
35+
36+
print(db.execute("DROP TABLE IF EXISTS cs50"))
37+
print(db.execute("CREATE TABLE cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16), bin BYTEA)"))
38+
print(db.execute("INSERT INTO cs50 (val) VALUES('foo')"))
39+
print(db.execute("INSERT INTO cs50 (val) VALUES('bar')"))
40+
print(db.execute("INSERT INTO cs50 (val) VALUES('baz')"))
41+
print(db.execute("SELECT * FROM cs50"))
42+
try:
43+
print(db.execute("INSERT INTO cs50 (id, val) VALUES(1, 'bar')"))
44+
except Exception as e:
45+
print(e)
46+
pass
47+
print(db.execute("INSERT INTO cs50 (val) VALUES('qux')"))
48+
#print(db.execute("DELETE FROM cs50"))

tests/sql.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,8 @@ def setUpClass(self):
150150
self.db = SQL("mysql://root@localhost/test")
151151

152152
def setUp(self):
153-
self.db.execute("CREATE TABLE cs50 (id INTEGER NOT NULL AUTO_INCREMENT, val VARCHAR(16), bin BLOB, PRIMARY KEY (id))")
153+
self.db.execute("CREATE TABLE IF NOT EXISTS cs50 (id INTEGER NOT NULL AUTO_INCREMENT, val VARCHAR(16), bin BLOB, PRIMARY KEY (id))")
154+
self.db.execute("DELETE FROM cs50")
154155

155156

156157
class PostgresTests(SQLTests):
@@ -159,7 +160,8 @@ def setUpClass(self):
159160
self.db = SQL("postgresql://postgres@localhost/test")
160161

161162
def setUp(self):
162-
self.db.execute("CREATE TABLE cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16), bin BYTEA)")
163+
self.db.execute("CREATE TABLE IF NOT EXISTS cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16), bin BYTEA)")
164+
self.db.execute("DELETE FROM cs50")
163165

164166
def test_cte(self):
165167
self.assertEqual(self.db.execute("WITH foo AS ( SELECT 1 AS bar ) SELECT bar FROM foo"), [{"bar": 1}])
@@ -173,23 +175,24 @@ def setUpClass(self):
173175
self.db = SQL("sqlite:///test.db")
174176

175177
def setUp(self):
176-
self.db.execute("CREATE TABLE cs50(id INTEGER PRIMARY KEY, val TEXT, bin BLOB)")
178+
self.db.execute("CREATE TABLE IF NOT EXISTS cs50 (id INTEGER PRIMARY KEY, val TEXT, bin BLOB)")
179+
self.db.execute("DELETE FROM cs50")
177180

178181
def test_lastrowid(self):
179182
self.db.execute("CREATE TABLE foo(id INTEGER PRIMARY KEY AUTOINCREMENT, firstname TEXT, lastname TEXT)")
180183
self.assertEqual(self.db.execute("INSERT INTO foo (firstname, lastname) VALUES('firstname', 'lastname')"), 1)
181-
self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo (id, firstname, lastname) VALUES(1, 'firstname', 'lastname')")
184+
self.assertRaises(ValueError, self.db.execute, "INSERT INTO foo (id, firstname, lastname) VALUES(1, 'firstname', 'lastname')")
182185
self.assertEqual(self.db.execute("INSERT OR IGNORE INTO foo (id, firstname, lastname) VALUES(1, 'firstname', 'lastname')"), None)
183186

184187
def test_integrity_constraints(self):
185188
self.db.execute("CREATE TABLE foo(id INTEGER PRIMARY KEY)")
186189
self.assertEqual(self.db.execute("INSERT INTO foo VALUES(1)"), 1)
187-
self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo VALUES(1)")
190+
self.assertRaises(ValueError, self.db.execute, "INSERT INTO foo VALUES(1)")
188191

189192
def test_foreign_key_support(self):
190193
self.db.execute("CREATE TABLE foo(id INTEGER PRIMARY KEY)")
191194
self.db.execute("CREATE TABLE bar(foo_id INTEGER, FOREIGN KEY (foo_id) REFERENCES foo(id))")
192-
self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO bar VALUES(50)")
195+
self.assertRaises(ValueError, self.db.execute, "INSERT INTO bar VALUES(50)")
193196

194197
def test_qmark(self):
195198
self.db.execute("CREATE TABLE foo (firstname STRING, lastname STRING)")

0 commit comments

Comments
 (0)