Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transaction sessions; appcontext teardown fix; multiple databases in request fix. #122

Merged
merged 6 commits into from
Jun 12, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 46 additions & 23 deletions src/cs50/sql.py
Original file line number Diff line number Diff line change
@@ -43,6 +43,7 @@ def __init__(self, url, **kwargs):
import os
import re
import sqlalchemy
import sqlalchemy.orm
import sqlite3

# Get logger
@@ -59,6 +60,11 @@ def __init__(self, url, **kwargs):
# Create engine, disabling SQLAlchemy's own autocommit mode, raising exception if back end's module not installed
self._engine = sqlalchemy.create_engine(url, **kwargs).execution_options(autocommit=False)

# Create a variable to hold the session. If None, autocommit is on.
self._Session = sqlalchemy.orm.session.sessionmaker(bind=self._engine)
self._session = None
self._in_transaction = False

# Listener for connections
def connect(dbapi_connection, connection_record):

@@ -90,9 +96,8 @@ def connect(dbapi_connection, connection_record):
self._logger.disabled = disabled

def __del__(self):
"""Close database connection."""
if hasattr(self, "_connection"):
self._connection.close()
"""Close database session and connection."""
self._close_session()

@_enable_logging
def execute(self, sql, *args, **kwargs):
@@ -125,6 +130,13 @@ def execute(self, sql, *args, **kwargs):
if token.ttype in [sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML]:
command = token.value.upper()
break

# Begin a new session, if transaction started by caller (not using autocommit)
elif token.value.upper() in ["BEGIN", "START"]:
if self._in_transaction:
raise RuntimeError("transaction already open")

self._in_transaction = True
else:
command = None

@@ -272,6 +284,10 @@ def execute(self, sql, *args, **kwargs):
statement = "".join([str(token) for token in tokens])

# Connect to database (for transactions' sake)
if self._session is None:
self._session = self._Session()

# Set up a Flask app teardown function to close session at teardown
try:

# Infer whether Flask is installed
@@ -280,29 +296,17 @@ def execute(self, sql, *args, **kwargs):
# Infer whether app is defined
assert flask.current_app

# If no connection for app's current request yet
if not hasattr(flask.g, "_connection"):
# Disconnect later - but only once
if not hasattr(self, "_teardown_appcontext_added"):
self._teardown_appcontext_added = True

# Connect now
flask.g._connection = self._engine.connect()

# Disconnect later
@flask.current_app.teardown_appcontext
def shutdown_session(exception=None):
if hasattr(flask.g, "_connection"):
flask.g._connection.close()

# Use this connection
connection = flask.g._connection
"""Close any existing session on app context teardown."""
self._close_session()

except (ModuleNotFoundError, AssertionError):

# If no connection yet
if not hasattr(self, "_connection"):
self._connection = self._engine.connect()

# Use this connection
connection = self._connection
pass

# Catch SQLAlchemy warnings
with warnings.catch_warnings():
@@ -316,8 +320,15 @@ def shutdown_session(exception=None):
# Join tokens into statement, abbreviating binary data as <class 'bytes'>
_statement = "".join([str(bytes) if token.ttype == sqlparse.tokens.Other else str(token) for token in tokens])

# If COMMIT or ROLLBACK, turn on autocommit mode
if command in ["COMMIT", "ROLLBACK"] and "TO" not in (token.value for token in tokens):
if not self._in_transaction:
raise RuntimeError("transactions must be initiated with BEGIN or START TRANSACTION")

self._in_transaction = False

# Execute statement
result = connection.execute(sqlalchemy.text(statement))
result = self._session.execute(sqlalchemy.text(statement))

# Return value
ret = True
@@ -346,7 +357,7 @@ def shutdown_session(exception=None):
elif command == "INSERT":
if self._engine.url.get_backend_name() in ["postgres", "postgresql"]:
try:
result = connection.execute("SELECT LASTVAL()")
result = self._session.execute("SELECT LASTVAL()")
ret = result.first()[0]
except sqlalchemy.exc.OperationalError: # If lastval is not yet defined in this session
ret = None
@@ -357,6 +368,10 @@ def shutdown_session(exception=None):
elif command in ["DELETE", "UPDATE"]:
ret = result.rowcount

