diff --git a/src/cs50/sql.py b/src/cs50/sql.py index cd8ae88..f47e2b6 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -43,6 +43,7 @@ def __init__(self, url, **kwargs): import os import re import sqlalchemy + import sqlalchemy.orm import sqlite3 # Get logger @@ -59,6 +60,11 @@ def __init__(self, url, **kwargs): # Create engine, disabling SQLAlchemy's own autocommit mode, raising exception if back end's module not installed self._engine = sqlalchemy.create_engine(url, **kwargs).execution_options(autocommit=False) + # Create a variable to hold the session. If None, autocommit is on. + self._Session = sqlalchemy.orm.session.sessionmaker(bind=self._engine) + self._session = None + self._in_transaction = False + # Listener for connections def connect(dbapi_connection, connection_record): @@ -90,9 +96,8 @@ def connect(dbapi_connection, connection_record): self._logger.disabled = disabled def __del__(self): - """Close database connection.""" - if hasattr(self, "_connection"): - self._connection.close() + """Close database session and connection.""" + self._close_session() @_enable_logging def execute(self, sql, *args, **kwargs): @@ -125,6 +130,13 @@ def execute(self, sql, *args, **kwargs): if token.ttype in [sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML]: command = token.value.upper() break + + # Begin a new session, if transaction started by caller (not using autocommit) + elif token.value.upper() in ["BEGIN", "START"]: + if self._in_transaction: + raise RuntimeError("transaction already open") + + self._in_transaction = True else: command = None @@ -272,6 +284,10 @@ def execute(self, sql, *args, **kwargs): statement = "".join([str(token) for token in tokens]) # Connect to database (for transactions' sake) + if self._session is None: + self._session = self._Session() + + # Set up a Flask app teardown function to close session at teardown try: # Infer whether Flask is installed @@ -280,29 +296,17 @@ def execute(self, sql, *args, **kwargs): # Infer whether app is defined assert flask.current_app - # If no connection for app's current request yet - if not hasattr(flask.g, "_connection"): + # Disconnect later - but only once + if not hasattr(self, "_teardown_appcontext_added"): + self._teardown_appcontext_added = True - # Connect now - flask.g._connection = self._engine.connect() - - # Disconnect later @flask.current_app.teardown_appcontext def shutdown_session(exception=None): - if hasattr(flask.g, "_connection"): - flask.g._connection.close() - - # Use this connection - connection = flask.g._connection + """Close any existing session on app context teardown.""" + self._close_session() except (ModuleNotFoundError, AssertionError): - - # If no connection yet - if not hasattr(self, "_connection"): - self._connection = self._engine.connect() - - # Use this connection - connection = self._connection + pass # Catch SQLAlchemy warnings with warnings.catch_warnings(): @@ -316,8 +320,15 @@ def shutdown_session(exception=None): # Join tokens into statement, abbreviating binary data as <class 'bytes'> _statement = "".join([str(bytes) if token.ttype == sqlparse.tokens.Other else str(token) for token in tokens]) + # If COMMIT or ROLLBACK, turn on autocommit mode + if command in ["COMMIT", "ROLLBACK"] and "TO" not in (token.value for token in tokens): + if not self._in_transaction: + raise RuntimeError("transactions must be initiated with BEGIN or START TRANSACTION") + + self._in_transaction = False + # Execute statement - result = connection.execute(sqlalchemy.text(statement)) + result = self._session.execute(sqlalchemy.text(statement)) # Return value ret = True @@ -346,7 +357,7 @@ def shutdown_session(exception=None): elif command == "INSERT": if self._engine.url.get_backend_name() in ["postgres", "postgresql"]: try: - result = connection.execute("SELECT LASTVAL()") + result = self._session.execute("SELECT LASTVAL()") ret = result.first()[0] except sqlalchemy.exc.OperationalError: # If lastval is not yet defined in this session ret = None @@ -357,6 +368,10 @@ def shutdown_session(exception=None): elif command in ["DELETE", "UPDATE"]: ret = result.rowcount + # If autocommit is on, commit + if not self._in_transaction: + self._session.commit() + # If constraint violated, return None except sqlalchemy.exc.IntegrityError as e: self._logger.debug(termcolor.colored(statement, "yellow")) @@ -376,6 +391,14 @@ def shutdown_session(exception=None): self._logger.debug(termcolor.colored(_statement, "green")) return ret + def _close_session(self): + """Closes any existing session and resets instance variables.""" + if self._session is not None: + self._session.close() + + self._session = None + self._in_transaction = False + def _escape(self, value): """ Escapes value using engine's conversion function. diff --git a/tests/flask/application.py b/tests/flask/application.py index 939a8f9..404b1d4 100644 --- a/tests/flask/application.py +++ b/tests/flask/application.py @@ -1,22 +1,76 @@ +import logging +import os import requests import sys -from flask import Flask, render_template sys.path.insert(0, "../../src") import cs50 import cs50.flask +from flask import Flask, render_template + app = Flask(__name__) -db = cs50.SQL("sqlite:///../sqlite.db") +logging.disable(logging.CRITICAL) +os.environ["WERKZEUG_RUN_MAIN"] = "true" + +db_url = "sqlite:///../test.db" +db = cs50.SQL(db_url) @app.route("/") def index(): - db.execute("SELECT 1") """ def f(): res = requests.get("cs50.harvard.edu") f() """ return render_template("index.html") + +@app.route("/autocommit") +def autocommit(): + db.execute("INSERT INTO test (val) VALUES (?)", "def") + db2 = cs50.SQL(db_url) + ret = db2.execute("SELECT val FROM test WHERE val=?", "def") + return str(ret == [{"val": "def"}]) + +@app.route("/create") +def create(): + ret = db.execute("CREATE TABLE test (id INTEGER PRIMARY KEY AUTOINCREMENT, val VARCHAR(16))") + return str(ret) + +@app.route("/delete") +def delete(): + ret = db.execute("DELETE FROM test") + return str(ret > 0) + +@app.route("/drop") +def drop(): + ret = db.execute("DROP TABLE test") + return str(ret) + +@app.route("/insert") +def insert(): + ret = db.execute("INSERT INTO test (val) VALUES (?)", "abc") + return str(ret > 0) + +@app.route("/multiple_connections") +def multiple_connections(): + ctx = len(app.teardown_appcontext_funcs) + db1 = cs50.SQL(db_url) + td1 = (len(app.teardown_appcontext_funcs) == ctx + 1) + db2 = cs50.SQL(db_url) + td2 = (len(app.teardown_appcontext_funcs) == ctx + 2) + return str(td1 and td2) + +@app.route("/select") +def select(): + ret = db.execute("SELECT val FROM test") + return str(ret == [{"val": "abc"}]) + +@app.route("/single_teardown") +def single_teardown(): + db.execute("SELECT * FROM test") + ctx = len(app.teardown_appcontext_funcs) + db.execute("SELECT COUNT(id) FROM test") + return str(ctx == len(app.teardown_appcontext_funcs)) diff --git a/tests/flask/test.py b/tests/flask/test.py new file mode 100644 index 0000000..0b084d6 --- /dev/null +++ b/tests/flask/test.py @@ -0,0 +1,49 @@ +import logging +import requests +import sys +import threading +import time +import unittest + +from application import app + +def request(route): + r = requests.get("http://localhost:5000/{}".format(route)) + return r.text == "True" + +class FlaskTests(unittest.TestCase): + + def test__create(self): + self.assertTrue(request("create")) + + def test_autocommit(self): + self.assertTrue(request("autocommit")) + + def test_delete(self): + self.assertTrue(request("delete")) + + def test_insert(self): + self.assertTrue(request("insert")) + + def test_multiple_connections(self): + self.assertTrue(request("multiple_connections")) + + def test_select(self): + self.assertTrue(request("select")) + + def test_single_teardown(self): + self.assertTrue(request("single_teardown")) + + def test_zdrop(self): + self.assertTrue(request("drop")) + + +if __name__ == "__main__": + t = threading.Thread(target=app.run, daemon=True) + t.start() + + suite = unittest.TestSuite([ + unittest.TestLoader().loadTestsFromTestCase(FlaskTests) + ]) + + sys.exit(not unittest.TextTestRunner(verbosity=2).run(suite).wasSuccessful()) diff --git a/tests/sql.py b/tests/sql.py index 9ad463f..95301eb 100644 --- a/tests/sql.py +++ b/tests/sql.py @@ -115,11 +115,34 @@ def test_blob(self): self.db.execute("INSERT INTO cs50(bin) VALUES(:bin)", bin=row["bin"]) self.assertEqual(self.db.execute("SELECT id, bin FROM cs50"), rows) + def test_autocommit(self): + self.assertEqual(self.db.execute("INSERT INTO cs50(val) VALUES('foo')"), 1) + self.assertEqual(self.db.execute("INSERT INTO cs50(val) VALUES('bar')"), 2) + + # Load a new database instance to confirm the INSERTs were committed + db2 = SQL(self.db_url) + self.assertEqual(db2.execute("DELETE FROM cs50 WHERE id < 3"), 2) + + def test_commit_no_transaction(self): + with self.assertRaises(RuntimeError): + self.db.execute("COMMIT") + with self.assertRaises(RuntimeError): + self.db.execute("ROLLBACK") + def test_commit(self): self.db.execute("BEGIN") self.db.execute("INSERT INTO cs50 (val) VALUES('foo')") self.db.execute("COMMIT") - self.assertEqual(self.db.execute("SELECT val FROM cs50"), [{"val": "foo"}]) + + # Load a new database instance to confirm the INSERT was committed + db2 = SQL(self.db_url) + self.assertEqual(db2.execute("SELECT val FROM cs50"), [{"val": "foo"}]) + + def test_double_begin(self): + self.db.execute("BEGIN") + with self.assertRaises(RuntimeError): + self.db.execute("BEGIN") + self.db.execute("ROLLBACK") def test_rollback(self): self.db.execute("BEGIN") @@ -128,6 +151,17 @@ def test_rollback(self): self.db.execute("ROLLBACK") self.assertEqual(self.db.execute("SELECT val FROM cs50"), []) + def test_savepoint(self): + self.db.execute("BEGIN") + self.db.execute("INSERT INTO cs50 (val) VALUES('foo')") + self.db.execute("SAVEPOINT sp1") + self.db.execute("INSERT INTO cs50 (val) VALUES('bar')") + self.assertEqual(self.db.execute("SELECT val FROM cs50"), [{"val": "foo"}, {"val": "bar"}]) + self.db.execute("ROLLBACK TO sp1") + self.assertEqual(self.db.execute("SELECT val FROM cs50"), [{"val": "foo"}]) + self.db.execute("ROLLBACK") + self.assertEqual(self.db.execute("SELECT val FROM cs50"), []) + def tearDown(self): self.db.execute("DROP TABLE cs50") self.db.execute("DROP TABLE IF EXISTS foo") @@ -145,7 +179,9 @@ def tearDownClass(self): class MySQLTests(SQLTests): @classmethod def setUpClass(self): - self.db = SQL("mysql://root@localhost/test") + self.db_url = "mysql://root@localhost/test" + self.db = SQL(self.db_url) + print("\nMySQL tests") def setUp(self): self.db.execute("CREATE TABLE cs50 (id INTEGER NOT NULL AUTO_INCREMENT, val VARCHAR(16), bin BLOB, PRIMARY KEY (id))") @@ -153,7 +189,9 @@ def setUp(self): class PostgresTests(SQLTests): @classmethod def setUpClass(self): - self.db = SQL("postgresql://postgres@localhost/test") + self.db_url = "postgresql://postgres@localhost/test" + self.db = SQL(self.db_url) + print("\nPOSTGRES tests") def setUp(self): self.db.execute("CREATE TABLE cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16), bin BYTEA)") @@ -165,7 +203,9 @@ class SQLiteTests(SQLTests): @classmethod def setUpClass(self): open("test.db", "w").close() - self.db = SQL("sqlite:///test.db") + self.db_url = "sqlite:///test.db" + self.db = SQL(self.db_url) + print("\nSQLite tests") def setUp(self): self.db.execute("CREATE TABLE cs50(id INTEGER PRIMARY KEY, val TEXT, bin BLOB)")