Skip to content

Commit 0dff716

Browse files
committedJun 4, 2020
Style fixes. Minor design improvements, including removing SQL class URL variable, and always committing session so as to release locks.
1 parent 6e17982 commit 0dff716

File tree

4 files changed

+25
-23
lines changed

4 files changed

+25
-23
lines changed
 

‎src/cs50/sql.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def __init__(self, url, **kwargs):
4343
import os
4444
import re
4545
import sqlalchemy
46-
import sqlalchemy.orm as orm
46+
import sqlalchemy.orm
4747
import sqlite3
4848

4949
# Get logger
@@ -57,14 +57,11 @@ def __init__(self, url, **kwargs):
5757
if not os.path.isfile(matches.group(1)):
5858
raise RuntimeError("not a file: {}".format(matches.group(1)))
5959

60-
# Record the URL (used in testing)
61-
self.url = url
62-
6360
# Create engine, disabling SQLAlchemy's own autocommit mode, raising exception if back end's module not installed
6461
self._engine = sqlalchemy.create_engine(url, **kwargs).execution_options(autocommit=False)
6562

6663
# Create a variable to hold the session. If None, autocommit is on.
67-
self.Session = orm.sessionmaker(bind=self._engine)
64+
self._Session = sqlalchemy.orm.session.sessionmaker(bind=self._engine)
6865
self._session = None
6966

7067
# Listener for connections
@@ -101,6 +98,7 @@ def __del__(self):
10198
"""Close database session and connection."""
10299
if self._session is not None:
103100
self._session.close()
101+
self._session = None
104102

105103
@_enable_logging
106104
def execute(self, sql, *args, **kwargs):
@@ -134,11 +132,11 @@ def execute(self, sql, *args, **kwargs):
134132
command = token.value.upper()
135133
break
136134

137-
# Begin a new transaction session, if done manually
135+
# Begin a new session, if transaction started by caller (not using autocommit)
138136
elif token.value.upper() in ["BEGIN", "START"]:
139137
if self._session is not None:
140138
self._session.close()
141-
self._session = self.Session()
139+
self._session = self._Session()
142140
else:
143141
command = None
144142

@@ -288,7 +286,7 @@ def execute(self, sql, *args, **kwargs):
288286
# Connect to database (for transactions' sake)
289287
session = self._session
290288
if session is None:
291-
session = self.Session()
289+
session = self._Session()
292290

293291
# Set up a Flask app teardown function to close session at teardown
294292
try:
@@ -303,11 +301,12 @@ def execute(self, sql, *args, **kwargs):
303301
if not hasattr(self, "teardown_appcontext_added"):
304302
self.teardown_appcontext_added = True
305303

306-
# Register shutdown_session on app context teardown
307304
@flask.current_app.teardown_appcontext
308305
def shutdown_session(exception=None):
306+
"""Close any existing session on app context teardown."""
309307
if self._session is not None:
310308
self._session.close()
309+
self._session = None
311310

312311
except (ModuleNotFoundError, AssertionError):
313312
pass
@@ -370,11 +369,9 @@ def shutdown_session(exception=None):
370369
session.close()
371370
self._session = None
372371

373-
374372
# If autocommit is on, commit and close
375373
if self._session is None and command not in ["COMMIT", "ROLLBACK"]:
376-
if command not in ["SELECT"]:
377-
session.commit()
374+
session.commit()
378375
session.close()
379376

380377
# If constraint violated, return None

‎tests/flask/application.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,21 @@
22
import os
33
import requests
44
import sys
5-
from flask import Flask, render_template
65

76
sys.path.insert(0, "../../src")
87

98
import cs50
109
import cs50.flask
1110

11+
from flask import Flask, render_template
12+
1213
app = Flask(__name__)
1314

1415
logging.disable(logging.CRITICAL)
1516
os.environ["WERKZEUG_RUN_MAIN"] = "true"
1617

17-
db = cs50.SQL("sqlite:///../test.db")
18+
db_url = "sqlite:///../test.db"
19+
db = cs50.SQL(db_url)
1820

1921
@app.route("/")
2022
def index():
@@ -28,7 +30,7 @@ def f():
2830
@app.route("/autocommit")
2931
def autocommit():
3032
db.execute("INSERT INTO test (val) VALUES (?)", "def")
31-
db2 = cs50.SQL(db.url)
33+
db2 = cs50.SQL(db_url)
3234
ret = db2.execute("SELECT val FROM test WHERE val=?", "def")
3335
return str(ret == [{"val": "def"}])
3436

@@ -55,9 +57,9 @@ def insert():
5557
@app.route("/multiple_connections")
5658
def multiple_connections():
5759
ctx = len(app.teardown_appcontext_funcs)
58-
db1 = cs50.SQL(db.url)
60+
db1 = cs50.SQL(db_url)
5961
td1 = (len(app.teardown_appcontext_funcs) == ctx + 1)
60-
db2 = cs50.SQL(db.url)
62+
db2 = cs50.SQL(db_url)
6163
td2 = (len(app.teardown_appcontext_funcs) == ctx + 2)
6264
return str(td1 and td2)
6365

‎tests/flask/test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
from application import app
21
import logging
32
import requests
43
import sys
54
import threading
65
import time
76
import unittest
87

8+
from application import app
99

1010
def request(route):
1111
r = requests.get("http://localhost:5000/{}".format(route))

‎tests/sql.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def test_autocommit(self):
120120
self.assertEqual(self.db.execute("INSERT INTO cs50(val) VALUES('bar')"), 2)
121121

122122
# Load a new database instance to confirm the INSERTs were committed
123-
db2 = SQL(self.db.url)
123+
db2 = SQL(self.db_url)
124124
self.assertEqual(db2.execute("DELETE FROM cs50 WHERE id < 3"), 2)
125125

126126
def test_commit(self):
@@ -129,7 +129,7 @@ def test_commit(self):
129129
self.db.execute("COMMIT")
130130

131131
# Load a new database instance to confirm the INSERT was committed
132-
db2 = SQL(self.db.url)
132+
db2 = SQL(self.db_url)
133133
self.assertEqual(db2.execute("SELECT val FROM cs50"), [{"val": "foo"}])
134134

135135
def test_rollback(self):
@@ -167,7 +167,8 @@ def tearDownClass(self):
167167
class MySQLTests(SQLTests):
168168
@classmethod
169169
def setUpClass(self):
170-
self.db = SQL("mysql://root@localhost/test")
170+
self.db_url = "mysql://root@localhost/test"
171+
self.db = SQL(self.db_url)
171172
print("\nMySQL tests")
172173

173174
def setUp(self):
@@ -176,7 +177,8 @@ def setUp(self):
176177
class PostgresTests(SQLTests):
177178
@classmethod
178179
def setUpClass(self):
179-
self.db = SQL("postgresql://root:test@localhost/test")
180+
self.db_url = "postgresql://root:test@localhost/test"
181+
self.db = SQL(self.db_url)
180182
print("\nPOSTGRES tests")
181183

182184
def setUp(self):
@@ -189,7 +191,8 @@ class SQLiteTests(SQLTests):
189191
@classmethod
190192
def setUpClass(self):
191193
open("test.db", "w").close()
192-
self.db = SQL("sqlite:///test.db")
194+
self.db_url = "sqlite:///test.db"
195+
self.db = SQL(self.db_url)
193196
print("\nSQLite tests")
194197

195198
def setUp(self):

0 commit comments

Comments
 (0)
Please sign in to comment.