Skip to content

Fixes support for transactions for SQLite, PostgreSQL, and MySQL #133

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Nov 23, 2020
71 changes: 71 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,74 @@ f = cs50.get_float();
i = cs50.get_int();
s = cs50.get_string();
```

## Testing

1. Run `cli50` in `python-cs50`.
1. Run `sudo su -`.
1. Run `apt install -y libmysqlclient-dev mysql-server postgresql`.
1. Run `pip3 install mysqlclient psycopg2-binary`.
1. In `/etc/mysql/mysql.conf.d/mysqld.cnf`, add `skip-grant-tables` under `[mysqld]`.
1. In `/etc/profile.d/cli.sh`, remove `valgrind` function for now.
1. Run `service mysql start`.
1. Run `mysql -e 'CREATE DATABASE IF NOT EXISTS test;'`.
1. In `/etc/postgresql/10/main/pg_hba.conf, change:
```
local all postgres peer
host all all 127.0.0.1/32 md5
```
to:
```
local all postgres trust
host all all 127.0.0.1/32 trust
```
1. Run `service postgresql start`.
1. Run `psql -c 'create database test;' -U postgres`.
1. Run `touch test.db`.

### Sample Tests

```
import cs50
db = cs50.SQL("sqlite:///foo.db")
db.execute("CREATE TABLE IF NOT EXISTS cs50 (id INTEGER PRIMARY KEY, val TEXT, bin BLOB)")
db.execute("INSERT INTO cs50 (val) VALUES('a')")
db.execute("INSERT INTO cs50 (val) VALUES('b')")
db.execute("BEGIN")
db.execute("INSERT INTO cs50 (val) VALUES('c')")
db.execute("INSERT INTO cs50 (val) VALUES('x')")
db.execute("INSERT INTO cs50 (val) VALUES('y')")
db.execute("ROLLBACK")
db.execute("INSERT INTO cs50 (val) VALUES('z')")
db.execute("COMMIT")

---

import cs50
db = cs50.SQL("mysql://root@localhost/test")
db.execute("CREATE TABLE IF NOT EXISTS cs50 (id INTEGER PRIMARY KEY, val TEXT, bin BLOB)")
db.execute("INSERT INTO cs50 (val) VALUES('a')")
db.execute("INSERT INTO cs50 (val) VALUES('b')")
db.execute("BEGIN")
db.execute("INSERT INTO cs50 (val) VALUES('c')")
db.execute("INSERT INTO cs50 (val) VALUES('x')")
db.execute("INSERT INTO cs50 (val) VALUES('y')")
db.execute("ROLLBACK")
db.execute("INSERT INTO cs50 (val) VALUES('z')")
db.execute("COMMIT")

---

import cs50
db = cs50.SQL("postgresql://postgres@localhost/test")
db.execute("CREATE TABLE IF NOT EXISTS cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16), bin BYTEA)")
db.execute("INSERT INTO cs50 (val) VALUES('a')")
db.execute("INSERT INTO cs50 (val) VALUES('b')")
db.execute("BEGIN")
db.execute("INSERT INTO cs50 (val) VALUES('c')")
db.execute("INSERT INTO cs50 (val) VALUES('x')")
db.execute("INSERT INTO cs50 (val) VALUES('y')")
db.execute("ROLLBACK")
db.execute("INSERT INTO cs50 (val) VALUES('z')")
db.execute("COMMIT")
```
9 changes: 5 additions & 4 deletions src/cs50/flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,21 @@ def _wrap_flask(f):

f.logging.default_handler.formatter.formatException = lambda exc_info: _formatException(*exc_info)

if os.getenv("CS50_IDE_TYPE") == "online":
if os.getenv("CS50_IDE_TYPE"):
from werkzeug.middleware.proxy_fix import ProxyFix
_flask_init_before = f.Flask.__init__
def _flask_init_after(self, *args, **kwargs):
_flask_init_before(self, *args, **kwargs)
self.wsgi_app = ProxyFix(self.wsgi_app, x_proto=1)
self.config["TEMPLATES_AUTO_RELOAD"] = True # Automatically reload templates
self.wsgi_app = ProxyFix(self.wsgi_app, x_proto=1) # For HTTPS-to-HTTP proxy
f.Flask.__init__ = _flask_init_after


# Flask was imported before cs50
# If Flask was imported before cs50
if "flask" in sys.modules:
_wrap_flask(sys.modules["flask"])

# Flask wasn't imported
# If Flask wasn't imported
else:
flask_loader = pkgutil.get_loader('flask')
if flask_loader:
Expand Down
55 changes: 41 additions & 14 deletions src/cs50/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ def __init__(self, url, **kwargs):
# Listener for connections
def connect(dbapi_connection, connection_record):

