Skip to content

Commit 2d40229

Browse files
author
Kareem Zidane
authoredAug 19, 2020
Merge pull request #124 from cs50/develop
Re-implements SQL support with sessions
2 parents 3cb3ff0 + 4924fa1 commit 2d40229

File tree

5 files changed

+194
-31
lines changed

5 files changed

+194
-31
lines changed
 

‎setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,5 @@
1616
package_dir={"": "src"},
1717
packages=["cs50"],
1818
url="https://github.com/cs50/python-cs50",
19-
version="5.0.4"
19+
version="5.1.0"
2020
)

‎src/cs50/sql.py

+43-23
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def __init__(self, url, **kwargs):
4343
import os
4444
import re
4545
import sqlalchemy
46+
import sqlalchemy.orm
4647
import sqlite3
4748

4849
# Get logger
@@ -59,6 +60,11 @@ def __init__(self, url, **kwargs):
5960
# Create engine, disabling SQLAlchemy's own autocommit mode, raising exception if back end's module not installed
6061
self._engine = sqlalchemy.create_engine(url, **kwargs).execution_options(autocommit=False)
6162

63+
# Create a variable to hold the session. If None, autocommit is on.
64+
self._Session = sqlalchemy.orm.session.sessionmaker(bind=self._engine)
65+
self._session = None
66+
self._in_transaction = False
67+
6268
# Listener for connections
6369
def connect(dbapi_connection, connection_record):
6470

@@ -90,9 +96,8 @@ def connect(dbapi_connection, connection_record):
9096
self._logger.disabled = disabled
9197

9298
def __del__(self):
93-
"""Close database connection."""
94-
if hasattr(self, "_connection"):
95-
self._connection.close()
99+
"""Close database session and connection."""
100+
self._close_session()
96101

97102
@_enable_logging
98103
def execute(self, sql, *args, **kwargs):
@@ -125,6 +130,12 @@ def execute(self, sql, *args, **kwargs):
125130
if token.ttype in [sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML]:
126131
command = token.value.upper()
127132
break
133+
134+
# Begin a new session, if transaction opened by caller (not using autocommit)
135+
elif token.value.upper() in ["BEGIN", "START"]:
136+
if self._in_transaction:
137+
raise RuntimeError("transaction already open")
138+
self._in_transaction = True
128139
else:
129140
command = None
130141

@@ -272,6 +283,10 @@ def execute(self, sql, *args, **kwargs):
272283
statement = "".join([str(token) for token in tokens])
273284

274285
# Connect to database (for transactions' sake)
286+
if self._session is None:
287+
self._session = self._Session()
288+
289+
# Set up a Flask app teardown function to close session at teardown
275290
try:
276291

277292
# Infer whether Flask is installed
@@ -280,29 +295,17 @@ def execute(self, sql, *args, **kwargs):
280295
# Infer whether app is defined
281296
assert flask.current_app
282297

283-
# If no connection for app's current request yet
284-
if not hasattr(flask.g, "_connection"):
298+
# Disconnect later - but only once
299+
if not hasattr(self, "_teardown_appcontext_added"):
300+
self._teardown_appcontext_added = True
285301

286-
# Connect now
287-
flask.g._connection = self._engine.connect()
288-
289-
# Disconnect later
290302
@flask.current_app.teardown_appcontext
291303
def shutdown_session(exception=None):
292-
if hasattr(flask.g, "_connection"):
293-
flask.g._connection.close()
294-
295-
# Use this connection
296-
connection = flask.g._connection
304+
"""Close any existing session on app context teardown."""
305+
self._close_session()
297306

298307
except (ModuleNotFoundError, AssertionError):
299-
300-
# If no connection yet
301-
if not hasattr(self, "_connection"):
302-
self._connection = self._engine.connect()
303-
304-
# Use this connection
305-
connection = self._connection
308+
pass
306309

307310
# Catch SQLAlchemy warnings
308311
with warnings.catch_warnings():
@@ -316,8 +319,14 @@ def shutdown_session(exception=None):
316319
# Join tokens into statement, abbreviating binary data as <class 'bytes'>
317320
_statement = "".join([str(bytes) if token.ttype == sqlparse.tokens.Other else str(token) for token in tokens])
318321

322+
# If COMMIT or ROLLBACK, turn on autocommit mode
323+
if command in ["COMMIT", "ROLLBACK"] and "TO" not in (token.value for token in tokens):
324+
if not self._in_transaction:
325+
raise RuntimeError("transactions must be opened with BEGIN or START TRANSACTION")
326+
self._in_transaction = False
327+
319328
# Execute statement
320-
result = connection.execute(sqlalchemy.text(statement))
329+
result = self._session.execute(sqlalchemy.text(statement))
321330

