From 6e1798224295a0e0aaa076164694a8e262f6ca03 Mon Sep 17 00:00:00 2001 From: Joshua Archibald <jarchibald121@gmail.com> Date: Thu, 4 Jun 2020 12:17:45 -0500 Subject: [PATCH 1/6] Use sessions to handle transactions, allowing for both auto and manual commit modes. Registers Flask appcontext teardown function only once per database instance, and also allows for multiple database connections in a single Flask request. Add unit tests for SQL savepoints, autocommit mode, manual transaction mode. Add integration tests for Flask. --- src/cs50/sql.py | 66 +++++++++++++++++++++++++------------- tests/flask/application.py | 56 ++++++++++++++++++++++++++++++-- tests/flask/test.py | 49 ++++++++++++++++++++++++++++ tests/sql.py | 29 +++++++++++++++-- 4 files changed, 173 insertions(+), 27 deletions(-) create mode 100644 tests/flask/test.py diff --git a/src/cs50/sql.py b/src/cs50/sql.py index cd8ae88..55cf058 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -43,6 +43,7 @@ def __init__(self, url, **kwargs): import os import re import sqlalchemy + import sqlalchemy.orm as orm import sqlite3 # Get logger @@ -56,9 +57,16 @@ def __init__(self, url, **kwargs): if not os.path.isfile(matches.group(1)): raise RuntimeError("not a file: {}".format(matches.group(1))) + # Record the URL (used in testing) + self.url = url + # Create engine, disabling SQLAlchemy's own autocommit mode, raising exception if back end's module not installed self._engine = sqlalchemy.create_engine(url, **kwargs).execution_options(autocommit=False) + # Create a variable to hold the session. If None, autocommit is on. + self.Session = orm.sessionmaker(bind=self._engine) + self._session = None + # Listener for connections def connect(dbapi_connection, connection_record): @@ -90,9 +98,9 @@ def connect(dbapi_connection, connection_record): self._logger.disabled = disabled def __del__(self): - """Close database connection.""" - if hasattr(self, "_connection"): - self._connection.close() + """Close database session and connection.""" + if self._session is not None: + self._session.close() @_enable_logging def execute(self, sql, *args, **kwargs): @@ -125,6 +133,12 @@ def execute(self, sql, *args, **kwargs): if token.ttype in [sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML]: command = token.value.upper() break + + # Begin a new transaction session, if done manually + elif token.value.upper() in ["BEGIN", "START"]: + if self._session is not None: + self._session.close() + self._session = self.Session() else: command = None @@ -272,6 +286,11 @@ def execute(self, sql, *args, **kwargs): statement = "".join([str(token) for token in tokens]) # Connect to database (for transactions' sake) + session = self._session + if session is None: + session = self.Session() + + # Set up a Flask app teardown function to close session at teardown try: # Infer whether Flask is installed @@ -280,29 +299,18 @@ def execute(self, sql, *args, **kwargs): # Infer whether app is defined assert flask.current_app - # If no connection for app's current request yet - if not hasattr(flask.g, "_connection"): - - # Connect now - flask.g._connection = self._engine.connect() + # Disconnect later - but only once + if not hasattr(self, "teardown_appcontext_added"): + self.teardown_appcontext_added = True - # Disconnect later + # Register shutdown_session on app context teardown @flask.current_app.teardown_appcontext def shutdown_session(exception=None): - if hasattr(flask.g, "_connection"): - flask.g._connection.close() - - # Use this connection - connection = flask.g._connection + if self._session is not None: + self._session.close() except (ModuleNotFoundError, AssertionError): - - # If no connection yet - if not hasattr(self, "_connection"): - self._connection = self._engine.connect() - - # Use this connection - connection = self._connection + pass # Catch SQLAlchemy warnings with warnings.catch_warnings(): @@ -317,7 +325,7 @@ def shutdown_session(exception=None): _statement = "".join([str(bytes) if token.ttype == sqlparse.tokens.Other else str(token) for token in tokens]) # Execute statement - result = connection.execute(sqlalchemy.text(statement)) + result = session.execute(sqlalchemy.text(statement)) # Return value ret = True @@ -346,7 +354,7 @@ def shutdown_session(exception=None): elif command == "INSERT": if self._engine.url.get_backend_name() in ["postgres", "postgresql"]: try: - result = connection.execute("SELECT LASTVAL()") + result = session.execute("SELECT LASTVAL()") ret = result.first()[0] except sqlalchemy.exc.OperationalError: # If lastval is not yet defined in this session ret = None @@ -357,6 +365,18 @@ def shutdown_session(exception=None): elif command in ["DELETE", "UPDATE"]: ret = result.rowcount + # If COMMIT or ROLLBACK, turn on autocommit mode + elif command in ["COMMIT", "ROLLBACK"] and "TO" not in statement: + session.close() + self._session = None + + + # If autocommit is on, commit and close + if self._session is None and command not in ["COMMIT", "ROLLBACK"]: + if command not in ["SELECT"]: + session.commit() + session.close() + # If constraint violated, return None except sqlalchemy.exc.IntegrityError as e: self._logger.debug(termcolor.colored(statement, "yellow")) diff --git a/tests/flask/application.py b/tests/flask/application.py index 939a8f9..e3f0768 100644 --- a/tests/flask/application.py +++ b/tests/flask/application.py @@ -1,3 +1,5 @@ +import logging +import os import requests import sys from flask import Flask, render_template @@ -9,14 +11,64 @@ app = Flask(__name__) -db = cs50.SQL("sqlite:///../sqlite.db") +logging.disable(logging.CRITICAL) +os.environ["WERKZEUG_RUN_MAIN"] = "true" + +db = cs50.SQL("sqlite:///../test.db") @app.route("/") def index(): - db.execute("SELECT 1") """ def f(): res = requests.get("cs50.harvard.edu") f() """ return render_template("index.html") + +@app.route("/autocommit") +def autocommit(): + db.execute("INSERT INTO test (val) VALUES (?)", "def") + db2 = cs50.SQL(db.url) + ret = db2.execute("SELECT val FROM test WHERE val=?", "def") + return str(ret == [{"val": "def"}]) + +@app.route("/create") +def create(): + ret = db.execute("CREATE TABLE test (id INTEGER PRIMARY KEY AUTOINCREMENT, val VARCHAR(16))") + return str(ret) + +@app.route("/delete") +def delete(): + ret = db.execute("DELETE FROM test") + return str(ret > 0) + +@app.route("/drop") +def drop(): + ret = db.execute("DROP TABLE test") + return str(ret) + +@app.route("/insert") +def insert(): + ret = db.execute("INSERT INTO test (val) VALUES (?)", "abc") + return str(ret > 0) + +@app.route("/multiple_connections") +def multiple_connections(): + ctx = len(app.teardown_appcontext_funcs) + db1 = cs50.SQL(db.url) + td1 = (len(app.teardown_appcontext_funcs) == ctx + 1) + db2 = cs50.SQL(db.url) + td2 = (len(app.teardown_appcontext_funcs) == ctx + 2) + return str(td1 and td2) + +@app.route("/select") +def select(): + ret = db.execute("SELECT val FROM test") + return str(ret == [{"val": "abc"}]) + +@app.route("/single_teardown") +def single_teardown(): + db.execute("SELECT * FROM test") + ctx = len(app.teardown_appcontext_funcs) + db.execute("SELECT COUNT(id) FROM test") + return str(ctx == len(app.teardown_appcontext_funcs)) diff --git a/tests/flask/test.py b/tests/flask/test.py new file mode 100644 index 0000000..9a4134b --- /dev/null +++ b/tests/flask/test.py @@ -0,0 +1,49 @@ +from application import app +import logging +import requests +import sys +import threading +import time +import unittest + + +def request(route): + r = requests.get("http://localhost:5000/{}".format(route)) + return r.text == "True" + +class FlaskTests(unittest.TestCase): + + def test__create(self): + self.assertTrue(request("create")) + + def test_autocommit(self): + self.assertTrue(request("autocommit")) + + def test_delete(self): + self.assertTrue(request("delete")) + + def test_insert(self): + self.assertTrue(request("insert")) + + def test_multiple_connections(self): + self.assertTrue(request("multiple_connections")) + + def test_select(self): + self.assertTrue(request("select")) + + def test_single_teardown(self): + self.assertTrue(request("single_teardown")) + + def test_zdrop(self): + self.assertTrue(request("drop")) + + +if __name__ == "__main__": + t = threading.Thread(target=app.run, daemon=True) + t.start() + + suite = unittest.TestSuite([ + unittest.TestLoader().loadTestsFromTestCase(FlaskTests) + ]) + + sys.exit(not unittest.TextTestRunner(verbosity=2).run(suite).wasSuccessful()) diff --git a/tests/sql.py b/tests/sql.py index 9ad463f..57974a6 100644 --- a/tests/sql.py +++ b/tests/sql.py @@ -115,11 +115,22 @@ def test_blob(self): self.db.execute("INSERT INTO cs50(bin) VALUES(:bin)", bin=row["bin"]) self.assertEqual(self.db.execute("SELECT id, bin FROM cs50"), rows) + def test_autocommit(self): + self.assertEqual(self.db.execute("INSERT INTO cs50(val) VALUES('foo')"), 1) + self.assertEqual(self.db.execute("INSERT INTO cs50(val) VALUES('bar')"), 2) + + # Load a new database instance to confirm the INSERTs were committed + db2 = SQL(self.db.url) + self.assertEqual(db2.execute("DELETE FROM cs50 WHERE id < 3"), 2) + def test_commit(self): self.db.execute("BEGIN") self.db.execute("INSERT INTO cs50 (val) VALUES('foo')") self.db.execute("COMMIT") - self.assertEqual(self.db.execute("SELECT val FROM cs50"), [{"val": "foo"}]) + + # Load a new database instance to confirm the INSERT was committed + db2 = SQL(self.db.url) + self.assertEqual(db2.execute("SELECT val FROM cs50"), [{"val": "foo"}]) def test_rollback(self): self.db.execute("BEGIN") @@ -128,6 +139,17 @@ def test_rollback(self): self.db.execute("ROLLBACK") self.assertEqual(self.db.execute("SELECT val FROM cs50"), []) + def test_savepoint(self): + self.db.execute("BEGIN") + self.db.execute("INSERT INTO cs50 (val) VALUES('foo')") + self.db.execute("SAVEPOINT sp1") + self.db.execute("INSERT INTO cs50 (val) VALUES('bar')") + self.assertEqual(self.db.execute("SELECT val FROM cs50"), [{"val": "foo"}, {"val": "bar"}]) + self.db.execute("ROLLBACK TO sp1") + self.assertEqual(self.db.execute("SELECT val FROM cs50"), [{"val": "foo"}]) + self.db.execute("ROLLBACK") + self.assertEqual(self.db.execute("SELECT val FROM cs50"), []) + def tearDown(self): self.db.execute("DROP TABLE cs50") self.db.execute("DROP TABLE IF EXISTS foo") @@ -146,6 +168,7 @@ class MySQLTests(SQLTests): @classmethod def setUpClass(self): self.db = SQL("mysql://root@localhost/test") + print("\nMySQL tests") def setUp(self): self.db.execute("CREATE TABLE cs50 (id INTEGER NOT NULL AUTO_INCREMENT, val VARCHAR(16), bin BLOB, PRIMARY KEY (id))") @@ -153,7 +176,8 @@ def setUp(self): class PostgresTests(SQLTests): @classmethod def setUpClass(self): - self.db = SQL("postgresql://postgres@localhost/test") + self.db = SQL("postgresql://root:test@localhost/test") + print("\nPOSTGRES tests") def setUp(self): self.db.execute("CREATE TABLE cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16), bin BYTEA)") @@ -166,6 +190,7 @@ class SQLiteTests(SQLTests): def setUpClass(self): open("test.db", "w").close() self.db = SQL("sqlite:///test.db") + print("\nSQLite tests") def setUp(self): self.db.execute("CREATE TABLE cs50(id INTEGER PRIMARY KEY, val TEXT, bin BLOB)") From 0dff7160cc684ffb03c3baeb092e803ddff2abe0 Mon Sep 17 00:00:00 2001 From: Joshua Archibald <jarchibald121@gmail.com> Date: Thu, 4 Jun 2020 14:29:38 -0500 Subject: [PATCH 2/6] Style fixes. Minor design improvements, including removing SQL class URL variable, and always committing session so as to release locks. --- src/cs50/sql.py | 21 +++++++++------------ tests/flask/application.py | 12 +++++++----- tests/flask/test.py | 2 +- tests/sql.py | 13 ++++++++----- 4 files changed, 25 insertions(+), 23 deletions(-) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index 55cf058..8acf194 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -43,7 +43,7 @@ def __init__(self, url, **kwargs): import os import re import sqlalchemy - import sqlalchemy.orm as orm + import sqlalchemy.orm import sqlite3 # Get logger @@ -57,14 +57,11 @@ def __init__(self, url, **kwargs): if not os.path.isfile(matches.group(1)): raise RuntimeError("not a file: {}".format(matches.group(1))) - # Record the URL (used in testing) - self.url = url - # Create engine, disabling SQLAlchemy's own autocommit mode, raising exception if back end's module not installed self._engine = sqlalchemy.create_engine(url, **kwargs).execution_options(autocommit=False) # Create a variable to hold the session. If None, autocommit is on. - self.Session = orm.sessionmaker(bind=self._engine) + self._Session = sqlalchemy.orm.session.sessionmaker(bind=self._engine) self._session = None # Listener for connections @@ -101,6 +98,7 @@ def __del__(self): """Close database session and connection.""" if self._session is not None: self._session.close() + self._session = None @_enable_logging def execute(self, sql, *args, **kwargs): @@ -134,11 +132,11 @@ def execute(self, sql, *args, **kwargs): command = token.value.upper() break - # Begin a new transaction session, if done manually + # Begin a new session, if transaction started by caller (not using autocommit) elif token.value.upper() in ["BEGIN", "START"]: if self._session is not None: self._session.close() - self._session = self.Session() + self._session = self._Session() else: command = None @@ -288,7 +286,7 @@ def execute(self, sql, *args, **kwargs): # Connect to database (for transactions' sake) session = self._session if session is None: - session = self.Session() + session = self._Session() # Set up a Flask app teardown function to close session at teardown try: @@ -303,11 +301,12 @@ def execute(self, sql, *args, **kwargs): if not hasattr(self, "teardown_appcontext_added"): self.teardown_appcontext_added = True - # Register shutdown_session on app context teardown @flask.current_app.teardown_appcontext def shutdown_session(exception=None): + """Close any existing session on app context teardown.""" if self._session is not None: self._session.close() + self._session = None except (ModuleNotFoundError, AssertionError): pass @@ -370,11 +369,9 @@ def shutdown_session(exception=None): session.close() self._session = None - # If autocommit is on, commit and close if self._session is None and command not in ["COMMIT", "ROLLBACK"]: - if command not in ["SELECT"]: - session.commit() + session.commit() session.close() # If constraint violated, return None diff --git a/tests/flask/application.py b/tests/flask/application.py index e3f0768..404b1d4 100644 --- a/tests/flask/application.py +++ b/tests/flask/application.py @@ -2,19 +2,21 @@ import os import requests import sys -from flask import Flask, render_template sys.path.insert(0, "../../src") import cs50 import cs50.flask +from flask import Flask, render_template + app = Flask(__name__) logging.disable(logging.CRITICAL) os.environ["WERKZEUG_RUN_MAIN"] = "true" -db = cs50.SQL("sqlite:///../test.db") +db_url = "sqlite:///../test.db" +db = cs50.SQL(db_url) @app.route("/") def index(): @@ -28,7 +30,7 @@ def f(): @app.route("/autocommit") def autocommit(): db.execute("INSERT INTO test (val) VALUES (?)", "def") - db2 = cs50.SQL(db.url) + db2 = cs50.SQL(db_url) ret = db2.execute("SELECT val FROM test WHERE val=?", "def") return str(ret == [{"val": "def"}]) @@ -55,9 +57,9 @@ def insert(): @app.route("/multiple_connections") def multiple_connections(): ctx = len(app.teardown_appcontext_funcs) - db1 = cs50.SQL(db.url) + db1 = cs50.SQL(db_url) td1 = (len(app.teardown_appcontext_funcs) == ctx + 1) - db2 = cs50.SQL(db.url) + db2 = cs50.SQL(db_url) td2 = (len(app.teardown_appcontext_funcs) == ctx + 2) return str(td1 and td2) diff --git a/tests/flask/test.py b/tests/flask/test.py index 9a4134b..0b084d6 100644 --- a/tests/flask/test.py +++ b/tests/flask/test.py @@ -1,4 +1,3 @@ -from application import app import logging import requests import sys @@ -6,6 +5,7 @@ import time import unittest +from application import app def request(route): r = requests.get("http://localhost:5000/{}".format(route)) diff --git a/tests/sql.py b/tests/sql.py index 57974a6..7694ad9 100644 --- a/tests/sql.py +++ b/tests/sql.py @@ -120,7 +120,7 @@ def test_autocommit(self): self.assertEqual(self.db.execute("INSERT INTO cs50(val) VALUES('bar')"), 2) # Load a new database instance to confirm the INSERTs were committed - db2 = SQL(self.db.url) + db2 = SQL(self.db_url) self.assertEqual(db2.execute("DELETE FROM cs50 WHERE id < 3"), 2) def test_commit(self): @@ -129,7 +129,7 @@ def test_commit(self): self.db.execute("COMMIT") # Load a new database instance to confirm the INSERT was committed - db2 = SQL(self.db.url) + db2 = SQL(self.db_url) self.assertEqual(db2.execute("SELECT val FROM cs50"), [{"val": "foo"}]) def test_rollback(self): @@ -167,7 +167,8 @@ def tearDownClass(self): class MySQLTests(SQLTests): @classmethod def setUpClass(self): - self.db = SQL("mysql://root@localhost/test") + self.db_url = "mysql://root@localhost/test" + self.db = SQL(self.db_url) print("\nMySQL tests") def setUp(self): @@ -176,7 +177,8 @@ def setUp(self): class PostgresTests(SQLTests): @classmethod def setUpClass(self): - self.db = SQL("postgresql://root:test@localhost/test") + self.db_url = "postgresql://root:test@localhost/test" + self.db = SQL(self.db_url) print("\nPOSTGRES tests") def setUp(self): @@ -189,7 +191,8 @@ class SQLiteTests(SQLTests): @classmethod def setUpClass(self): open("test.db", "w").close() - self.db = SQL("sqlite:///test.db") + self.db_url = "sqlite:///test.db" + self.db = SQL(self.db_url) print("\nSQLite tests") def setUp(self): From c227d18b5200b4474eedc545b76f0143c1ca1e51 Mon Sep 17 00:00:00 2001 From: Joshua Archibald <jarchibald121@gmail.com> Date: Thu, 4 Jun 2020 14:35:17 -0500 Subject: [PATCH 3/6] Fix sql.py in tests for Travis CI. --- tests/sql.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/sql.py b/tests/sql.py index 7694ad9..8742702 100644 --- a/tests/sql.py +++ b/tests/sql.py @@ -177,7 +177,7 @@ def setUp(self): class PostgresTests(SQLTests): @classmethod def setUpClass(self): - self.db_url = "postgresql://root:test@localhost/test" + self.db_url = "postgresql://postgres@localhost/test" self.db = SQL(self.db_url) print("\nPOSTGRES tests") From ee4128311e8c3c6962d76f0b4c718b5eaecc5530 Mon Sep 17 00:00:00 2001 From: Joshua Archibald <jarchibald121@gmail.com> Date: Fri, 5 Jun 2020 17:23:25 -0500 Subject: [PATCH 4/6] Requested changes to code design, including some renaming, a new instance variable to track transaction status, and retaining session between calls to execute, among other things. --- src/cs50/sql.py | 49 ++++++++++++++++++++++++++----------------------- tests/sql.py | 14 +++++++++++++- 2 files changed, 39 insertions(+), 24 deletions(-) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index 8acf194..d8af011 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -63,6 +63,7 @@ def __init__(self, url, **kwargs): # Create a variable to hold the session. If None, autocommit is on. self._Session = sqlalchemy.orm.session.sessionmaker(bind=self._engine) self._session = None + self._in_transaction = False # Listener for connections def connect(dbapi_connection, connection_record): @@ -96,9 +97,7 @@ def connect(dbapi_connection, connection_record): def __del__(self): """Close database session and connection.""" - if self._session is not None: - self._session.close() - self._session = None + self._close_session() @_enable_logging def execute(self, sql, *args, **kwargs): @@ -134,9 +133,9 @@ def execute(self, sql, *args, **kwargs): # Begin a new session, if transaction started by caller (not using autocommit) elif token.value.upper() in ["BEGIN", "START"]: - if self._session is not None: - self._session.close() - self._session = self._Session() + if self._in_transaction: + raise RuntimeError("transaction already open") + self._in_transaction = True else: command = None @@ -284,9 +283,8 @@ def execute(self, sql, *args, **kwargs): statement = "".join([str(token) for token in tokens]) # Connect to database (for transactions' sake) - session = self._session - if session is None: - session = self._Session() + if self._session is None: + self._session = self._Session() # Set up a Flask app teardown function to close session at teardown try: @@ -304,9 +302,7 @@ def execute(self, sql, *args, **kwargs): @flask.current_app.teardown_appcontext def shutdown_session(exception=None): """Close any existing session on app context teardown.""" - if self._session is not None: - self._session.close() - self._session = None + self._close_session() except (ModuleNotFoundError, AssertionError): pass @@ -323,8 +319,14 @@ def shutdown_session(exception=None): # Join tokens into statement, abbreviating binary data as <class 'bytes'> _statement = "".join([str(bytes) if token.ttype == sqlparse.tokens.Other else str(token) for token in tokens]) + # If COMMIT or ROLLBACK, turn on autocommit mode + if command in ["COMMIT", "ROLLBACK"] and "TO" not in statement: + if not self._in_transaction: + raise RuntimeError("transactions must be initiated with BEGIN or START TRANSACTION") + self._in_transaction = False + # Execute statement - result = session.execute(sqlalchemy.text(statement)) + result = self._session.execute(sqlalchemy.text(statement)) # Return value ret = True @@ -353,7 +355,7 @@ def shutdown_session(exception=None): elif command == "INSERT": if self._engine.url.get_backend_name() in ["postgres", "postgresql"]: try: - result = session.execute("SELECT LASTVAL()") + result = self._session.execute("SELECT LASTVAL()") ret = result.first()[0] except sqlalchemy.exc.OperationalError: # If lastval is not yet defined in this session ret = None @@ -364,15 +366,9 @@ def shutdown_session(exception=None): elif command in ["DELETE", "UPDATE"]: ret = result.rowcount - # If COMMIT or ROLLBACK, turn on autocommit mode - elif command in ["COMMIT", "ROLLBACK"] and "TO" not in statement: - session.close() - self._session = None - - # If autocommit is on, commit and close - if self._session is None and command not in ["COMMIT", "ROLLBACK"]: - session.commit() - session.close() + # If autocommit is on, commit + if not self._in_transaction: + self._session.commit() # If constraint violated, return None except sqlalchemy.exc.IntegrityError as e: @@ -393,6 +389,13 @@ def shutdown_session(exception=None): self._logger.debug(termcolor.colored(_statement, "green")) return ret + def _close_session(self): + """Closes any existing session and resets instance variables.""" + if self._session is not None: + self._session.close() + self._session = None + self._in_transaction = False + def _escape(self, value): """ Escapes value using engine's conversion function. diff --git a/tests/sql.py b/tests/sql.py index 8742702..661920e 100644 --- a/tests/sql.py +++ b/tests/sql.py @@ -123,6 +123,12 @@ def test_autocommit(self): db2 = SQL(self.db_url) self.assertEqual(db2.execute("DELETE FROM cs50 WHERE id < 3"), 2) + def test_commit_no_transaction(self): + with self.assertRaises(RuntimeError): + self.db.execute("COMMIT") + with self.assertRaises(RuntimeError): + self.db.execute("ROLLBACK") + def test_commit(self): self.db.execute("BEGIN") self.db.execute("INSERT INTO cs50 (val) VALUES('foo')") @@ -132,6 +138,12 @@ def test_commit(self): db2 = SQL(self.db_url) self.assertEqual(db2.execute("SELECT val FROM cs50"), [{"val": "foo"}]) + def test_double_begin(self): + self.db.execute("BEGIN") + with self.assertRaises(RuntimeError): + self.db.execute("BEGIN") + self.db.execute("ROLLBACK") + def test_rollback(self): self.db.execute("BEGIN") self.db.execute("INSERT INTO cs50 (val) VALUES('foo')") @@ -177,7 +189,7 @@ def setUp(self): class PostgresTests(SQLTests): @classmethod def setUpClass(self): - self.db_url = "postgresql://postgres@localhost/test" + self.db_url = "postgresql://root:test@localhost/test" self.db = SQL(self.db_url) print("\nPOSTGRES tests") From c60d67d908e507c2fedfb3fb44829e07571ca7a6 Mon Sep 17 00:00:00 2001 From: Joshua Archibald <jarchibald121@gmail.com> Date: Fri, 5 Jun 2020 17:24:25 -0500 Subject: [PATCH 5/6] Messed up the tests for Travis CI again. Fixed. --- tests/sql.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/sql.py b/tests/sql.py index 661920e..95301eb 100644 --- a/tests/sql.py +++ b/tests/sql.py @@ -189,7 +189,7 @@ def setUp(self): class PostgresTests(SQLTests): @classmethod def setUpClass(self): - self.db_url = "postgresql://root:test@localhost/test" + self.db_url = "postgresql://postgres@localhost/test" self.db = SQL(self.db_url) print("\nPOSTGRES tests") From dac2ae88c7533e5aad4e9671395c591692694bed Mon Sep 17 00:00:00 2001 From: Joshua Archibald <jarchibald121@gmail.com> Date: Wed, 10 Jun 2020 23:32:18 -0500 Subject: [PATCH 6/6] Stylistic changes. --- src/cs50/sql.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index d8af011..f47e2b6 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -135,6 +135,7 @@ def execute(self, sql, *args, **kwargs): elif token.value.upper() in ["BEGIN", "START"]: if self._in_transaction: raise RuntimeError("transaction already open") + self._in_transaction = True else: command = None @@ -296,8 +297,8 @@ def execute(self, sql, *args, **kwargs): assert flask.current_app # Disconnect later - but only once - if not hasattr(self, "teardown_appcontext_added"): - self.teardown_appcontext_added = True + if not hasattr(self, "_teardown_appcontext_added"): + self._teardown_appcontext_added = True @flask.current_app.teardown_appcontext def shutdown_session(exception=None): @@ -320,9 +321,10 @@ def shutdown_session(exception=None): _statement = "".join([str(bytes) if token.ttype == sqlparse.tokens.Other else str(token) for token in tokens]) # If COMMIT or ROLLBACK, turn on autocommit mode - if command in ["COMMIT", "ROLLBACK"] and "TO" not in statement: + if command in ["COMMIT", "ROLLBACK"] and "TO" not in (token.value for token in tokens): if not self._in_transaction: raise RuntimeError("transactions must be initiated with BEGIN or START TRANSACTION") + self._in_transaction = False # Execute statement @@ -393,6 +395,7 @@ def _close_session(self): """Closes any existing session and resets instance variables.""" if self._session is not None: self._session.close() + self._session = None self._in_transaction = False