diff --git a/setup.py b/setup.py index 3b32066..d220295 100644 --- a/setup.py +++ b/setup.py @@ -16,5 +16,5 @@ package_dir={"": "src"}, packages=["cs50"], url="https://github.com/cs50/python-cs50", - version="4.0.4" + version="5.0.0" ) diff --git a/src/cs50/__init__.py b/src/cs50/__init__.py index 54706cb..9502254 100644 --- a/src/cs50/__init__.py +++ b/src/cs50/__init__.py @@ -1,28 +1,27 @@ +import logging import os import sys -try: - - # Save student's sys.path - _path = sys.path[:] - # In case student has files that shadow packages - sys.path = [p for p in sys.path if p not in ("", os.getcwd())] +# Disable cs50 logger by default +logging.getLogger("cs50").disabled = True - # Import cs50_* - from .cs50 import get_char, get_float, get_int, get_string +# In case student has files that shadow packages +for p in ("", os.getcwd()): try: - from .cs50 import get_long - except ImportError: + sys.path.remove(p) + except ValueError: pass - # Replace Flask's logger - from . import flask - - # Wrap SQLAlchemy - from .sql import SQL +# Import cs50_* +from .cs50 import get_char, get_float, get_int, get_string +try: + from .cs50 import get_long +except ImportError: + pass -finally: +# Hook into flask importing +from . import flask - # Restore student's sys.path (just in case library raised an exception that caller caught) - sys.path = _path +# Wrap SQLAlchemy +from .sql import SQL diff --git a/src/cs50/flask.py b/src/cs50/flask.py index 7ce48ed..1d59064 100644 --- a/src/cs50/flask.py +++ b/src/cs50/flask.py @@ -1,60 +1,40 @@ -import logging - -from distutils.version import StrictVersion -from pkg_resources import get_distribution - -from .cs50 import _formatException - -# Try to monkey-patch Flask, if installed -try: - - # Only patch >= 1.0 - _version = StrictVersion(get_distribution("flask").version) - assert _version >= StrictVersion("1.0") - - # Reformat logger's exceptions - # http://flask.pocoo.org/docs/1.0/logging/ - # https://docs.python.org/3/library/logging.html#logging.Formatter.formatException - try: - import flask.logging - flask.logging.default_handler.formatter.formatException = lambda exc_info: _formatException(*exc_info) - except Exception: - pass - - # Enable logging when Flask is in use, - # monkey-patching own SQL module, which shouldn't need to know about Flask - logging.getLogger("cs50").disabled = True - try: - import flask - from .sql import SQL - except ImportError: - pass - else: - _execute_before = SQL.execute - def _execute_after(*args, **kwargs): - disabled = logging.getLogger("cs50").disabled - if flask.current_app: - logging.getLogger("cs50").disabled = False - try: - return _execute_before(*args, **kwargs) - finally: - logging.getLogger("cs50").disabled = disabled - SQL.execute = _execute_after - - # When behind CS50 IDE's proxy, ensure that flask.redirect doesn't redirect from HTTPS to HTTP - # https://werkzeug.palletsprojects.com/en/0.15.x/middleware/proxy_fix/#module-werkzeug.middleware.proxy_fix - from os import getenv - if getenv("CS50_IDE_TYPE") == "online": - try: - import flask - from werkzeug.middleware.proxy_fix import ProxyFix - _flask_init_before = flask.Flask.__init__ - def _flask_init_after(self, *args, **kwargs): - _flask_init_before(self, *args, **kwargs) - self.wsgi_app = ProxyFix(self.wsgi_app, x_proto=1) - flask.Flask.__init__ = _flask_init_after - except: - pass - -except Exception: - pass +import os +import pkgutil +import sys + +def _wrap_flask(f): + if f is None: + return + + from distutils.version import StrictVersion + from .cs50 import _formatException + + if f.__version__ < StrictVersion("1.0"): + return + + f.logging.default_handler.formatter.formatException = lambda exc_info: _formatException(*exc_info) + + if os.getenv("CS50_IDE_TYPE") == "online": + from werkzeug.middleware.proxy_fix import ProxyFix + _flask_init_before = f.Flask.__init__ + def _flask_init_after(self, *args, **kwargs): + _flask_init_before(self, *args, **kwargs) + self.wsgi_app = ProxyFix(self.wsgi_app, x_proto=1) + f.Flask.__init__ = _flask_init_after + + +# Flask was imported before cs50 +if "flask" in sys.modules: + _wrap_flask(sys.modules["flask"]) + +# Flask wasn't imported +else: + flask_loader = pkgutil.get_loader('flask') + if flask_loader: + _exec_module_before = flask_loader.exec_module + + def _exec_module_after(*args, **kwargs): + _exec_module_before(*args, **kwargs) + _wrap_flask(sys.modules["flask"]) + + flask_loader.exec_module = _exec_module_after diff --git a/src/cs50/sql.py b/src/cs50/sql.py index 3bc52d6..484f5ad 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -1,3 +1,30 @@ +def _enable_logging(f): + """Enable logging of SQL statements when Flask is in use.""" + + import logging + import functools + + @functools.wraps(f) + def decorator(*args, **kwargs): + + # Infer whether Flask is installed + try: + import flask + except ModuleNotFoundError: + return f(*args, **kwargs) + + # Enable logging + disabled = logging.getLogger("cs50").disabled + if flask.current_app: + logging.getLogger("cs50").disabled = False + try: + return f(*args, **kwargs) + finally: + logging.getLogger("cs50").disabled = disabled + + return decorator + + class SQL(object): """Wrap SQLAlchemy to provide a simple SQL API.""" @@ -29,25 +56,23 @@ def __init__(self, url, **kwargs): if not os.path.isfile(matches.group(1)): raise RuntimeError("not a file: {}".format(matches.group(1))) - # Remember foreign_keys and remove it from kwargs - foreign_keys = kwargs.pop("foreign_keys", False) + # Create engine, disabling SQLAlchemy's own autocommit mode, raising exception if back end's module not installed + self._engine = sqlalchemy.create_engine(url, **kwargs).execution_options(autocommit=False) - # Create engine, raising exception if back end's module not installed - self.engine = sqlalchemy.create_engine(url, **kwargs) + # Listener for connections + def connect(dbapi_connection, connection_record): - # Enable foreign key constraints - if foreign_keys: - def connect(dbapi_connection, connection_record): - if type(dbapi_connection) is sqlite3.Connection: # If back end is sqlite - cursor = dbapi_connection.cursor() - cursor.execute("PRAGMA foreign_keys=ON") - cursor.close() - sqlalchemy.event.listen(self.engine, "connect", connect) + # Disable underlying API's own emitting of BEGIN and COMMIT + dbapi_connection.isolation_level = None - else: + # Enable foreign key constraints + if type(dbapi_connection) is sqlite3.Connection: # If back end is sqlite + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() - # Create engine, raising exception if back end's module not installed - self.engine = sqlalchemy.create_engine(url, **kwargs) + # Register listener + sqlalchemy.event.listen(self._engine, "connect", connect) # Log statements to standard error logging.basicConfig(level=logging.DEBUG) @@ -64,6 +89,12 @@ def connect(dbapi_connection, connection_record): finally: self._logger.disabled = disabled + def __del__(self): + """Close database connection.""" + if hasattr(self, "_connection"): + self._connection.close() + + @_enable_logging def execute(self, sql, *args, **kwargs): """Execute a SQL statement.""" @@ -232,6 +263,38 @@ def execute(self, sql, *args, **kwargs): # Join tokens into statement statement = "".join([str(token) for token in tokens]) + # Connect to database (for transactions' sake) + try: + + # Infer whether Flask is installed + import flask + + # Infer whether app is defined + assert flask.current_app + + # If no connection for app's current request yet + if not hasattr(flask.g, "_connection"): + + # Connect now + flask.g._connection = self._engine.connect() + + # Disconnect later + @flask.current_app.teardown_appcontext + def shutdown_session(exception=None): + flask.g._connection.close() + + # Use this connection + connection = flask.g._connection + + except (ModuleNotFoundError, AssertionError): + + # If no connection yet + if not hasattr(self, "_connection"): + self._connection = self._engine.connect() + + # Use this connection + connection = self._connection + # Catch SQLAlchemy warnings with warnings.catch_warnings(): @@ -241,8 +304,11 @@ def execute(self, sql, *args, **kwargs): # Prepare, execute statement try: + # Join tokens into statement, abbreviating binary data as + _statement = "".join([str(bytes) if token.ttype == sqlparse.tokens.Other else str(token) for token in tokens]) + # Execute statement - result = self.engine.execute(sqlalchemy.text(statement)) + result = connection.execute(sqlalchemy.text(statement)) # Return value ret = True @@ -254,42 +320,55 @@ def execute(self, sql, *args, **kwargs): # If SELECT, return result set as list of dict objects if value == "SELECT": - # Coerce any decimal.Decimal objects to float objects - # https://groups.google.com/d/msg/sqlalchemy/0qXMYJvq8SA/oqtvMD9Uw-kJ + # Coerce types rows = [dict(row) for row in result.fetchall()] for row in rows: for column in row: + + # Coerce decimal.Decimal objects to float objects + # https://groups.google.com/d/msg/sqlalchemy/0qXMYJvq8SA/oqtvMD9Uw-kJ if type(row[column]) is decimal.Decimal: row[column] = float(row[column]) + + # Coerce memoryview objects (as from PostgreSQL's bytea columns) to bytes + elif type(row[column]) is memoryview: + row[column] = bytes(row[column]) + + # Rows to be returned ret = rows - # If INSERT, return primary key value for a newly inserted row + # If INSERT, return primary key value for a newly inserted row (or None if none) elif value == "INSERT": - if self.engine.url.get_backend_name() in ["postgres", "postgresql"]: - result = self.engine.execute("SELECT LASTVAL()") - ret = result.first()[0] + if self._engine.url.get_backend_name() in ["postgres", "postgresql"]: + try: + result = connection.execute("SELECT LASTVAL()") + ret = result.first()[0] + except sqlalchemy.exc.OperationalError: # If lastval is not yet defined in this session + ret = None else: - ret = result.lastrowid + ret = result.lastrowid if result.rowcount == 1 else None # If DELETE or UPDATE, return number of rows matched elif value in ["DELETE", "UPDATE"]: ret = result.rowcount # If constraint violated, return None - except sqlalchemy.exc.IntegrityError: + except sqlalchemy.exc.IntegrityError as e: self._logger.debug(termcolor.colored(statement, "yellow")) - return None + e = RuntimeError(e.orig) + e.__cause__ = None + raise e # If user errror except sqlalchemy.exc.OperationalError as e: self._logger.debug(termcolor.colored(statement, "red")) - e = RuntimeError(_parse_exception(e)) + e = RuntimeError(e.orig) e.__cause__ = None raise e # Return value else: - self._logger.debug(termcolor.colored(statement, "green")) + self._logger.debug(termcolor.colored(_statement, "green")) return ret def _escape(self, value): @@ -312,53 +391,58 @@ def __escape(value): if type(value) is bool: return sqlparse.sql.Token( sqlparse.tokens.Number, - sqlalchemy.types.Boolean().literal_processor(self.engine.dialect)(value)) - - # bytearray, bytes - elif type(value) in [bytearray, bytes]: - raise RuntimeError("unsupported value") # TODO + sqlalchemy.types.Boolean().literal_processor(self._engine.dialect)(value)) + + # bytes + elif type(value) is bytes: + if self._engine.url.get_backend_name() in ["mysql", "sqlite"]: + return sqlparse.sql.Token(sqlparse.tokens.Other, f"x'{value.hex()}'") # https://dev.mysql.com/doc/refman/8.0/en/hexadecimal-literals.html + elif self._engine.url.get_backend_name() == "postgresql": + return sqlparse.sql.Token(sqlparse.tokens.Other, f"'\\x{value.hex()}'") # https://dba.stackexchange.com/a/203359 + else: + raise RuntimeError("unsupported value: {}".format(value)) # datetime.date elif type(value) is datetime.date: return sqlparse.sql.Token( sqlparse.tokens.String, - sqlalchemy.types.String().literal_processor(self.engine.dialect)(value.strftime("%Y-%m-%d"))) + sqlalchemy.types.String().literal_processor(self._engine.dialect)(value.strftime("%Y-%m-%d"))) # datetime.datetime elif type(value) is datetime.datetime: return sqlparse.sql.Token( sqlparse.tokens.String, - sqlalchemy.types.String().literal_processor(self.engine.dialect)(value.strftime("%Y-%m-%d %H:%M:%S"))) + sqlalchemy.types.String().literal_processor(self._engine.dialect)(value.strftime("%Y-%m-%d %H:%M:%S"))) # datetime.time elif type(value) is datetime.time: return sqlparse.sql.Token( sqlparse.tokens.String, - sqlalchemy.types.String().literal_processor(self.engine.dialect)(value.strftime("%H:%M:%S"))) + sqlalchemy.types.String().literal_processor(self._engine.dialect)(value.strftime("%H:%M:%S"))) # float elif type(value) is float: return sqlparse.sql.Token( sqlparse.tokens.Number, - sqlalchemy.types.Float().literal_processor(self.engine.dialect)(value)) + sqlalchemy.types.Float().literal_processor(self._engine.dialect)(value)) # int elif type(value) is int: return sqlparse.sql.Token( sqlparse.tokens.Number, - sqlalchemy.types.Integer().literal_processor(self.engine.dialect)(value)) + sqlalchemy.types.Integer().literal_processor(self._engine.dialect)(value)) # str elif type(value) is str: return sqlparse.sql.Token( sqlparse.tokens.String, - sqlalchemy.types.String().literal_processor(self.engine.dialect)(value)) + sqlalchemy.types.String().literal_processor(self._engine.dialect)(value)) # None elif value is None: return sqlparse.sql.Token( sqlparse.tokens.Keyword, - sqlalchemy.types.NullType().literal_processor(self.engine.dialect)(value)) + sqlalchemy.types.NullType().literal_processor(self._engine.dialect)(value)) # Unsupported value else: @@ -366,11 +450,9 @@ def __escape(value): # Escape value(s), separating with commas as needed if type(value) in [list, tuple]: - return sqlparse.sql.TokenList(sqlparse.parse(", ".join([str(__escape(v)) for v in value]))) + return sqlparse.sql.TokenList([__escape(v) for v in value]) else: - return sqlparse.sql.Token( - sqlparse.tokens.String, - __escape(value)) + return __escape(value) def _parse_exception(e): diff --git a/tests/sql.py b/tests/sql.py index ceee9dc..ab70f21 100644 --- a/tests/sql.py +++ b/tests/sql.py @@ -32,9 +32,9 @@ def test_select_all(self): self.assertEqual(self.db.execute("SELECT * FROM cs50"), []) rows = [ - {"id": 1, "val": "foo"}, - {"id": 2, "val": "bar"}, - {"id": 3, "val": "baz"} + {"id": 1, "val": "foo", "bin": None}, + {"id": 2, "val": "bar", "bin": None}, + {"id": 3, "val": "baz", "bin": None} ] for row in rows: self.db.execute("INSERT INTO cs50(val) VALUES(:val)", val=row["val"]) @@ -61,7 +61,13 @@ def test_select_where(self): for row in rows: self.db.execute("INSERT INTO cs50(val) VALUES(:val)", val=row["val"]) - self.assertEqual(self.db.execute("SELECT * FROM cs50 WHERE id = :id OR val = :val", id=rows[1]["id"], val=rows[2]["val"]), rows[1:3]) + self.assertEqual(self.db.execute("SELECT id, val FROM cs50 WHERE id = :id OR val = :val", id=rows[1]["id"], val=rows[2]["val"]), rows[1:3]) + + def test_select_with_comments(self): + self.assertEqual(self.db.execute("--comment\nSELECT * FROM cs50;\n--comment"), []) + + def test_select_with_semicolon(self): + self.assertEqual(self.db.execute("SELECT * FROM cs50;\n--comment"), []) def test_select_with_comments(self): self.assertEqual(self.db.execute("--comment\nSELECT * FROM cs50;\n--comment"), []) @@ -99,8 +105,33 @@ def test_string_literal_with_colon(self): self.assertEqual(self.db.execute("SELECT val FROM cs50 WHERE val = ':bar :baz'"), [{"val": ":bar :baz"}]) self.assertEqual(self.db.execute("SELECT val FROM cs50 WHERE val = ' :bar :baz'"), [{"val": " :bar :baz"}]) + def test_blob(self): + rows = [ + {"id": 1, "bin": b"\0"}, + {"id": 2, "bin": b"\1"}, + {"id": 3, "bin": b"\2"} + ] + for row in rows: + self.db.execute("INSERT INTO cs50(bin) VALUES(:bin)", bin=row["bin"]) + self.assertEqual(self.db.execute("SELECT id, bin FROM cs50"), rows) + + def test_commit(self): + self.db.execute("BEGIN") + self.db.execute("INSERT INTO cs50 (val) VALUES('foo')") + self.db.execute("COMMIT") + self.assertEqual(self.db.execute("SELECT val FROM cs50"), [{"val": "foo"}]) + + def test_rollback(self): + self.db.execute("BEGIN") + self.db.execute("INSERT INTO cs50 (val) VALUES('foo')") + self.db.execute("INSERT INTO cs50 (val) VALUES('bar')") + self.db.execute("ROLLBACK") + self.assertEqual(self.db.execute("SELECT val FROM cs50"), []) + def tearDown(self): self.db.execute("DROP TABLE cs50") + self.db.execute("DROP TABLE IF EXISTS foo") + self.db.execute("DROP TABLE IF EXISTS bar") @classmethod def tearDownClass(self): @@ -117,7 +148,7 @@ def setUpClass(self): self.db = SQL("mysql://root@localhost/test") def setUp(self): - self.db.execute("CREATE TABLE cs50 (id INTEGER NOT NULL AUTO_INCREMENT, val VARCHAR(16), PRIMARY KEY (id))") + self.db.execute("CREATE TABLE cs50 (id INTEGER NOT NULL AUTO_INCREMENT, val VARCHAR(16), bin BLOB, PRIMARY KEY (id))") class PostgresTests(SQLTests): @classmethod @@ -125,36 +156,34 @@ def setUpClass(self): self.db = SQL("postgresql://postgres@localhost/test") def setUp(self): - self.db.execute("CREATE TABLE cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16))") + self.db.execute("CREATE TABLE cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16), bin BYTEA)") class SQLiteTests(SQLTests): @classmethod def setUpClass(self): open("test.db", "w").close() self.db = SQL("sqlite:///test.db") - open("test1.db", "w").close() - self.db1 = SQL("sqlite:///test1.db", foreign_keys=True) def setUp(self): - self.db.execute("DROP TABLE IF EXISTS cs50") - self.db.execute("CREATE TABLE cs50(id INTEGER PRIMARY KEY, val TEXT)") + self.db.execute("CREATE TABLE cs50(id INTEGER PRIMARY KEY, val TEXT, bin BLOB)") + + def test_lastrowid(self): + self.db.execute("CREATE TABLE foo(id INTEGER PRIMARY KEY AUTOINCREMENT, firstname TEXT, lastname TEXT)") + self.assertEqual(self.db.execute("INSERT INTO foo (firstname, lastname) VALUES('firstname', 'lastname')"), 1) + self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo (id, firstname, lastname) VALUES(1, 'firstname', 'lastname')") + self.assertEqual(self.db.execute("INSERT OR IGNORE INTO foo (id, firstname, lastname) VALUES(1, 'firstname', 'lastname')"), None) + + def test_integrity_constraints(self): + self.db.execute("CREATE TABLE foo(id INTEGER PRIMARY KEY)") + self.assertEqual(self.db.execute("INSERT INTO foo VALUES(1)"), 1) + self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo VALUES(1)") def test_foreign_key_support(self): - self.db.execute("DROP TABLE IF EXISTS foo") self.db.execute("CREATE TABLE foo(id INTEGER PRIMARY KEY)") - self.db.execute("DROP TABLE IF EXISTS bar") self.db.execute("CREATE TABLE bar(foo_id INTEGER, FOREIGN KEY (foo_id) REFERENCES foo(id))") - self.assertEqual(self.db.execute("INSERT INTO bar VALUES(50)"), 1) - - self.db1.execute("DROP TABLE IF EXISTS foo") - self.db1.execute("CREATE TABLE foo(id INTEGER PRIMARY KEY)") - self.db1.execute("DROP TABLE IF EXISTS bar") - self.db1.execute("CREATE TABLE bar(foo_id INTEGER, FOREIGN KEY (foo_id) REFERENCES foo(id))") - self.assertEqual(self.db1.execute("INSERT INTO bar VALUES(50)"), None) - + self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO bar VALUES(50)") def test_qmark(self): - self.db.execute("DROP TABLE IF EXISTS foo") self.db.execute("CREATE TABLE foo (firstname STRING, lastname STRING)") self.db.execute("INSERT INTO foo VALUES (?, 'bar')", "baz") @@ -188,7 +217,6 @@ def test_qmark(self): self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "bar", "lastname": "baz"}]) self.db.execute("DELETE FROM foo") - self.db.execute("DROP TABLE IF EXISTS bar") self.db.execute("CREATE TABLE bar (firstname STRING)") self.db.execute("INSERT INTO bar VALUES (?)", "baz") self.assertEqual(self.db.execute("SELECT * FROM bar"), [{"firstname": "baz"}]) @@ -203,7 +231,6 @@ def test_qmark(self): self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo VALUES (?, ?)", 'bar', baz='baz') def test_named(self): - self.db.execute("DROP TABLE IF EXISTS foo") self.db.execute("CREATE TABLE foo (firstname STRING, lastname STRING)") self.db.execute("INSERT INTO foo VALUES (:baz, 'bar')", baz="baz") @@ -226,7 +253,6 @@ def test_named(self): self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "bar", "lastname": "baz"}]) self.db.execute("DELETE FROM foo") - self.db.execute("DROP TABLE IF EXISTS bar") self.db.execute("CREATE TABLE bar (firstname STRING)") self.db.execute("INSERT INTO bar VALUES (:baz)", baz="baz") self.assertEqual(self.db.execute("SELECT * FROM bar"), [{"firstname": "baz"}]) @@ -238,7 +264,6 @@ def test_named(self): def test_numeric(self): - self.db.execute("DROP TABLE IF EXISTS foo") self.db.execute("CREATE TABLE foo (firstname STRING, lastname STRING)") self.db.execute("INSERT INTO foo VALUES (:1, 'bar')", "baz") @@ -272,7 +297,6 @@ def test_numeric(self): self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "bar", "lastname": "baz"}]) self.db.execute("DELETE FROM foo") - self.db.execute("DROP TABLE IF EXISTS bar") self.db.execute("CREATE TABLE bar (firstname STRING)") self.db.execute("INSERT INTO bar VALUES (:1)", "baz") self.assertEqual(self.db.execute("SELECT * FROM bar"), [{"firstname": "baz"}])