# Disable underlying API's own emitting of BEGIN and COMMIT
# Disable underlying API's own emitting of BEGIN and COMMIT so we can ourselves
# https://docs.sqlalchemy.org/en/13/dialects/sqlite.html#serializable-isolation-savepoints-transactional-ddl
dbapi_connection.isolation_level = None

# Enable foreign key constraints
Expand All @@ -71,6 +72,9 @@ def connect(dbapi_connection, connection_record):
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()

# Autocommit by default
self._autocommit = True

# Register listener
sqlalchemy.event.listen(self._engine, "connect", connect)

Expand All @@ -90,9 +94,14 @@ def connect(dbapi_connection, connection_record):
self._logger.disabled = disabled

def __del__(self):
"""Disconnect from database."""
self._disconnect()

def _disconnect(self):
"""Close database connection."""
if hasattr(self, "_connection"):
self._connection.close()
delattr(self, "_connection")

@_enable_logging
def execute(self, sql, *args, **kwargs):
Expand All @@ -107,7 +116,7 @@ def execute(self, sql, *args, **kwargs):
import warnings

# Parse statement, stripping comments and then leading/trailing whitespace
statements = sqlparse.parse(sqlparse.format(sql, strip_comments=True).strip())
statements = sqlparse.parse(sqlparse.format(sql, keyword_case="upper", strip_comments=True).strip())

# Allow only one statement at a time, since SQLite doesn't support multiple
# https://docs.python.org/3/library/sqlite3.html#sqlite3.Cursor.execute
Expand All @@ -122,9 +131,10 @@ def execute(self, sql, *args, **kwargs):

# Infer command from (unflattened) statement
for token in statements[0]:
if token.ttype in [sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML]:
command = token.value.upper()
break
if token.ttype in [sqlparse.tokens.Keyword, sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML]:
if token.value in ["BEGIN", "DELETE", "INSERT", "SELECT", "START", "UPDATE"]:
command = token.value
break
else:
command = None

Expand Down Expand Up @@ -271,7 +281,7 @@ def execute(self, sql, *args, **kwargs):
# Join tokens into statement
statement = "".join([str(token) for token in tokens])

# Connect to database (for transactions' sake)
# Connect to database
try:

# Infer whether Flask is installed
Expand All @@ -280,19 +290,23 @@ def execute(self, sql, *args, **kwargs):
# Infer whether app is defined
assert flask.current_app

# If no connection for app's current request yet
# If new context
if not hasattr(flask.g, "_connection"):

# Connect now
flask.g._connection = self._engine.connect()
# Ready to connect
flask.g._connection = None

# Disconnect later
@flask.current_app.teardown_appcontext
def shutdown_session(exception=None):
if hasattr(flask.g, "_connection"):
if flask.g._connection:
flask.g._connection.close()

# Use this connection
# If no connection for context yet
if not flask.g._connection:
flas.g._connection = self._engine.connect()

# Use context's connection
connection = flask.g._connection

except (ModuleNotFoundError, AssertionError):
Expand All @@ -316,8 +330,20 @@ def shutdown_session(exception=None):
# Join tokens into statement, abbreviating binary data as <class 'bytes'>
_statement = "".join([str(bytes) if token.ttype == sqlparse.tokens.Other else str(token) for token in tokens])

# Check for start of transaction
if command in ["BEGIN", "START"]:
self._autocommit = False

# Execute statement
if self._autocommit:
connection.execute(sqlalchemy.text("BEGIN"))
result = connection.execute(sqlalchemy.text(statement))
if self._autocommit:
connection.execute(sqlalchemy.text("COMMIT"))

# Check for end of transaction
if command in ["COMMIT", "ROLLBACK"]:
self._autocommit = True

# Return value
ret = True
Expand Down Expand Up @@ -360,12 +386,13 @@ def shutdown_session(exception=None):
# If constraint violated, return None
except sqlalchemy.exc.IntegrityError as e:
self._logger.debug(termcolor.colored(statement, "yellow"))
e = RuntimeError(e.orig)
e = ValueError(e.orig)
e.__cause__ = None
raise e

# If user errror
except sqlalchemy.exc.OperationalError as e:
# If user error
except (sqlalchemy.exc.OperationalError, sqlalchemy.exc.ProgrammingError) as e:
self._disconnect()
self._logger.debug(termcolor.colored(statement, "red"))
e = RuntimeError(e.orig)
e.__cause__ = None
Expand Down
48 changes: 48 additions & 0 deletions tests/foo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import logging
import sys

sys.path.insert(0, "../src")

import cs50

"""
db = cs50.SQL("sqlite:///foo.db")

logging.getLogger("cs50").disabled = False

#db.execute("SELECT ? FROM ? ORDER BY ?", "a", "tbl", "c")
db.execute("CREATE TABLE IF NOT EXISTS bar (firstname STRING)")

db.execute("INSERT INTO bar VALUES (?)", "baz")
db.execute("INSERT INTO bar VALUES (?)", "qux")
db.execute("SELECT * FROM bar WHERE firstname IN (?)", ("baz", "qux"))
db.execute("DELETE FROM bar")
"""

