diff --git a/README.md b/README.md index fb37280..85d5172 100644 --- a/README.md +++ b/README.md @@ -19,3 +19,74 @@ f = cs50.get_float(); i = cs50.get_int(); s = cs50.get_string(); ``` + +## Testing + +1. Run `cli50` in `python-cs50`. +1. Run `sudo su -`. +1. Run `apt install -y libmysqlclient-dev mysql-server postgresql`. +1. Run `pip3 install mysqlclient psycopg2-binary`. +1. In `/etc/mysql/mysql.conf.d/mysqld.cnf`, add `skip-grant-tables` under `[mysqld]`. +1. In `/etc/profile.d/cli.sh`, remove `valgrind` function for now. +1. Run `service mysql start`. +1. Run `mysql -e 'CREATE DATABASE IF NOT EXISTS test;'`. +1. In `/etc/postgresql/10/main/pg_hba.conf, change: + ``` + local all postgres peer + host all all 127.0.0.1/32 md5 + ``` + to: + ``` + local all postgres trust + host all all 127.0.0.1/32 trust + ``` +1. Run `service postgresql start`. +1. Run `psql -c 'create database test;' -U postgres`. +1. Run `touch test.db`. + +### Sample Tests + +``` +import cs50 +db = cs50.SQL("sqlite:///foo.db") +db.execute("CREATE TABLE IF NOT EXISTS cs50 (id INTEGER PRIMARY KEY, val TEXT, bin BLOB)") +db.execute("INSERT INTO cs50 (val) VALUES('a')") +db.execute("INSERT INTO cs50 (val) VALUES('b')") +db.execute("BEGIN") +db.execute("INSERT INTO cs50 (val) VALUES('c')") +db.execute("INSERT INTO cs50 (val) VALUES('x')") +db.execute("INSERT INTO cs50 (val) VALUES('y')") +db.execute("ROLLBACK") +db.execute("INSERT INTO cs50 (val) VALUES('z')") +db.execute("COMMIT") + +--- + +import cs50 +db = cs50.SQL("mysql://root@localhost/test") +db.execute("CREATE TABLE IF NOT EXISTS cs50 (id INTEGER PRIMARY KEY, val TEXT, bin BLOB)") +db.execute("INSERT INTO cs50 (val) VALUES('a')") +db.execute("INSERT INTO cs50 (val) VALUES('b')") +db.execute("BEGIN") +db.execute("INSERT INTO cs50 (val) VALUES('c')") +db.execute("INSERT INTO cs50 (val) VALUES('x')") +db.execute("INSERT INTO cs50 (val) VALUES('y')") +db.execute("ROLLBACK") +db.execute("INSERT INTO cs50 (val) VALUES('z')") +db.execute("COMMIT") + +--- + +import cs50 +db = cs50.SQL("postgresql://postgres@localhost/test") +db.execute("CREATE TABLE IF NOT EXISTS cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16), bin BYTEA)") +db.execute("INSERT INTO cs50 (val) VALUES('a')") +db.execute("INSERT INTO cs50 (val) VALUES('b')") +db.execute("BEGIN") +db.execute("INSERT INTO cs50 (val) VALUES('c')") +db.execute("INSERT INTO cs50 (val) VALUES('x')") +db.execute("INSERT INTO cs50 (val) VALUES('y')") +db.execute("ROLLBACK") +db.execute("INSERT INTO cs50 (val) VALUES('z')") +db.execute("COMMIT") +``` diff --git a/src/cs50/flask.py b/src/cs50/flask.py index 1d59064..538d32a 100644 --- a/src/cs50/flask.py +++ b/src/cs50/flask.py @@ -14,20 +14,21 @@ def _wrap_flask(f): f.logging.default_handler.formatter.formatException = lambda exc_info: _formatException(*exc_info) - if os.getenv("CS50_IDE_TYPE") == "online": + if os.getenv("CS50_IDE_TYPE"): from werkzeug.middleware.proxy_fix import ProxyFix _flask_init_before = f.Flask.__init__ def _flask_init_after(self, *args, **kwargs): _flask_init_before(self, *args, **kwargs) - self.wsgi_app = ProxyFix(self.wsgi_app, x_proto=1) + self.config["TEMPLATES_AUTO_RELOAD"] = True # Automatically reload templates + self.wsgi_app = ProxyFix(self.wsgi_app, x_proto=1) # For HTTPS-to-HTTP proxy f.Flask.__init__ = _flask_init_after -# Flask was imported before cs50 +# If Flask was imported before cs50 if "flask" in sys.modules: _wrap_flask(sys.modules["flask"]) -# Flask wasn't imported +# If Flask wasn't imported else: flask_loader = pkgutil.get_loader('flask') if flask_loader: diff --git a/src/cs50/sql.py b/src/cs50/sql.py index b9675d3..f6da366 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -62,7 +62,8 @@ def __init__(self, url, **kwargs): # Listener for connections def connect(dbapi_connection, connection_record): - # Disable underlying API's own emitting of BEGIN and COMMIT + # Disable underlying API's own emitting of BEGIN and COMMIT so we can ourselves + # https://docs.sqlalchemy.org/en/13/dialects/sqlite.html#serializable-isolation-savepoints-transactional-ddl dbapi_connection.isolation_level = None # Enable foreign key constraints @@ -71,6 +72,9 @@ def connect(dbapi_connection, connection_record): cursor.execute("PRAGMA foreign_keys=ON") cursor.close() + # Autocommit by default + self._autocommit = True + # Register listener sqlalchemy.event.listen(self._engine, "connect", connect) @@ -90,9 +94,14 @@ def connect(dbapi_connection, connection_record): self._logger.disabled = disabled def __del__(self): + """Disconnect from database.""" + self._disconnect() + + def _disconnect(self): """Close database connection.""" if hasattr(self, "_connection"): self._connection.close() + delattr(self, "_connection") @_enable_logging def execute(self, sql, *args, **kwargs): @@ -107,7 +116,7 @@ def execute(self, sql, *args, **kwargs): import warnings # Parse statement, stripping comments and then leading/trailing whitespace - statements = sqlparse.parse(sqlparse.format(sql, strip_comments=True).strip()) + statements = sqlparse.parse(sqlparse.format(sql, keyword_case="upper", strip_comments=True).strip()) # Allow only one statement at a time, since SQLite doesn't support multiple # https://docs.python.org/3/library/sqlite3.html#sqlite3.Cursor.execute @@ -122,9 +131,10 @@ def execute(self, sql, *args, **kwargs): # Infer command from (unflattened) statement for token in statements[0]: - if token.ttype in [sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML]: - command = token.value.upper() - break + if token.ttype in [sqlparse.tokens.Keyword, sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML]: + if token.value in ["BEGIN", "DELETE", "INSERT", "SELECT", "START", "UPDATE"]: + command = token.value + break else: command = None @@ -271,7 +281,7 @@ def execute(self, sql, *args, **kwargs): # Join tokens into statement statement = "".join([str(token) for token in tokens]) - # Connect to database (for transactions' sake) + # Connect to database try: # Infer whether Flask is installed @@ -280,19 +290,23 @@ def execute(self, sql, *args, **kwargs): # Infer whether app is defined assert flask.current_app - # If no connection for app's current request yet + # If new context if not hasattr(flask.g, "_connection"): - # Connect now - flask.g._connection = self._engine.connect() + # Ready to connect + flask.g._connection = None # Disconnect later @flask.current_app.teardown_appcontext def shutdown_session(exception=None): - if hasattr(flask.g, "_connection"): + if flask.g._connection: flask.g._connection.close() - # Use this connection + # If no connection for context yet + if not flask.g._connection: + flas.g._connection = self._engine.connect() + + # Use context's connection connection = flask.g._connection except (ModuleNotFoundError, AssertionError): @@ -316,8 +330,20 @@ def shutdown_session(exception=None): # Join tokens into statement, abbreviating binary data as _statement = "".join([str(bytes) if token.ttype == sqlparse.tokens.Other else str(token) for token in tokens]) + # Check for start of transaction + if command in ["BEGIN", "START"]: + self._autocommit = False + # Execute statement + if self._autocommit: + connection.execute(sqlalchemy.text("BEGIN")) result = connection.execute(sqlalchemy.text(statement)) + if self._autocommit: + connection.execute(sqlalchemy.text("COMMIT")) + + # Check for end of transaction + if command in ["COMMIT", "ROLLBACK"]: + self._autocommit = True # Return value ret = True @@ -360,12 +386,13 @@ def shutdown_session(exception=None): # If constraint violated, return None except sqlalchemy.exc.IntegrityError as e: self._logger.debug(termcolor.colored(statement, "yellow")) - e = RuntimeError(e.orig) + e = ValueError(e.orig) e.__cause__ = None raise e - # If user errror - except sqlalchemy.exc.OperationalError as e: + # If user error + except (sqlalchemy.exc.OperationalError, sqlalchemy.exc.ProgrammingError) as e: + self._disconnect() self._logger.debug(termcolor.colored(statement, "red")) e = RuntimeError(e.orig) e.__cause__ = None diff --git a/tests/foo.py b/tests/foo.py new file mode 100644 index 0000000..7f32a00 --- /dev/null +++ b/tests/foo.py @@ -0,0 +1,48 @@ +import logging +import sys + +sys.path.insert(0, "../src") + +import cs50 + +""" +db = cs50.SQL("sqlite:///foo.db") + +logging.getLogger("cs50").disabled = False + +#db.execute("SELECT ? FROM ? ORDER BY ?", "a", "tbl", "c") +db.execute("CREATE TABLE IF NOT EXISTS bar (firstname STRING)") + +db.execute("INSERT INTO bar VALUES (?)", "baz") +db.execute("INSERT INTO bar VALUES (?)", "qux") +db.execute("SELECT * FROM bar WHERE firstname IN (?)", ("baz", "qux")) +db.execute("DELETE FROM bar") +""" + +db = cs50.SQL("postgresql://postgres@localhost/test") + +""" +print(db.execute("DROP TABLE IF EXISTS cs50")) +print(db.execute("CREATE TABLE cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16), bin BYTEA)")) +print(db.execute("INSERT INTO cs50 (val) VALUES('foo')")) +print(db.execute("SELECT * FROM cs50")) + +print(db.execute("DROP TABLE IF EXISTS cs50")) +print(db.execute("CREATE TABLE cs50 (val VARCHAR(16), bin BYTEA)")) +print(db.execute("INSERT INTO cs50 (val) VALUES('foo')")) +print(db.execute("SELECT * FROM cs50")) +""" + +print(db.execute("DROP TABLE IF EXISTS cs50")) +print(db.execute("CREATE TABLE cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16), bin BYTEA)")) +print(db.execute("INSERT INTO cs50 (val) VALUES('foo')")) +print(db.execute("INSERT INTO cs50 (val) VALUES('bar')")) +print(db.execute("INSERT INTO cs50 (val) VALUES('baz')")) +print(db.execute("SELECT * FROM cs50")) +try: + print(db.execute("INSERT INTO cs50 (id, val) VALUES(1, 'bar')")) +except Exception as e: + print(e) + pass +print(db.execute("INSERT INTO cs50 (val) VALUES('qux')")) +#print(db.execute("DELETE FROM cs50")) diff --git a/tests/sql.py b/tests/sql.py index cbad470..f893895 100644 --- a/tests/sql.py +++ b/tests/sql.py @@ -150,7 +150,8 @@ def setUpClass(self): self.db = SQL("mysql://root@localhost/test") def setUp(self): - self.db.execute("CREATE TABLE cs50 (id INTEGER NOT NULL AUTO_INCREMENT, val VARCHAR(16), bin BLOB, PRIMARY KEY (id))") + self.db.execute("CREATE TABLE IF NOT EXISTS cs50 (id INTEGER NOT NULL AUTO_INCREMENT, val VARCHAR(16), bin BLOB, PRIMARY KEY (id))") + self.db.execute("DELETE FROM cs50") class PostgresTests(SQLTests): @@ -159,7 +160,8 @@ def setUpClass(self): self.db = SQL("postgresql://postgres@localhost/test") def setUp(self): - self.db.execute("CREATE TABLE cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16), bin BYTEA)") + self.db.execute("CREATE TABLE IF NOT EXISTS cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16), bin BYTEA)") + self.db.execute("DELETE FROM cs50") def test_cte(self): self.assertEqual(self.db.execute("WITH foo AS ( SELECT 1 AS bar ) SELECT bar FROM foo"), [{"bar": 1}]) @@ -173,23 +175,24 @@ def setUpClass(self): self.db = SQL("sqlite:///test.db") def setUp(self): - self.db.execute("CREATE TABLE cs50(id INTEGER PRIMARY KEY, val TEXT, bin BLOB)") + self.db.execute("CREATE TABLE IF NOT EXISTS cs50 (id INTEGER PRIMARY KEY, val TEXT, bin BLOB)") + self.db.execute("DELETE FROM cs50") def test_lastrowid(self): self.db.execute("CREATE TABLE foo(id INTEGER PRIMARY KEY AUTOINCREMENT, firstname TEXT, lastname TEXT)") self.assertEqual(self.db.execute("INSERT INTO foo (firstname, lastname) VALUES('firstname', 'lastname')"), 1) - self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo (id, firstname, lastname) VALUES(1, 'firstname', 'lastname')") + self.assertRaises(ValueError, self.db.execute, "INSERT INTO foo (id, firstname, lastname) VALUES(1, 'firstname', 'lastname')") self.assertEqual(self.db.execute("INSERT OR IGNORE INTO foo (id, firstname, lastname) VALUES(1, 'firstname', 'lastname')"), None) def test_integrity_constraints(self): self.db.execute("CREATE TABLE foo(id INTEGER PRIMARY KEY)") self.assertEqual(self.db.execute("INSERT INTO foo VALUES(1)"), 1) - self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo VALUES(1)") + self.assertRaises(ValueError, self.db.execute, "INSERT INTO foo VALUES(1)") def test_foreign_key_support(self): self.db.execute("CREATE TABLE foo(id INTEGER PRIMARY KEY)") self.db.execute("CREATE TABLE bar(foo_id INTEGER, FOREIGN KEY (foo_id) REFERENCES foo(id))") - self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO bar VALUES(50)") + self.assertRaises(ValueError, self.db.execute, "INSERT INTO bar VALUES(50)") def test_qmark(self): self.db.execute("CREATE TABLE foo (firstname STRING, lastname STRING)")