# If autocommit is on, commit
if not self._in_transaction:
self._session.commit()

# If constraint violated, return None
except sqlalchemy.exc.IntegrityError as e:
self._logger.debug(termcolor.colored(statement, "yellow"))
@@ -376,6 +391,14 @@ def shutdown_session(exception=None):
self._logger.debug(termcolor.colored(_statement, "green"))
return ret

def _close_session(self):
"""Closes any existing session and resets instance variables."""
if self._session is not None:
self._session.close()

self._session = None
self._in_transaction = False

def _escape(self, value):
"""
Escapes value using engine's conversion function.
60 changes: 57 additions & 3 deletions tests/flask/application.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,76 @@
import logging
import os
import requests
import sys
from flask import Flask, render_template

sys.path.insert(0, "../../src")

import cs50
import cs50.flask

from flask import Flask, render_template

app = Flask(__name__)

db = cs50.SQL("sqlite:///../sqlite.db")
logging.disable(logging.CRITICAL)
os.environ["WERKZEUG_RUN_MAIN"] = "true"

db_url = "sqlite:///../test.db"
db = cs50.SQL(db_url)

@app.route("/")
def index():
db.execute("SELECT 1")
"""
def f():
res = requests.get("cs50.harvard.edu")
f()
"""
return render_template("index.html")

@app.route("/autocommit")
def autocommit():
db.execute("INSERT INTO test (val) VALUES (?)", "def")
db2 = cs50.SQL(db_url)
ret = db2.execute("SELECT val FROM test WHERE val=?", "def")
return str(ret == [{"val": "def"}])

@app.route("/create")
def create():
ret = db.execute("CREATE TABLE test (id INTEGER PRIMARY KEY AUTOINCREMENT, val VARCHAR(16))")
return str(ret)

@app.route("/delete")
def delete():
ret = db.execute("DELETE FROM test")
return str(ret > 0)

@app.route("/drop")
def drop():
ret = db.execute("DROP TABLE test")
return str(ret)

@app.route("/insert")
def insert():
ret = db.execute("INSERT INTO test (val) VALUES (?)", "abc")
return str(ret > 0)

@app.route("/multiple_connections")
def multiple_connections():
ctx = len(app.teardown_appcontext_funcs)
db1 = cs50.SQL(db_url)
td1 = (len(app.teardown_appcontext_funcs) == ctx + 1)
db2 = cs50.SQL(db_url)
td2 = (len(app.teardown_appcontext_funcs) == ctx + 2)
return str(td1 and td2)

@app.route("/select")
def select():
ret = db.execute("SELECT val FROM test")
return str(ret == [{"val": "abc"}])

@app.route("/single_teardown")
def single_teardown():
db.execute("SELECT * FROM test")
ctx = len(app.teardown_appcontext_funcs)
db.execute("SELECT COUNT(id) FROM test")
return str(ctx == len(app.teardown_appcontext_funcs))
49 changes: 49 additions & 0 deletions tests/flask/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import logging
import requests
import sys
import threading
import time
import unittest

from application import app

def request(route):
r = requests.get("http://localhost:5000/{}".format(route))
return r.text == "True"

class FlaskTests(unittest.TestCase):

def test__create(self):
self.assertTrue(request("create"))

def test_autocommit(self):
self.assertTrue(request("autocommit"))

def test_delete(self):
self.assertTrue(request("delete"))

def test_insert(self):
self.assertTrue(request("insert"))

def test_multiple_connections(self):
self.assertTrue(request("multiple_connections"))

def test_select(self):
self.assertTrue(request("select"))

def test_single_teardown(self):
self.assertTrue(request("single_teardown"))

def test_zdrop(self):
self.assertTrue(request("drop"))


if __name__ == "__main__":
t = threading.Thread(target=app.run, daemon=True)
t.start()

suite = unittest.TestSuite([
unittest.TestLoader().loadTestsFromTestCase(FlaskTests)
])

sys.exit(not unittest.TextTestRunner(verbosity=2).run(suite).wasSuccessful())
48 changes: 44 additions & 4 deletions tests/sql.py
Original file line number Diff line number Diff line change
@@ -115,11 +115,34 @@ def test_blob(self):
self.db.execute("INSERT INTO cs50(bin) VALUES(:bin)", bin=row["bin"])
self.assertEqual(self.db.execute("SELECT id, bin FROM cs50"), rows)

def test_autocommit(self):
self.assertEqual(self.db.execute("INSERT INTO cs50(val) VALUES('foo')"), 1)
self.assertEqual(self.db.execute("INSERT INTO cs50(val) VALUES('bar')"), 2)

# Load a new database instance to confirm the INSERTs were committed
db2 = SQL(self.db_url)
self.assertEqual(db2.execute("DELETE FROM cs50 WHERE id < 3"), 2)

def test_commit_no_transaction(self):
with self.assertRaises(RuntimeError):
self.db.execute("COMMIT")
with self.assertRaises(RuntimeError):
self.db.execute("ROLLBACK")

def test_commit(self):
self.db.execute("BEGIN")
self.db.execute("INSERT INTO cs50 (val) VALUES('foo')")
self.db.execute("COMMIT")
self.assertEqual(self.db.execute("SELECT val FROM cs50"), [{"val": "foo"}])

# Load a new database instance to confirm the INSERT was committed
db2 = SQL(self.db_url)
self.assertEqual(db2.execute("SELECT val FROM cs50"), [{"val": "foo"}])

def test_double_begin(self):
self.db.execute("BEGIN")
with self.assertRaises(RuntimeError):
self.db.execute("BEGIN")
self.db.execute("ROLLBACK")

def test_rollback(self):
self.db.execute("BEGIN")
@@ -128,6 +151,17 @@ def test_rollback(self):
self.db.execute("ROLLBACK")
self.assertEqual(self.db.execute("SELECT val FROM cs50"), [])

def test_savepoint(self):
self.db.execute("BEGIN")
self.db.execute("INSERT INTO cs50 (val) VALUES('foo')")
self.db.execute("SAVEPOINT sp1")
self.db.execute("INSERT INTO cs50 (val) VALUES('bar')")
self.assertEqual(self.db.execute("SELECT val FROM cs50"), [{"val": "foo"}, {"val": "bar"}])
self.db.execute("ROLLBACK TO sp1")
self.assertEqual(self.db.execute("SELECT val FROM cs50"), [{"val": "foo"}])
self.db.execute("ROLLBACK")
self.assertEqual(self.db.execute("SELECT val FROM cs50"), [])

def tearDown(self):
self.db.execute("DROP TABLE cs50")
self.db.execute("DROP TABLE IF EXISTS foo")
@@ -145,15 +179,19 @@ def tearDownClass(self):
class MySQLTests(SQLTests):
@classmethod
def setUpClass(self):
self.db = SQL("mysql://root@localhost/test")
self.db_url = "mysql://root@localhost/test"
self.db = SQL(self.db_url)
print("\nMySQL tests")

def setUp(self):
self.db.execute("CREATE TABLE cs50 (id INTEGER NOT NULL AUTO_INCREMENT, val VARCHAR(16), bin BLOB, PRIMARY KEY (id))")

class PostgresTests(SQLTests):
@classmethod
def setUpClass(self):
self.db = SQL("postgresql://postgres@localhost/test")
self.db_url = "postgresql://postgres@localhost/test"
self.db = SQL(self.db_url)
print("\nPOSTGRES tests")

def setUp(self):
self.db.execute("CREATE TABLE cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16), bin BYTEA)")
@@ -165,7 +203,9 @@ class SQLiteTests(SQLTests):
@classmethod
def setUpClass(self):
open("test.db", "w").close()
self.db = SQL("sqlite:///test.db")
self.db_url = "sqlite:///test.db"
self.db = SQL(self.db_url)
print("\nSQLite tests")

def setUp(self):
self.db.execute("CREATE TABLE cs50(id INTEGER PRIMARY KEY, val TEXT, bin BLOB)")