db = cs50.SQL("postgresql://postgres@localhost/test")

"""
print(db.execute("DROP TABLE IF EXISTS cs50"))
print(db.execute("CREATE TABLE cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16), bin BYTEA)"))
print(db.execute("INSERT INTO cs50 (val) VALUES('foo')"))
print(db.execute("SELECT * FROM cs50"))

print(db.execute("DROP TABLE IF EXISTS cs50"))
print(db.execute("CREATE TABLE cs50 (val VARCHAR(16), bin BYTEA)"))
print(db.execute("INSERT INTO cs50 (val) VALUES('foo')"))
print(db.execute("SELECT * FROM cs50"))
"""

print(db.execute("DROP TABLE IF EXISTS cs50"))
print(db.execute("CREATE TABLE cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16), bin BYTEA)"))
print(db.execute("INSERT INTO cs50 (val) VALUES('foo')"))
print(db.execute("INSERT INTO cs50 (val) VALUES('bar')"))
print(db.execute("INSERT INTO cs50 (val) VALUES('baz')"))
print(db.execute("SELECT * FROM cs50"))
try:
print(db.execute("INSERT INTO cs50 (id, val) VALUES(1, 'bar')"))
except Exception as e:
print(e)
pass
print(db.execute("INSERT INTO cs50 (val) VALUES('qux')"))
#print(db.execute("DELETE FROM cs50"))
15 changes: 9 additions & 6 deletions tests/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,8 @@ def setUpClass(self):
self.db = SQL("mysql://root@localhost/test")

def setUp(self):
self.db.execute("CREATE TABLE cs50 (id INTEGER NOT NULL AUTO_INCREMENT, val VARCHAR(16), bin BLOB, PRIMARY KEY (id))")
self.db.execute("CREATE TABLE IF NOT EXISTS cs50 (id INTEGER NOT NULL AUTO_INCREMENT, val VARCHAR(16), bin BLOB, PRIMARY KEY (id))")
self.db.execute("DELETE FROM cs50")


class PostgresTests(SQLTests):
Expand All @@ -159,7 +160,8 @@ def setUpClass(self):
self.db = SQL("postgresql://postgres@localhost/test")

def setUp(self):
self.db.execute("CREATE TABLE cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16), bin BYTEA)")
self.db.execute("CREATE TABLE IF NOT EXISTS cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16), bin BYTEA)")
self.db.execute("DELETE FROM cs50")

def test_cte(self):
self.assertEqual(self.db.execute("WITH foo AS ( SELECT 1 AS bar ) SELECT bar FROM foo"), [{"bar": 1}])
Expand All @@ -173,23 +175,24 @@ def setUpClass(self):
self.db = SQL("sqlite:///test.db")

def setUp(self):
self.db.execute("CREATE TABLE cs50(id INTEGER PRIMARY KEY, val TEXT, bin BLOB)")
self.db.execute("CREATE TABLE IF NOT EXISTS cs50 (id INTEGER PRIMARY KEY, val TEXT, bin BLOB)")
self.db.execute("DELETE FROM cs50")

def test_lastrowid(self):
self.db.execute("CREATE TABLE foo(id INTEGER PRIMARY KEY AUTOINCREMENT, firstname TEXT, lastname TEXT)")
self.assertEqual(self.db.execute("INSERT INTO foo (firstname, lastname) VALUES('firstname', 'lastname')"), 1)
self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo (id, firstname, lastname) VALUES(1, 'firstname', 'lastname')")
self.assertRaises(ValueError, self.db.execute, "INSERT INTO foo (id, firstname, lastname) VALUES(1, 'firstname', 'lastname')")
self.assertEqual(self.db.execute("INSERT OR IGNORE INTO foo (id, firstname, lastname) VALUES(1, 'firstname', 'lastname')"), None)

def test_integrity_constraints(self):
self.db.execute("CREATE TABLE foo(id INTEGER PRIMARY KEY)")
self.assertEqual(self.db.execute("INSERT INTO foo VALUES(1)"), 1)
self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo VALUES(1)")
self.assertRaises(ValueError, self.db.execute, "INSERT INTO foo VALUES(1)")

def test_foreign_key_support(self):
self.db.execute("CREATE TABLE foo(id INTEGER PRIMARY KEY)")
self.db.execute("CREATE TABLE bar(foo_id INTEGER, FOREIGN KEY (foo_id) REFERENCES foo(id))")
self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO bar VALUES(50)")
self.assertRaises(ValueError, self.db.execute, "INSERT INTO bar VALUES(50)")

def test_qmark(self):
self.db.execute("CREATE TABLE foo (firstname STRING, lastname STRING)")
Expand Down