322331
# Return value
323332
ret = True
@@ -346,7 +355,7 @@ def shutdown_session(exception=None):
346355
elif command == "INSERT":
347356
if self._engine.url.get_backend_name() in ["postgres", "postgresql"]:
348357
try:
349-
result = connection.execute("SELECT LASTVAL()")
358+
result = self._session.execute("SELECT LASTVAL()")
350359
ret = result.first()[0]
351360
except sqlalchemy.exc.OperationalError: # If lastval is not yet defined in this session
352361
ret = None
@@ -357,6 +366,10 @@ def shutdown_session(exception=None):
357366
elif command in ["DELETE", "UPDATE"]:
358367
ret = result.rowcount
359368

369+
# If autocommit is on, commit
370+
if not self._in_transaction:
371+
self._session.commit()
372+
360373
# If constraint violated, return None
361374
except sqlalchemy.exc.IntegrityError as e:
362375
self._logger.debug(termcolor.colored(statement, "yellow"))
@@ -376,6 +389,13 @@ def shutdown_session(exception=None):
376389
self._logger.debug(termcolor.colored(_statement, "green"))
377390
return ret
378391

392+
def _close_session(self):
393+
"""Closes any existing session and resets instance variables."""
394+
if self._session is not None:
395+
self._session.close()
396+
self._session = None
397+
self._in_transaction = False
398+
379399
def _escape(self, value):
380400
"""
381401
Escapes value using engine's conversion function.

‎tests/flask/application.py

+57-3
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,76 @@
1+
import logging
2+
import os
13
import requests
24
import sys
3-
from flask import Flask, render_template
45

56
sys.path.insert(0, "../../src")
67

78
import cs50
89
import cs50.flask
910

11+
from flask import Flask, render_template
12+
1013
app = Flask(__name__)
1114

12-
db = cs50.SQL("sqlite:///../sqlite.db")
15+
logging.disable(logging.CRITICAL)
16+
os.environ["WERKZEUG_RUN_MAIN"] = "true"
17+
18+
db_url = "sqlite:///../test.db"
19+
db = cs50.SQL(db_url)
1320

1421
@app.route("/")
1522
def index():
16-
db.execute("SELECT 1")
1723
"""
1824
def f():
1925
res = requests.get("cs50.harvard.edu")
2026
f()
2127
"""
2228
return render_template("index.html")
29+
30+
@app.route("/autocommit")
31+
def autocommit():
32+
db.execute("INSERT INTO test (val) VALUES (?)", "def")
33+
db2 = cs50.SQL(db_url)
34+
ret = db2.execute("SELECT val FROM test WHERE val=?", "def")
35+
return str(ret == [{"val": "def"}])
36+
37+
@app.route("/create")
38+
def create():
39+
ret = db.execute("CREATE TABLE test (id INTEGER PRIMARY KEY AUTOINCREMENT, val VARCHAR(16))")
40+
return str(ret)
41+
42+
@app.route("/delete")
43+
def delete():
44+
ret = db.execute("DELETE FROM test")
45+
return str(ret > 0)
46+
47+
@app.route("/drop")
48+
def drop():
49+
ret = db.execute("DROP TABLE test")
50+
return str(ret)
51+
52+
@app.route("/insert")
53+
def insert():
54+
ret = db.execute("INSERT INTO test (val) VALUES (?)", "abc")
55+
return str(ret > 0)
56+
57+
@app.route("/multiple_connections")
58+
def multiple_connections():
59+
ctx = len(app.teardown_appcontext_funcs)
60+
db1 = cs50.SQL(db_url)
61+
td1 = (len(app.teardown_appcontext_funcs) == ctx + 1)
62+
db2 = cs50.SQL(db_url)
63+
td2 = (len(app.teardown_appcontext_funcs) == ctx + 2)
64+
return str(td1 and td2)
65+
66+
@app.route("/select")
67+
def select():
68+
ret = db.execute("SELECT val FROM test")
69+
return str(ret == [{"val": "abc"}])
70+
71+
@app.route("/single_teardown")
72+
def single_teardown():
73+
db.execute("SELECT * FROM test")
74+
ctx = len(app.teardown_appcontext_funcs)
75+
db.execute("SELECT COUNT(id) FROM test")
76+
return str(ctx == len(app.teardown_appcontext_funcs))

‎tests/flask/test.py

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import logging
2+
import requests
3+
import sys
4+
import threading
5+
import time
6+
import unittest
7+
8+
from application import app
9+
10+
def request(route):
11+
r = requests.get("http://localhost:5000/{}".format(route))
12+
return r.text == "True"
13+
14+
class FlaskTests(unittest.TestCase):
15+
16+
def test__create(self):
17+
self.assertTrue(request("create"))
18+
19+
def test_autocommit(self):
20+
self.assertTrue(request("autocommit"))
21+
22+
def test_delete(self):
23+
self.assertTrue(request("delete"))
24+
25+
def test_insert(self):
26+
self.assertTrue(request("insert"))
27+
28+
def test_multiple_connections(self):
29+
self.assertTrue(request("multiple_connections"))
30+
31+
def test_select(self):
32+
self.assertTrue(request("select"))
33+
34+
def test_single_teardown(self):
35+
self.assertTrue(request("single_teardown"))
36+
37+
def test_zdrop(self):
38+
self.assertTrue(request("drop"))
39+
40+
41+
if __name__ == "__main__":
42+
t = threading.Thread(target=app.run, daemon=True)
43+
t.start()
44+
45+
suite = unittest.TestSuite([
46+
unittest.TestLoader().loadTestsFromTestCase(FlaskTests)
47+
])
48+
49+
sys.exit(not unittest.TextTestRunner(verbosity=2).run(suite).wasSuccessful())

‎tests/sql.py

+44-4
Original file line numberDiff line numberDiff line change
@@ -115,11 +115,34 @@ def test_blob(self):
115115
self.db.execute("INSERT INTO cs50(bin) VALUES(:bin)", bin=row["bin"])
116116
self.assertEqual(self.db.execute("SELECT id, bin FROM cs50"), rows)
117117

118+
def test_autocommit(self):
119+
self.assertEqual(self.db.execute("INSERT INTO cs50(val) VALUES('foo')"), 1)
120+
self.assertEqual(self.db.execute("INSERT INTO cs50(val) VALUES('bar')"), 2)
121+
122+
# Load a new database instance to confirm the INSERTs were committed
123+
db2 = SQL(self.db_url)
124+
self.assertEqual(db2.execute("DELETE FROM cs50 WHERE id < 3"), 2)
125+
126+
def test_commit_no_transaction(self):
127+
with self.assertRaises(RuntimeError):
128+
self.db.execute("COMMIT")
129+
with self.assertRaises(RuntimeError):
130+
self.db.execute("ROLLBACK")
131+
118132
def test_commit(self):
119133
self.db.execute("BEGIN")
120134
self.db.execute("INSERT INTO cs50 (val) VALUES('foo')")
121135
self.db.execute("COMMIT")
122-
self.assertEqual(self.db.execute("SELECT val FROM cs50"), [{"val": "foo"}])
136+
137+
# Load a new database instance to confirm the INSERT was committed
138+
db2 = SQL(self.db_url)
139+
self.assertEqual(db2.execute("SELECT val FROM cs50"), [{"val": "foo"}])
140+
141+
def test_double_begin(self):
142+
self.db.execute("BEGIN")
143+
with self.assertRaises(RuntimeError):
144+
self.db.execute("BEGIN")
145+
self.db.execute("ROLLBACK")
123146

124147
def test_rollback(self):
125148
self.db.execute("BEGIN")
@@ -128,6 +151,17 @@ def test_rollback(self):
128151
self.db.execute("ROLLBACK")
129152
self.assertEqual(self.db.execute("SELECT val FROM cs50"), [])
130153

154+
def test_savepoint(self):
155+
self.db.execute("BEGIN")
156+
self.db.execute("INSERT INTO cs50 (val) VALUES('foo')")
157+
self.db.execute("SAVEPOINT sp1")
158+
self.db.execute("INSERT INTO cs50 (val) VALUES('bar')")
159+
self.assertEqual(self.db.execute("SELECT val FROM cs50"), [{"val": "foo"}, {"val": "bar"}])
160+
self.db.execute("ROLLBACK TO sp1")
161+
self.assertEqual(self.db.execute("SELECT val FROM cs50"), [{"val": "foo"}])
162+
self.db.execute("ROLLBACK")
163+
self.assertEqual(self.db.execute("SELECT val FROM cs50"), [])
164+
131165
def tearDown(self):
132166
self.db.execute("DROP TABLE cs50")
133167
self.db.execute("DROP TABLE IF EXISTS foo")
@@ -145,15 +179,19 @@ def tearDownClass(self):
145179
class MySQLTests(SQLTests):
146180
@classmethod
147181
def setUpClass(self):
148-
self.db = SQL("mysql://root@localhost/test")
182+
self.db_url = "mysql://root@localhost/test"
183+
self.db = SQL(self.db_url)
184+
print("\nMySQL tests")
149185

150186
def setUp(self):
151187
self.db.execute("CREATE TABLE cs50 (id INTEGER NOT NULL AUTO_INCREMENT, val VARCHAR(16), bin BLOB, PRIMARY KEY (id))")
152188

153189
class PostgresTests(SQLTests):
154190
@classmethod
155191
def setUpClass(self):
156-
self.db = SQL("postgresql://postgres@localhost/test")
192+
self.db_url = "postgresql://postgres@localhost/test"
193+
self.db = SQL(self.db_url)
194+
print("\nPOSTGRES tests")
157195

158196
def setUp(self):
159197
self.db.execute("CREATE TABLE cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16), bin BYTEA)")
@@ -165,7 +203,9 @@ class SQLiteTests(SQLTests):
165203
@classmethod
166204
def setUpClass(self):
167205
open("test.db", "w").close()
168-
self.db = SQL("sqlite:///test.db")
206+
self.db_url = "sqlite:///test.db"
207+
self.db = SQL(self.db_url)
208+
print("\nSQLite tests")
169209

170210
def setUp(self):
171211
self.db.execute("CREATE TABLE cs50(id INTEGER PRIMARY KEY, val TEXT, bin BLOB)")

0 commit comments

Comments
 (0)