From c51a9df52c8c50c51b9ba34cb8edf91aa227e194 Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Fri, 12 Jun 2020 10:11:40 -0400 Subject: [PATCH 001/159] style tweaks --- src/cs50/sql.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index f47e2b6..3455fcb 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -131,11 +131,10 @@ def execute(self, sql, *args, **kwargs): command = token.value.upper() break - # Begin a new session, if transaction started by caller (not using autocommit) + # Begin a new session, if transaction opened by caller (not using autocommit) elif token.value.upper() in ["BEGIN", "START"]: if self._in_transaction: raise RuntimeError("transaction already open") - self._in_transaction = True else: command = None @@ -323,8 +322,7 @@ def shutdown_session(exception=None): # If COMMIT or ROLLBACK, turn on autocommit mode if command in ["COMMIT", "ROLLBACK"] and "TO" not in (token.value for token in tokens): if not self._in_transaction: - raise RuntimeError("transactions must be initiated with BEGIN or START TRANSACTION") - + raise RuntimeError("transactions must be opened with BEGIN or START TRANSACTION") self._in_transaction = False # Execute statement @@ -395,7 +393,6 @@ def _close_session(self): """Closes any existing session and resets instance variables.""" if self._session is not None: self._session.close() - self._session = None self._in_transaction = False From 4924fa10ad7e4a419be6f7818252f00b60835bda Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Fri, 12 Jun 2020 10:13:03 -0400 Subject: [PATCH 002/159] version++ Mostly bug fixes, but arguably adds support for multiple connections too --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 87a44b0..ae4b30a 100644 --- a/setup.py +++ b/setup.py @@ -16,5 +16,5 @@ package_dir={"": "src"}, packages=["cs50"], url="https://github.com/cs50/python-cs50", - version="5.0.4" + version="5.1.0" ) From 74c74ccb9c4b264c3fa2f21613430b816657c468 Mon Sep 17 00:00:00 2001 From: Chris Bradfield <cb@scribe.net> Date: Thu, 27 Aug 2020 21:03:37 -0700 Subject: [PATCH 003/159] Update README.md Remove non-Pythonic semicolons. --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index fb37280..5df22a2 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,7 @@ import cs50 ... -f = cs50.get_float(); -i = cs50.get_int(); -s = cs50.get_string(); +f = cs50.get_float() +i = cs50.get_int() +s = cs50.get_string() ``` From e7c0df8ad81ad8c1e8e543c72ebc8fbf3b3a5f1f Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Sat, 21 Nov 2020 21:31:08 -0500 Subject: [PATCH 004/159] added tests for IN --- tests/sql.py | 30 ++++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/tests/sql.py b/tests/sql.py index 95301eb..106b4a1 100644 --- a/tests/sql.py +++ b/tests/sql.py @@ -138,11 +138,13 @@ def test_commit(self): db2 = SQL(self.db_url) self.assertEqual(db2.execute("SELECT val FROM cs50"), [{"val": "foo"}]) + """ def test_double_begin(self): self.db.execute("BEGIN") with self.assertRaises(RuntimeError): self.db.execute("BEGIN") self.db.execute("ROLLBACK") + """ def test_rollback(self): self.db.execute("BEGIN") @@ -176,6 +178,7 @@ def tearDownClass(self): if not str(e).startswith("(1051"): raise e + class MySQLTests(SQLTests): @classmethod def setUpClass(self): @@ -186,6 +189,7 @@ def setUpClass(self): def setUp(self): self.db.execute("CREATE TABLE cs50 (id INTEGER NOT NULL AUTO_INCREMENT, val VARCHAR(16), bin BLOB, PRIMARY KEY (id))") + class PostgresTests(SQLTests): @classmethod def setUpClass(self): @@ -199,7 +203,9 @@ def setUp(self): def test_cte(self): self.assertEqual(self.db.execute("WITH foo AS ( SELECT 1 AS bar ) SELECT bar FROM foo"), [{"bar": 1}]) + class SQLiteTests(SQLTests): + @classmethod def setUpClass(self): open("test.db", "w").close() @@ -251,23 +257,39 @@ def test_qmark(self): self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "bar", "lastname": "baz"}]) self.db.execute("DELETE FROM foo") - self.db.execute("INSERT INTO foo VALUES (?, ?)", ["bar", "baz"]) + self.db.execute("INSERT INTO foo VALUES (?)", ("bar", "baz")) self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "bar", "lastname": "baz"}]) self.db.execute("DELETE FROM foo") + self.db.execute("INSERT INTO foo VALUES (?)", ["bar", "baz"]) + self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "bar", "lastname": "baz"}]) + self.db.execute("DELETE FROM foo") - self.db.execute("INSERT INTO foo VALUES (?,?)", "bar", "baz") + self.db.execute("INSERT INTO foo VALUES (?, ?)", ["bar", "baz"]) + self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "bar", "lastname": "baz"}]) + self.db.execute("DELETE FROM foo") + + self.db.execute("INSERT INTO foo VALUES (?, ?)", "bar", "baz") self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "bar", "lastname": "baz"}]) self.db.execute("DELETE FROM foo") self.db.execute("CREATE TABLE bar (firstname STRING)") + + self.db.execute("INSERT INTO bar VALUES (?)", "baz") + self.db.execute("INSERT INTO bar VALUES (?)", "qux") + self.assertEqual(self.db.execute("SELECT * FROM bar WHERE firstname IN (?)", ("baz", "qux")), [{"firstname": "baz"}, {"firstname": "qux"}]) + self.db.execute("DELETE FROM bar") + + self.db.execute("INSERT INTO bar VALUES (?)", "baz") + self.db.execute("INSERT INTO bar VALUES (?)", "qux") + self.assertEqual(self.db.execute("SELECT * FROM bar WHERE firstname IN (?)", ["baz", "qux"]), [{"firstname": "baz"}, {"firstname": "qux"}]) + self.db.execute("DELETE FROM bar") + self.db.execute("INSERT INTO bar VALUES (?)", "baz") self.assertEqual(self.db.execute("SELECT * FROM bar"), [{"firstname": "baz"}]) self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo VALUES (?)") self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo VALUES (?, ?)") - # self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo VALUES (?)", ('bar', 'baz')) - # self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo VALUES (?)", ['bar', 'baz']) self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo VALUES (?, ?)", 'bar', 'baz', 'qux') self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo VALUES (?, ?)", ('bar', 'baz', 'qux')) self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo VALUES (?, ?)", ['bar', 'baz', 'qux']) From cf34af75be81360caf9920170b6b1bfdc0653896 Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Sat, 21 Nov 2020 21:51:02 -0500 Subject: [PATCH 005/159] reverting to 5.0.5 --- .travis.yml | 4 +- README.md | 6 +-- setup.py | 2 +- src/cs50/sql.py | 68 ++++++++++++--------------------- tests/flask/application.py | 60 ++--------------------------- tests/flask/test.py | 49 ------------------------ tests/sql.py | 78 ++++---------------------------------- 7 files changed, 41 insertions(+), 226 deletions(-) delete mode 100644 tests/flask/test.py diff --git a/.travis.yml b/.travis.yml index b13ffd6..4af42c0 100644 --- a/.travis.yml +++ b/.travis.yml @@ -20,11 +20,11 @@ deploy: \"target_commitish\": \"$TRAVIS_COMMIT\", \"name\": \"v$(python setup.py --version)\" }" --user bot50:$GITHUB_TOKEN https://api.github.com/repos/$TRAVIS_REPO_SLUG/releases' on: - branch: master + branch: fixing-lists - provider: pypi user: "$PYPI_USERNAME" password: "$PYPI_PASSWORD" - on: master + on: fixing-lists notifications: slack: secure: lJklhcBVjDT6KzUNa3RFHXdXSeH7ytuuGrkZ5ZcR72CXMoTf2pMJTzPwRLWOp6lCSdDC9Y8MWLrcg/e33dJga4Jlp9alOmWqeqesaFjfee4st8vAsgNbv8/RajPH1gD2bnkt8oIwUzdHItdb5AucKFYjbH2g0d8ndoqYqUeBLrnsT1AP5G/Vi9OHC9OWNpR0FKaZIJE0Wt52vkPMH3sV2mFeIskByPB+56U5y547mualKxn61IVR/dhYBEtZQJuSvnwKHPOn9Pkk7cCa+SSSeTJ4w5LboY8T17otaYNauXo46i1bKIoGiBcCcrJyQHHiPQmcq/YU540MC5Wzt9YXUycmJzRi347oyQeDee27wV3XJlWMXuuhbtJiKCFny7BTQ160VATlj/dbwIzN99Ra6/BtTumv/6LyTdKIuVjdAkcN8dtdDW1nlrQ29zuPNCcXXzJ7zX7kQaOCUV1c2OrsbiH/0fE9nknUORn97txqhlYVi0QMS7764wFo6kg0vpmFQRkkQySsJl+TmgcZ01AlsJc2EMMWVuaj9Af9JU4/4yalqDiXIh1fOYYUZnLfOfWS+MsnI+/oLfqJFyMbrsQQTIjs+kTzbiEdhd2R4EZgusU/xRFWokS2NAvahexrRhRQ6tpAI+LezPrkNOR3aHiykBf+P9BkUa0wPp6V2Ayc6q0= diff --git a/README.md b/README.md index 5df22a2..fb37280 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,7 @@ import cs50 ... -f = cs50.get_float() -i = cs50.get_int() -s = cs50.get_string() +f = cs50.get_float(); +i = cs50.get_int(); +s = cs50.get_string(); ``` diff --git a/setup.py b/setup.py index ae4b30a..7108b87 100644 --- a/setup.py +++ b/setup.py @@ -16,5 +16,5 @@ package_dir={"": "src"}, packages=["cs50"], url="https://github.com/cs50/python-cs50", - version="5.1.0" + version="5.0.5" ) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index 3455fcb..b9675d3 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -43,7 +43,6 @@ def __init__(self, url, **kwargs): import os import re import sqlalchemy - import sqlalchemy.orm import sqlite3 # Get logger @@ -60,11 +59,6 @@ def __init__(self, url, **kwargs): # Create engine, disabling SQLAlchemy's own autocommit mode, raising exception if back end's module not installed self._engine = sqlalchemy.create_engine(url, **kwargs).execution_options(autocommit=False) - # Create a variable to hold the session. If None, autocommit is on. - self._Session = sqlalchemy.orm.session.sessionmaker(bind=self._engine) - self._session = None - self._in_transaction = False - # Listener for connections def connect(dbapi_connection, connection_record): @@ -96,8 +90,9 @@ def connect(dbapi_connection, connection_record): self._logger.disabled = disabled def __del__(self): - """Close database session and connection.""" - self._close_session() + """Close database connection.""" + if hasattr(self, "_connection"): + self._connection.close() @_enable_logging def execute(self, sql, *args, **kwargs): @@ -130,12 +125,6 @@ def execute(self, sql, *args, **kwargs): if token.ttype in [sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML]: command = token.value.upper() break - - # Begin a new session, if transaction opened by caller (not using autocommit) - elif token.value.upper() in ["BEGIN", "START"]: - if self._in_transaction: - raise RuntimeError("transaction already open") - self._in_transaction = True else: command = None @@ -283,10 +272,6 @@ def execute(self, sql, *args, **kwargs): statement = "".join([str(token) for token in tokens]) # Connect to database (for transactions' sake) - if self._session is None: - self._session = self._Session() - - # Set up a Flask app teardown function to close session at teardown try: # Infer whether Flask is installed @@ -295,17 +280,29 @@ def execute(self, sql, *args, **kwargs): # Infer whether app is defined assert flask.current_app - # Disconnect later - but only once - if not hasattr(self, "_teardown_appcontext_added"): - self._teardown_appcontext_added = True + # If no connection for app's current request yet + if not hasattr(flask.g, "_connection"): + # Connect now + flask.g._connection = self._engine.connect() + + # Disconnect later @flask.current_app.teardown_appcontext def shutdown_session(exception=None): - """Close any existing session on app context teardown.""" - self._close_session() + if hasattr(flask.g, "_connection"): + flask.g._connection.close() + + # Use this connection + connection = flask.g._connection except (ModuleNotFoundError, AssertionError): - pass + + # If no connection yet + if not hasattr(self, "_connection"): + self._connection = self._engine.connect() + + # Use this connection + connection = self._connection # Catch SQLAlchemy warnings with warnings.catch_warnings(): @@ -319,14 +316,8 @@ def shutdown_session(exception=None): # Join tokens into statement, abbreviating binary data as <class 'bytes'> _statement = "".join([str(bytes) if token.ttype == sqlparse.tokens.Other else str(token) for token in tokens]) - # If COMMIT or ROLLBACK, turn on autocommit mode - if command in ["COMMIT", "ROLLBACK"] and "TO" not in (token.value for token in tokens): - if not self._in_transaction: - raise RuntimeError("transactions must be opened with BEGIN or START TRANSACTION") - self._in_transaction = False - # Execute statement - result = self._session.execute(sqlalchemy.text(statement)) + result = connection.execute(sqlalchemy.text(statement)) # Return value ret = True @@ -355,7 +346,7 @@ def shutdown_session(exception=None): elif command == "INSERT": if self._engine.url.get_backend_name() in ["postgres", "postgresql"]: try: - result = self._session.execute("SELECT LASTVAL()") + result = connection.execute("SELECT LASTVAL()") ret = result.first()[0] except sqlalchemy.exc.OperationalError: # If lastval is not yet defined in this session ret = None @@ -366,10 +357,6 @@ def shutdown_session(exception=None): elif command in ["DELETE", "UPDATE"]: ret = result.rowcount - # If autocommit is on, commit - if not self._in_transaction: - self._session.commit() - # If constraint violated, return None except sqlalchemy.exc.IntegrityError as e: self._logger.debug(termcolor.colored(statement, "yellow")) @@ -389,13 +376,6 @@ def shutdown_session(exception=None): self._logger.debug(termcolor.colored(_statement, "green")) return ret - def _close_session(self): - """Closes any existing session and resets instance variables.""" - if self._session is not None: - self._session.close() - self._session = None - self._in_transaction = False - def _escape(self, value): """ Escapes value using engine's conversion function. @@ -475,7 +455,7 @@ def __escape(value): # Escape value(s), separating with commas as needed if type(value) in [list, tuple]: - return sqlparse.sql.TokenList([__escape(v) for v in value]) + return sqlparse.sql.TokenList(sqlparse.parse(", ".join([str(__escape(v)) for v in value]))) else: return __escape(value) diff --git a/tests/flask/application.py b/tests/flask/application.py index 404b1d4..939a8f9 100644 --- a/tests/flask/application.py +++ b/tests/flask/application.py @@ -1,76 +1,22 @@ -import logging -import os import requests import sys +from flask import Flask, render_template sys.path.insert(0, "../../src") import cs50 import cs50.flask -from flask import Flask, render_template - app = Flask(__name__) -logging.disable(logging.CRITICAL) -os.environ["WERKZEUG_RUN_MAIN"] = "true" - -db_url = "sqlite:///../test.db" -db = cs50.SQL(db_url) +db = cs50.SQL("sqlite:///../sqlite.db") @app.route("/") def index(): + db.execute("SELECT 1") """ def f(): res = requests.get("cs50.harvard.edu") f() """ return render_template("index.html") - -@app.route("/autocommit") -def autocommit(): - db.execute("INSERT INTO test (val) VALUES (?)", "def") - db2 = cs50.SQL(db_url) - ret = db2.execute("SELECT val FROM test WHERE val=?", "def") - return str(ret == [{"val": "def"}]) - -@app.route("/create") -def create(): - ret = db.execute("CREATE TABLE test (id INTEGER PRIMARY KEY AUTOINCREMENT, val VARCHAR(16))") - return str(ret) - -@app.route("/delete") -def delete(): - ret = db.execute("DELETE FROM test") - return str(ret > 0) - -@app.route("/drop") -def drop(): - ret = db.execute("DROP TABLE test") - return str(ret) - -@app.route("/insert") -def insert(): - ret = db.execute("INSERT INTO test (val) VALUES (?)", "abc") - return str(ret > 0) - -@app.route("/multiple_connections") -def multiple_connections(): - ctx = len(app.teardown_appcontext_funcs) - db1 = cs50.SQL(db_url) - td1 = (len(app.teardown_appcontext_funcs) == ctx + 1) - db2 = cs50.SQL(db_url) - td2 = (len(app.teardown_appcontext_funcs) == ctx + 2) - return str(td1 and td2) - -@app.route("/select") -def select(): - ret = db.execute("SELECT val FROM test") - return str(ret == [{"val": "abc"}]) - -@app.route("/single_teardown") -def single_teardown(): - db.execute("SELECT * FROM test") - ctx = len(app.teardown_appcontext_funcs) - db.execute("SELECT COUNT(id) FROM test") - return str(ctx == len(app.teardown_appcontext_funcs)) diff --git a/tests/flask/test.py b/tests/flask/test.py deleted file mode 100644 index 0b084d6..0000000 --- a/tests/flask/test.py +++ /dev/null @@ -1,49 +0,0 @@ -import logging -import requests -import sys -import threading -import time -import unittest - -from application import app - -def request(route): - r = requests.get("http://localhost:5000/{}".format(route)) - return r.text == "True" - -class FlaskTests(unittest.TestCase): - - def test__create(self): - self.assertTrue(request("create")) - - def test_autocommit(self): - self.assertTrue(request("autocommit")) - - def test_delete(self): - self.assertTrue(request("delete")) - - def test_insert(self): - self.assertTrue(request("insert")) - - def test_multiple_connections(self): - self.assertTrue(request("multiple_connections")) - - def test_select(self): - self.assertTrue(request("select")) - - def test_single_teardown(self): - self.assertTrue(request("single_teardown")) - - def test_zdrop(self): - self.assertTrue(request("drop")) - - -if __name__ == "__main__": - t = threading.Thread(target=app.run, daemon=True) - t.start() - - suite = unittest.TestSuite([ - unittest.TestLoader().loadTestsFromTestCase(FlaskTests) - ]) - - sys.exit(not unittest.TextTestRunner(verbosity=2).run(suite).wasSuccessful()) diff --git a/tests/sql.py b/tests/sql.py index 106b4a1..9ad463f 100644 --- a/tests/sql.py +++ b/tests/sql.py @@ -115,36 +115,11 @@ def test_blob(self): self.db.execute("INSERT INTO cs50(bin) VALUES(:bin)", bin=row["bin"]) self.assertEqual(self.db.execute("SELECT id, bin FROM cs50"), rows) - def test_autocommit(self): - self.assertEqual(self.db.execute("INSERT INTO cs50(val) VALUES('foo')"), 1) - self.assertEqual(self.db.execute("INSERT INTO cs50(val) VALUES('bar')"), 2) - - # Load a new database instance to confirm the INSERTs were committed - db2 = SQL(self.db_url) - self.assertEqual(db2.execute("DELETE FROM cs50 WHERE id < 3"), 2) - - def test_commit_no_transaction(self): - with self.assertRaises(RuntimeError): - self.db.execute("COMMIT") - with self.assertRaises(RuntimeError): - self.db.execute("ROLLBACK") - def test_commit(self): self.db.execute("BEGIN") self.db.execute("INSERT INTO cs50 (val) VALUES('foo')") self.db.execute("COMMIT") - - # Load a new database instance to confirm the INSERT was committed - db2 = SQL(self.db_url) - self.assertEqual(db2.execute("SELECT val FROM cs50"), [{"val": "foo"}]) - - """ - def test_double_begin(self): - self.db.execute("BEGIN") - with self.assertRaises(RuntimeError): - self.db.execute("BEGIN") - self.db.execute("ROLLBACK") - """ + self.assertEqual(self.db.execute("SELECT val FROM cs50"), [{"val": "foo"}]) def test_rollback(self): self.db.execute("BEGIN") @@ -153,17 +128,6 @@ def test_rollback(self): self.db.execute("ROLLBACK") self.assertEqual(self.db.execute("SELECT val FROM cs50"), []) - def test_savepoint(self): - self.db.execute("BEGIN") - self.db.execute("INSERT INTO cs50 (val) VALUES('foo')") - self.db.execute("SAVEPOINT sp1") - self.db.execute("INSERT INTO cs50 (val) VALUES('bar')") - self.assertEqual(self.db.execute("SELECT val FROM cs50"), [{"val": "foo"}, {"val": "bar"}]) - self.db.execute("ROLLBACK TO sp1") - self.assertEqual(self.db.execute("SELECT val FROM cs50"), [{"val": "foo"}]) - self.db.execute("ROLLBACK") - self.assertEqual(self.db.execute("SELECT val FROM cs50"), []) - def tearDown(self): self.db.execute("DROP TABLE cs50") self.db.execute("DROP TABLE IF EXISTS foo") @@ -178,24 +142,18 @@ def tearDownClass(self): if not str(e).startswith("(1051"): raise e - class MySQLTests(SQLTests): @classmethod def setUpClass(self): - self.db_url = "mysql://root@localhost/test" - self.db = SQL(self.db_url) - print("\nMySQL tests") + self.db = SQL("mysql://root@localhost/test") def setUp(self): self.db.execute("CREATE TABLE cs50 (id INTEGER NOT NULL AUTO_INCREMENT, val VARCHAR(16), bin BLOB, PRIMARY KEY (id))") - class PostgresTests(SQLTests): @classmethod def setUpClass(self): - self.db_url = "postgresql://postgres@localhost/test" - self.db = SQL(self.db_url) - print("\nPOSTGRES tests") + self.db = SQL("postgresql://postgres@localhost/test") def setUp(self): self.db.execute("CREATE TABLE cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16), bin BYTEA)") @@ -203,15 +161,11 @@ def setUp(self): def test_cte(self): self.assertEqual(self.db.execute("WITH foo AS ( SELECT 1 AS bar ) SELECT bar FROM foo"), [{"bar": 1}]) - class SQLiteTests(SQLTests): - @classmethod def setUpClass(self): open("test.db", "w").close() - self.db_url = "sqlite:///test.db" - self.db = SQL(self.db_url) - print("\nSQLite tests") + self.db = SQL("sqlite:///test.db") def setUp(self): self.db.execute("CREATE TABLE cs50(id INTEGER PRIMARY KEY, val TEXT, bin BLOB)") @@ -257,39 +211,23 @@ def test_qmark(self): self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "bar", "lastname": "baz"}]) self.db.execute("DELETE FROM foo") - self.db.execute("INSERT INTO foo VALUES (?)", ("bar", "baz")) - self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "bar", "lastname": "baz"}]) - self.db.execute("DELETE FROM foo") - - self.db.execute("INSERT INTO foo VALUES (?)", ["bar", "baz"]) - self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "bar", "lastname": "baz"}]) - self.db.execute("DELETE FROM foo") - self.db.execute("INSERT INTO foo VALUES (?, ?)", ["bar", "baz"]) self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "bar", "lastname": "baz"}]) self.db.execute("DELETE FROM foo") - self.db.execute("INSERT INTO foo VALUES (?, ?)", "bar", "baz") + + self.db.execute("INSERT INTO foo VALUES (?,?)", "bar", "baz") self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "bar", "lastname": "baz"}]) self.db.execute("DELETE FROM foo") self.db.execute("CREATE TABLE bar (firstname STRING)") - - self.db.execute("INSERT INTO bar VALUES (?)", "baz") - self.db.execute("INSERT INTO bar VALUES (?)", "qux") - self.assertEqual(self.db.execute("SELECT * FROM bar WHERE firstname IN (?)", ("baz", "qux")), [{"firstname": "baz"}, {"firstname": "qux"}]) - self.db.execute("DELETE FROM bar") - - self.db.execute("INSERT INTO bar VALUES (?)", "baz") - self.db.execute("INSERT INTO bar VALUES (?)", "qux") - self.assertEqual(self.db.execute("SELECT * FROM bar WHERE firstname IN (?)", ["baz", "qux"]), [{"firstname": "baz"}, {"firstname": "qux"}]) - self.db.execute("DELETE FROM bar") - self.db.execute("INSERT INTO bar VALUES (?)", "baz") self.assertEqual(self.db.execute("SELECT * FROM bar"), [{"firstname": "baz"}]) self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo VALUES (?)") self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo VALUES (?, ?)") + # self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo VALUES (?)", ('bar', 'baz')) + # self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo VALUES (?)", ['bar', 'baz']) self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo VALUES (?, ?)", 'bar', 'baz', 'qux') self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo VALUES (?, ?)", ('bar', 'baz', 'qux')) self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo VALUES (?, ?)", ['bar', 'baz', 'qux']) From 2be1e2438cc57147d33fe41ece6ce86047e132bc Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Sat, 21 Nov 2020 21:56:25 -0500 Subject: [PATCH 006/159] adds tests for IN --- tests/sql.py | 28 +++++++++++++++++++++++++--- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/tests/sql.py b/tests/sql.py index 9ad463f..cbad470 100644 --- a/tests/sql.py +++ b/tests/sql.py @@ -7,6 +7,7 @@ from cs50.sql import SQL + class SQLTests(unittest.TestCase): def test_multiple_statements(self): @@ -142,6 +143,7 @@ def tearDownClass(self): if not str(e).startswith("(1051"): raise e + class MySQLTests(SQLTests): @classmethod def setUpClass(self): @@ -150,6 +152,7 @@ def setUpClass(self): def setUp(self): self.db.execute("CREATE TABLE cs50 (id INTEGER NOT NULL AUTO_INCREMENT, val VARCHAR(16), bin BLOB, PRIMARY KEY (id))") + class PostgresTests(SQLTests): @classmethod def setUpClass(self): @@ -161,7 +164,9 @@ def setUp(self): def test_cte(self): self.assertEqual(self.db.execute("WITH foo AS ( SELECT 1 AS bar ) SELECT bar FROM foo"), [{"bar": 1}]) + class SQLiteTests(SQLTests): + @classmethod def setUpClass(self): open("test.db", "w").close() @@ -207,27 +212,44 @@ def test_qmark(self): self.assertEqual(self.db.execute("SELECT * FROM foo WHERE firstname = ? AND lastname = ?", ["qux", "quux"]), [{"firstname": "qux", "lastname": "quux"}]) self.db.execute("DELETE FROM foo") + self.db.execute("INSERT INTO foo VALUES (?)", ("bar", "baz")) + self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "bar", "lastname": "baz"}]) + self.db.execute("DELETE FROM foo") + self.db.execute("INSERT INTO foo VALUES (?, ?)", ("bar", "baz")) self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "bar", "lastname": "baz"}]) self.db.execute("DELETE FROM foo") - self.db.execute("INSERT INTO foo VALUES (?, ?)", ["bar", "baz"]) + self.db.execute("INSERT INTO foo VALUES (?)", ["bar", "baz"]) self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "bar", "lastname": "baz"}]) self.db.execute("DELETE FROM foo") + self.db.execute("INSERT INTO foo VALUES (?, ?)", ["bar", "baz"]) + self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "bar", "lastname": "baz"}]) + self.db.execute("DELETE FROM foo") self.db.execute("INSERT INTO foo VALUES (?,?)", "bar", "baz") self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "bar", "lastname": "baz"}]) self.db.execute("DELETE FROM foo") self.db.execute("CREATE TABLE bar (firstname STRING)") + self.db.execute("INSERT INTO bar VALUES (?)", "baz") self.assertEqual(self.db.execute("SELECT * FROM bar"), [{"firstname": "baz"}]) + self.db.execute("DELETE FROM bar") + + self.db.execute("INSERT INTO bar VALUES (?)", "baz") + self.db.execute("INSERT INTO bar VALUES (?)", "qux") + self.assertEqual(self.db.execute("SELECT * FROM bar WHERE firstname IN (?)", ("baz", "qux")), [{"firstname": "baz"}, {"firstname": "qux"}]) + self.db.execute("DELETE FROM bar") + + self.db.execute("INSERT INTO bar VALUES (?)", "baz") + self.db.execute("INSERT INTO bar VALUES (?)", "qux") + self.assertEqual(self.db.execute("SELECT * FROM bar WHERE firstname IN (?)", ["baz", "qux"]), [{"firstname": "baz"}, {"firstname": "qux"}]) + self.db.execute("DELETE FROM bar") self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo VALUES (?)") self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo VALUES (?, ?)") - # self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo VALUES (?)", ('bar', 'baz')) - # self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo VALUES (?)", ['bar', 'baz']) self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo VALUES (?, ?)", 'bar', 'baz', 'qux') self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo VALUES (?, ?)", ('bar', 'baz', 'qux')) self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo VALUES (?, ?)", ['bar', 'baz', 'qux']) From a598334cefd66466629b1dd71e00eddd314095b4 Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Sat, 21 Nov 2020 21:57:20 -0500 Subject: [PATCH 007/159] updated .travis.yml --- .travis.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.travis.yml b/.travis.yml index 4af42c0..b13ffd6 100644 --- a/.travis.yml +++ b/.travis.yml @@ -20,11 +20,11 @@ deploy: \"target_commitish\": \"$TRAVIS_COMMIT\", \"name\": \"v$(python setup.py --version)\" }" --user bot50:$GITHUB_TOKEN https://api.github.com/repos/$TRAVIS_REPO_SLUG/releases' on: - branch: fixing-lists + branch: master - provider: pypi user: "$PYPI_USERNAME" password: "$PYPI_PASSWORD" - on: fixing-lists + on: master notifications: slack: secure: lJklhcBVjDT6KzUNa3RFHXdXSeH7ytuuGrkZ5ZcR72CXMoTf2pMJTzPwRLWOp6lCSdDC9Y8MWLrcg/e33dJga4Jlp9alOmWqeqesaFjfee4st8vAsgNbv8/RajPH1gD2bnkt8oIwUzdHItdb5AucKFYjbH2g0d8ndoqYqUeBLrnsT1AP5G/Vi9OHC9OWNpR0FKaZIJE0Wt52vkPMH3sV2mFeIskByPB+56U5y547mualKxn61IVR/dhYBEtZQJuSvnwKHPOn9Pkk7cCa+SSSeTJ4w5LboY8T17otaYNauXo46i1bKIoGiBcCcrJyQHHiPQmcq/YU540MC5Wzt9YXUycmJzRi347oyQeDee27wV3XJlWMXuuhbtJiKCFny7BTQ160VATlj/dbwIzN99Ra6/BtTumv/6LyTdKIuVjdAkcN8dtdDW1nlrQ29zuPNCcXXzJ7zX7kQaOCUV1c2OrsbiH/0fE9nknUORn97txqhlYVi0QMS7764wFo6kg0vpmFQRkkQySsJl+TmgcZ01AlsJc2EMMWVuaj9Af9JU4/4yalqDiXIh1fOYYUZnLfOfWS+MsnI+/oLfqJFyMbrsQQTIjs+kTzbiEdhd2R4EZgusU/xRFWokS2NAvahexrRhRQ6tpAI+LezPrkNOR3aHiykBf+P9BkUa0wPp6V2Ayc6q0= From b2fc9694c5693e0fd3435976e08025893dc65251 Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Sat, 21 Nov 2020 23:39:39 -0500 Subject: [PATCH 008/159] working on transactions --- src/cs50/sql.py | 4 ++-- tests/foo.py | 51 +++++++++++++++++++++++++++++++++++++++++++++++++ tests/sql.py | 9 ++++++--- 3 files changed, 59 insertions(+), 5 deletions(-) create mode 100644 tests/foo.py diff --git a/src/cs50/sql.py b/src/cs50/sql.py index b9675d3..c0dc747 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -56,8 +56,8 @@ def __init__(self, url, **kwargs): if not os.path.isfile(matches.group(1)): raise RuntimeError("not a file: {}".format(matches.group(1))) - # Create engine, disabling SQLAlchemy's own autocommit mode, raising exception if back end's module not installed - self._engine = sqlalchemy.create_engine(url, **kwargs).execution_options(autocommit=False) + # Create engine, raising exception if back end's module not installed + self._engine = sqlalchemy.create_engine(url, **kwargs).execution_options(autocommit=True) # Listener for connections def connect(dbapi_connection, connection_record): diff --git a/tests/foo.py b/tests/foo.py new file mode 100644 index 0000000..11fda4d --- /dev/null +++ b/tests/foo.py @@ -0,0 +1,51 @@ +import logging +import sys + +sys.path.insert(0, "../src") + +import cs50 + +db = cs50.SQL("sqlite:///foo.db") + +logging.getLogger("cs50").disabled = False + +""" +#db.execute("SELECT ? FROM ? ORDER BY ?", "a", "tbl", "c") +db.execute("CREATE TABLE IF NOT EXISTS bar (firstname STRING)") + +db.execute("INSERT INTO bar VALUES (?)", "baz") +db.execute("INSERT INTO bar VALUES (?)", "qux") +db.execute("SELECT * FROM bar WHERE firstname IN (?)", ("baz", "qux")) +db.execute("DELETE FROM bar") +""" + +db = cs50.SQL("postgresql://postgres@localhost/test") + +""" +print(db.execute("DROP TABLE IF EXISTS cs50")) +print(db.execute("CREATE TABLE cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16), bin BYTEA)")) +print(db.execute("INSERT INTO cs50 (val) VALUES('foo')")) +print(db.execute("SELECT * FROM cs50")) + +print(db.execute("DROP TABLE IF EXISTS cs50")) +print(db.execute("CREATE TABLE cs50 (val VARCHAR(16), bin BYTEA)")) +print(db.execute("INSERT INTO cs50 (val) VALUES('foo')")) +print(db.execute("SELECT * FROM cs50")) +""" + +print(db.execute("DROP TABLE IF EXISTS cs50")) +print(db.execute("CREATE TABLE cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16), bin BYTEA)")) +print(db.execute("INSERT INTO cs50 (val) VALUES('foo')")) +print(db.execute("INSERT INTO cs50 (val) VALUES('foo')")) +print(db.execute("INSERT INTO cs50 (val) VALUES('foo')")) +print(db.execute("SELECT * FROM cs50")) +print(db.execute("COMMIT")) +""" +try: + print(db.execute("INSERT INTO cs50 (id, val) VALUES(1, 'bar')")) +except Exception as e: + print(e) + pass +print(db.execute("INSERT INTO cs50 (val) VALUES('foo')")) +print(db.execute("DELETE FROM cs50")) +""" diff --git a/tests/sql.py b/tests/sql.py index cbad470..fbfae61 100644 --- a/tests/sql.py +++ b/tests/sql.py @@ -150,7 +150,8 @@ def setUpClass(self): self.db = SQL("mysql://root@localhost/test") def setUp(self): - self.db.execute("CREATE TABLE cs50 (id INTEGER NOT NULL AUTO_INCREMENT, val VARCHAR(16), bin BLOB, PRIMARY KEY (id))") + self.db.execute("CREATE TABLE IF NOT EXISTS cs50 (id INTEGER NOT NULL AUTO_INCREMENT, val VARCHAR(16), bin BLOB, PRIMARY KEY (id))") + self.db.execute("DELETE FROM cs50") class PostgresTests(SQLTests): @@ -159,7 +160,8 @@ def setUpClass(self): self.db = SQL("postgresql://postgres@localhost/test") def setUp(self): - self.db.execute("CREATE TABLE cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16), bin BYTEA)") + self.db.execute("CREATE TABLE IF NOT EXISTS cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16), bin BYTEA)") + self.db.execute("DELETE FROM cs50") def test_cte(self): self.assertEqual(self.db.execute("WITH foo AS ( SELECT 1 AS bar ) SELECT bar FROM foo"), [{"bar": 1}]) @@ -173,7 +175,8 @@ def setUpClass(self): self.db = SQL("sqlite:///test.db") def setUp(self): - self.db.execute("CREATE TABLE cs50(id INTEGER PRIMARY KEY, val TEXT, bin BLOB)") + self.db.execute("CREATE TABLE IF NOT EXISTS cs50 (id INTEGER PRIMARY KEY, val TEXT, bin BLOB)") + self.db.execute("DELETE FROM cs50") def test_lastrowid(self): self.db.execute("CREATE TABLE foo(id INTEGER PRIMARY KEY AUTOINCREMENT, firstname TEXT, lastname TEXT)") From 9d007a4981fb83f6e08b4e78df584b321ea8625b Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Sun, 22 Nov 2020 11:24:43 -0500 Subject: [PATCH 009/159] fixed support for transactions --- src/cs50/sql.py | 41 +++++++++++++++++++++++++++++++++-------- tests/foo.py | 13 +++++-------- 2 files changed, 38 insertions(+), 16 deletions(-) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index c0dc747..e6e3f06 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -56,13 +56,14 @@ def __init__(self, url, **kwargs): if not os.path.isfile(matches.group(1)): raise RuntimeError("not a file: {}".format(matches.group(1))) - # Create engine, raising exception if back end's module not installed - self._engine = sqlalchemy.create_engine(url, **kwargs).execution_options(autocommit=True) + # Create engine, disabling SQLAlchemy's own autocommit mode, raising exception if back end's module not installed + self._engine = sqlalchemy.create_engine(url, **kwargs).execution_options(autocommit=False) # Listener for connections def connect(dbapi_connection, connection_record): - # Disable underlying API's own emitting of BEGIN and COMMIT + # Disable underlying API's own emitting of BEGIN and COMMIT so we can ourselves + # https://docs.sqlalchemy.org/en/13/dialects/sqlite.html#serializable-isolation-savepoints-transactional-ddl dbapi_connection.isolation_level = None # Enable foreign key constraints @@ -71,6 +72,9 @@ def connect(dbapi_connection, connection_record): cursor.execute("PRAGMA foreign_keys=ON") cursor.close() + # Autocommit by default + self._autocommit = True + # Register listener sqlalchemy.event.listen(self._engine, "connect", connect) @@ -90,9 +94,14 @@ def connect(dbapi_connection, connection_record): self._logger.disabled = disabled def __del__(self): + """Disconnect from database.""" + self._disconnect() + + def _disconnect(self): """Close database connection.""" if hasattr(self, "_connection"): self._connection.close() + delattr(self, "_connection") @_enable_logging def execute(self, sql, *args, **kwargs): @@ -107,7 +116,7 @@ def execute(self, sql, *args, **kwargs): import warnings # Parse statement, stripping comments and then leading/trailing whitespace - statements = sqlparse.parse(sqlparse.format(sql, strip_comments=True).strip()) + statements = sqlparse.parse(sqlparse.format(sql, keyword_case="upper", strip_comments=True).strip()) # Allow only one statement at a time, since SQLite doesn't support multiple # https://docs.python.org/3/library/sqlite3.html#sqlite3.Cursor.execute @@ -122,9 +131,10 @@ def execute(self, sql, *args, **kwargs): # Infer command from (unflattened) statement for token in statements[0]: - if token.ttype in [sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML]: - command = token.value.upper() - break + if token.ttype in [sqlparse.tokens.Keyword, sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML]: + if token.value in ["BEGIN", "DELETE", "INSERT", "SELECT", "START", "UPDATE"]: + command = token.value + break else: command = None @@ -316,8 +326,21 @@ def shutdown_session(exception=None): # Join tokens into statement, abbreviating binary data as <class 'bytes'> _statement = "".join([str(bytes) if token.ttype == sqlparse.tokens.Other else str(token) for token in tokens]) + # Check for start of transaction + if command in ["BEGIN", "START"]: + self._autocommit = False + # Execute statement - result = connection.execute(sqlalchemy.text(statement)) + if self._autocommit: + connection.execute(sqlalchemy.text("BEGIN")) + result = connection.execute(sqlalchemy.text(statement)) + connection.execute(sqlalchemy.text("COMMIT")) + else: + result = connection.execute(sqlalchemy.text(statement)) + + # Check for end of transaction + if command in ["COMMIT", "ROLLBACK"]: + self._autocommit = True # Return value ret = True @@ -359,6 +382,7 @@ def shutdown_session(exception=None): # If constraint violated, return None except sqlalchemy.exc.IntegrityError as e: + self._disconnect() self._logger.debug(termcolor.colored(statement, "yellow")) e = RuntimeError(e.orig) e.__cause__ = None @@ -366,6 +390,7 @@ def shutdown_session(exception=None): # If user errror except sqlalchemy.exc.OperationalError as e: + self._disconnect() self._logger.debug(termcolor.colored(statement, "red")) e = RuntimeError(e.orig) e.__cause__ = None diff --git a/tests/foo.py b/tests/foo.py index 11fda4d..7f32a00 100644 --- a/tests/foo.py +++ b/tests/foo.py @@ -5,11 +5,11 @@ import cs50 +""" db = cs50.SQL("sqlite:///foo.db") logging.getLogger("cs50").disabled = False -""" #db.execute("SELECT ? FROM ? ORDER BY ?", "a", "tbl", "c") db.execute("CREATE TABLE IF NOT EXISTS bar (firstname STRING)") @@ -36,16 +36,13 @@ print(db.execute("DROP TABLE IF EXISTS cs50")) print(db.execute("CREATE TABLE cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16), bin BYTEA)")) print(db.execute("INSERT INTO cs50 (val) VALUES('foo')")) -print(db.execute("INSERT INTO cs50 (val) VALUES('foo')")) -print(db.execute("INSERT INTO cs50 (val) VALUES('foo')")) +print(db.execute("INSERT INTO cs50 (val) VALUES('bar')")) +print(db.execute("INSERT INTO cs50 (val) VALUES('baz')")) print(db.execute("SELECT * FROM cs50")) -print(db.execute("COMMIT")) -""" try: print(db.execute("INSERT INTO cs50 (id, val) VALUES(1, 'bar')")) except Exception as e: print(e) pass -print(db.execute("INSERT INTO cs50 (val) VALUES('foo')")) -print(db.execute("DELETE FROM cs50")) -""" +print(db.execute("INSERT INTO cs50 (val) VALUES('qux')")) +#print(db.execute("DELETE FROM cs50")) From f44c56235a955807e156f25c6f94acdab95f8000 Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Sun, 22 Nov 2020 11:45:47 -0500 Subject: [PATCH 010/159] updated README --- README.md | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/README.md b/README.md index fb37280..c73641e 100644 --- a/README.md +++ b/README.md @@ -19,3 +19,27 @@ f = cs50.get_float(); i = cs50.get_int(); s = cs50.get_string(); ``` + +## Testing + +1. Run `cli50` in `python-cs50`. +1. Run `sudo su -`. +1. Run `apt install -y libmysqlclient-dev mysql-server postgresql`. +1. Run `pip3 install mysqlclient psycopg2-binary`. +1. In `/etc/mysql/mysql.conf.d/mysqld.conf`, add `skip-grant-tables` under `[mysqld]`. +1. In `/etc/profile.d/cli.sh`, remove `valgrind` function for now. +1. Run `service mysql start`. +1. Run `mysql -e 'CREATE DATABASE IF NOT EXISTS test;'`. +1. In `/etc/postgresql/10/main/pg_hba.conf, change: + ``` + local all postgres peer + host all all 127.0.0.1/32 md5 + ``` + to: + ``` + local all postgres trust + host all all 127.0.0.1/32 trust + ``` +1. Run `service postgresql start`. +1. Run `psql -c 'create database test;' -U postgres. +1. Run `touch test.db`. From 87dd1616e58e04a22d42cbbfb031e064f0676087 Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Sun, 22 Nov 2020 11:51:52 -0500 Subject: [PATCH 011/159] added sample tests --- README.md | 47 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/README.md b/README.md index c73641e..0780948 100644 --- a/README.md +++ b/README.md @@ -43,3 +43,50 @@ s = cs50.get_string(); 1. Run `service postgresql start`. 1. Run `psql -c 'create database test;' -U postgres. 1. Run `touch test.db`. + +### Sample Tests + +``` +import cs50 +db = cs50.SQL("sqlite:///foo.db") +db.execute("CREATE TABLE IF NOT EXISTS cs50 (id INTEGER PRIMARY KEY, val TEXT, bin BLOB)") +db.execute("INSERT INTO cs50 (val) VALUES('a')") +db.execute("INSERT INTO cs50 (val) VALUES('b')") +db.execute("BEGIN") +db.execute("INSERT INTO cs50 (val) VALUES('c')") +db.execute("INSERT INTO cs50 (val) VALUES('x')") +db.execute("INSERT INTO cs50 (val) VALUES('y')") +db.execute("ROLLBACK") +db.execute("INSERT INTO cs50 (val) VALUES('z')") +db.execute("COMMIT") + +--- + +import cs50 +db = cs50.SQL("mysql://root@localhost/test") +db.execute("CREATE TABLE IF NOT EXISTS cs50 (id INTEGER PRIMARY KEY, val TEXT, bin BLOB)") +db.execute("INSERT INTO cs50 (val) VALUES('a')") +db.execute("INSERT INTO cs50 (val) VALUES('b')") +db.execute("BEGIN") +db.execute("INSERT INTO cs50 (val) VALUES('c')") +db.execute("INSERT INTO cs50 (val) VALUES('x')") +db.execute("INSERT INTO cs50 (val) VALUES('y')") +db.execute("ROLLBACK") +db.execute("INSERT INTO cs50 (val) VALUES('z')") +db.execute("COMMIT") + +--- + +import cs50 +db = cs50.SQL("postgresql://postgres@localhost/test") +db.execute("CREATE TABLE IF NOT EXISTS cs50 (id INTEGER PRIMARY KEY, val TEXT, bin BLOB)") +db.execute("INSERT INTO cs50 (val) VALUES('a')") +db.execute("INSERT INTO cs50 (val) VALUES('b')") +db.execute("BEGIN") +db.execute("INSERT INTO cs50 (val) VALUES('c')") +db.execute("INSERT INTO cs50 (val) VALUES('x')") +db.execute("INSERT INTO cs50 (val) VALUES('y')") +db.execute("ROLLBACK") +db.execute("INSERT INTO cs50 (val) VALUES('z')") +db.execute("COMMIT") +``` From a14e5493e5d129f44e24edfd5f5229cd158dd1ba Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Sun, 22 Nov 2020 14:41:17 -0500 Subject: [PATCH 012/159] tidied code --- src/cs50/flask.py | 9 +++++---- src/cs50/sql.py | 5 ++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/cs50/flask.py b/src/cs50/flask.py index 1d59064..538d32a 100644 --- a/src/cs50/flask.py +++ b/src/cs50/flask.py @@ -14,20 +14,21 @@ def _wrap_flask(f): f.logging.default_handler.formatter.formatException = lambda exc_info: _formatException(*exc_info) - if os.getenv("CS50_IDE_TYPE") == "online": + if os.getenv("CS50_IDE_TYPE"): from werkzeug.middleware.proxy_fix import ProxyFix _flask_init_before = f.Flask.__init__ def _flask_init_after(self, *args, **kwargs): _flask_init_before(self, *args, **kwargs) - self.wsgi_app = ProxyFix(self.wsgi_app, x_proto=1) + self.config["TEMPLATES_AUTO_RELOAD"] = True # Automatically reload templates + self.wsgi_app = ProxyFix(self.wsgi_app, x_proto=1) # For HTTPS-to-HTTP proxy f.Flask.__init__ = _flask_init_after -# Flask was imported before cs50 +# If Flask was imported before cs50 if "flask" in sys.modules: _wrap_flask(sys.modules["flask"]) -# Flask wasn't imported +# If Flask wasn't imported else: flask_loader = pkgutil.get_loader('flask') if flask_loader: diff --git a/src/cs50/sql.py b/src/cs50/sql.py index e6e3f06..b3f25ce 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -333,10 +333,9 @@ def shutdown_session(exception=None): # Execute statement if self._autocommit: connection.execute(sqlalchemy.text("BEGIN")) - result = connection.execute(sqlalchemy.text(statement)) + result = connection.execute(sqlalchemy.text(statement)) + if self._autocommit: connection.execute(sqlalchemy.text("COMMIT")) - else: - result = connection.execute(sqlalchemy.text(statement)) # Check for end of transaction if command in ["COMMIT", "ROLLBACK"]: From c68324df22850516aa1a54507052505072df90d7 Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Mon, 23 Nov 2020 13:16:06 -0500 Subject: [PATCH 013/159] fixed README --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 0780948..7705fd6 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ s = cs50.get_string(); 1. Run `sudo su -`. 1. Run `apt install -y libmysqlclient-dev mysql-server postgresql`. 1. Run `pip3 install mysqlclient psycopg2-binary`. -1. In `/etc/mysql/mysql.conf.d/mysqld.conf`, add `skip-grant-tables` under `[mysqld]`. +1. In `/etc/mysql/mysql.conf.d/mysqld.cnf`, add `skip-grant-tables` under `[mysqld]`. 1. In `/etc/profile.d/cli.sh`, remove `valgrind` function for now. 1. Run `service mysql start`. 1. Run `mysql -e 'CREATE DATABASE IF NOT EXISTS test;'`. @@ -41,7 +41,7 @@ s = cs50.get_string(); host all all 127.0.0.1/32 trust ``` 1. Run `service postgresql start`. -1. Run `psql -c 'create database test;' -U postgres. +1. Run `psql -c 'create database test;' -U postgres`. 1. Run `touch test.db`. ### Sample Tests From 083bf5e887eeb06790d0ed0049ac55b3f83251a3 Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Mon, 23 Nov 2020 13:41:40 -0500 Subject: [PATCH 014/159] corrected syntax in README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 7705fd6..85d5172 100644 --- a/README.md +++ b/README.md @@ -79,7 +79,7 @@ db.execute("COMMIT") import cs50 db = cs50.SQL("postgresql://postgres@localhost/test") -db.execute("CREATE TABLE IF NOT EXISTS cs50 (id INTEGER PRIMARY KEY, val TEXT, bin BLOB)") +db.execute("CREATE TABLE IF NOT EXISTS cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16), bin BYTEA)") db.execute("INSERT INTO cs50 (val) VALUES('a')") db.execute("INSERT INTO cs50 (val) VALUES('b')") db.execute("BEGIN") From a11c9b9912c039d4c5a0a1a6622fa6a80be153cc Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Mon, 23 Nov 2020 13:41:56 -0500 Subject: [PATCH 015/159] detecting PostgreSQL syntax errors --- src/cs50/sql.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index b3f25ce..38b60db 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -388,7 +388,7 @@ def shutdown_session(exception=None): raise e # If user errror - except sqlalchemy.exc.OperationalError as e: + except (sqlalchemy.exc.OperationalError, sqlalchemy.exc.ProgrammingError) as e: self._disconnect() self._logger.debug(termcolor.colored(statement, "red")) e = RuntimeError(e.orig) From 727df506e0c1731c71729515459deecc96d440e7 Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Mon, 23 Nov 2020 14:46:25 -0500 Subject: [PATCH 016/159] only calling teardown_appcontext once --- src/cs50/sql.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index 38b60db..d92ede3 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -281,7 +281,7 @@ def execute(self, sql, *args, **kwargs): # Join tokens into statement statement = "".join([str(token) for token in tokens]) - # Connect to database (for transactions' sake) + # Connect to database try: # Infer whether Flask is installed @@ -290,19 +290,23 @@ def execute(self, sql, *args, **kwargs): # Infer whether app is defined assert flask.current_app - # If no connection for app's current request yet + # If new context if not hasattr(flask.g, "_connection"): - # Connect now - flask.g._connection = self._engine.connect() + # Ready to connect + flask.g._connection = None # Disconnect later @flask.current_app.teardown_appcontext def shutdown_session(exception=None): - if hasattr(flask.g, "_connection"): + if flask.g._connection: flask.g._connection.close() - # Use this connection + # If no connection for context yet + if not flask.g._connection: + flas.g._connection = self._engine.connect() + + # Use context's connection connection = flask.g._connection except (ModuleNotFoundError, AssertionError): From 31305e173a25366e45468af1f8767547f50e7e31 Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Mon, 23 Nov 2020 15:03:09 -0500 Subject: [PATCH 017/159] raising ValueError, no longer disconnecting on IntegrityError --- src/cs50/sql.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index d92ede3..f6da366 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -385,13 +385,12 @@ def shutdown_session(exception=None): # If constraint violated, return None except sqlalchemy.exc.IntegrityError as e: - self._disconnect() self._logger.debug(termcolor.colored(statement, "yellow")) - e = RuntimeError(e.orig) + e = ValueError(e.orig) e.__cause__ = None raise e - # If user errror + # If user error except (sqlalchemy.exc.OperationalError, sqlalchemy.exc.ProgrammingError) as e: self._disconnect() self._logger.debug(termcolor.colored(statement, "red")) From e8cca8d28621bd95b2a67ebd0a0c5c4d5b714b53 Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Mon, 23 Nov 2020 15:08:23 -0500 Subject: [PATCH 018/159] updated tests for ValueError --- tests/sql.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/sql.py b/tests/sql.py index fbfae61..31b927d 100644 --- a/tests/sql.py +++ b/tests/sql.py @@ -187,12 +187,12 @@ def test_lastrowid(self): def test_integrity_constraints(self): self.db.execute("CREATE TABLE foo(id INTEGER PRIMARY KEY)") self.assertEqual(self.db.execute("INSERT INTO foo VALUES(1)"), 1) - self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo VALUES(1)") + self.assertRaises(ValueError, self.db.execute, "INSERT INTO foo VALUES(1)") def test_foreign_key_support(self): self.db.execute("CREATE TABLE foo(id INTEGER PRIMARY KEY)") self.db.execute("CREATE TABLE bar(foo_id INTEGER, FOREIGN KEY (foo_id) REFERENCES foo(id))") - self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO bar VALUES(50)") + self.assertRaises(ValueError, self.db.execute, "INSERT INTO bar VALUES(50)") def test_qmark(self): self.db.execute("CREATE TABLE foo (firstname STRING, lastname STRING)") From 16e52bf3f015d212e090daafa020d641f32aa35d Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Mon, 23 Nov 2020 15:09:06 -0500 Subject: [PATCH 019/159] updated test for ValueError --- tests/sql.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/sql.py b/tests/sql.py index 31b927d..f893895 100644 --- a/tests/sql.py +++ b/tests/sql.py @@ -181,7 +181,7 @@ def setUp(self): def test_lastrowid(self): self.db.execute("CREATE TABLE foo(id INTEGER PRIMARY KEY AUTOINCREMENT, firstname TEXT, lastname TEXT)") self.assertEqual(self.db.execute("INSERT INTO foo (firstname, lastname) VALUES('firstname', 'lastname')"), 1) - self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo (id, firstname, lastname) VALUES(1, 'firstname', 'lastname')") + self.assertRaises(ValueError, self.db.execute, "INSERT INTO foo (id, firstname, lastname) VALUES(1, 'firstname', 'lastname')") self.assertEqual(self.db.execute("INSERT OR IGNORE INTO foo (id, firstname, lastname) VALUES(1, 'firstname', 'lastname')"), None) def test_integrity_constraints(self): From 5857d3f10753a726b5085e6754de934f4d610b27 Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Mon, 23 Nov 2020 15:25:23 -0500 Subject: [PATCH 020/159] removed support for unpacking dicts and lists in execute --- src/cs50/sql.py | 14 +------------- tests/sql.py | 21 --------------------- 2 files changed, 1 insertion(+), 34 deletions(-) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index f6da366..148d57f 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -127,7 +127,7 @@ def execute(self, sql, *args, **kwargs): # Ensure named and positional parameters are mutually exclusive if len(args) > 0 and len(kwargs) > 0: - raise RuntimeError("cannot pass both named and positional parameters") + raise RuntimeError("cannot pass both positional and named parameters") # Infer command from (unflattened) statement for token in statements[0]: @@ -163,18 +163,6 @@ def execute(self, sql, *args, **kwargs): # Remember placeholder's index, name placeholders[index] = name - # If more placeholders than arguments - if len(args) == 1 and len(placeholders) > 1: - - # If user passed args as list or tuple, explode values into args - if isinstance(args[0], (list, tuple)): - args = args[0] - - # If user passed kwargs as dict, migrate values from args to kwargs - elif len(kwargs) == 0 and isinstance(args[0], dict): - kwargs = args[0] - args = [] - # If no placeholders if not paramstyle: diff --git a/tests/sql.py b/tests/sql.py index f893895..8cafdde 100644 --- a/tests/sql.py +++ b/tests/sql.py @@ -211,26 +211,16 @@ def test_qmark(self): self.db.execute("INSERT INTO foo VALUES ('qux', 'quux')") self.assertEqual(self.db.execute("SELECT * FROM foo WHERE firstname = ?", 'qux'), [{"firstname": "qux", "lastname": "quux"}]) self.assertEqual(self.db.execute("SELECT * FROM foo WHERE firstname = ? AND lastname = ?", "qux", "quux"), [{"firstname": "qux", "lastname": "quux"}]) - self.assertEqual(self.db.execute("SELECT * FROM foo WHERE firstname = ? AND lastname = ?", ("qux", "quux")), [{"firstname": "qux", "lastname": "quux"}]) - self.assertEqual(self.db.execute("SELECT * FROM foo WHERE firstname = ? AND lastname = ?", ["qux", "quux"]), [{"firstname": "qux", "lastname": "quux"}]) self.db.execute("DELETE FROM foo") self.db.execute("INSERT INTO foo VALUES (?)", ("bar", "baz")) self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "bar", "lastname": "baz"}]) self.db.execute("DELETE FROM foo") - self.db.execute("INSERT INTO foo VALUES (?, ?)", ("bar", "baz")) - self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "bar", "lastname": "baz"}]) - self.db.execute("DELETE FROM foo") - self.db.execute("INSERT INTO foo VALUES (?)", ["bar", "baz"]) self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "bar", "lastname": "baz"}]) self.db.execute("DELETE FROM foo") - self.db.execute("INSERT INTO foo VALUES (?, ?)", ["bar", "baz"]) - self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "bar", "lastname": "baz"}]) - self.db.execute("DELETE FROM foo") - self.db.execute("INSERT INTO foo VALUES (?,?)", "bar", "baz") self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "bar", "lastname": "baz"}]) self.db.execute("DELETE FROM foo") @@ -308,19 +298,8 @@ def test_numeric(self): self.db.execute("INSERT INTO foo VALUES ('qux', 'quux')") self.assertEqual(self.db.execute("SELECT * FROM foo WHERE firstname = :1", 'qux'), [{"firstname": "qux", "lastname": "quux"}]) self.assertEqual(self.db.execute("SELECT * FROM foo WHERE firstname = :1 AND lastname = :2", "qux", "quux"), [{"firstname": "qux", "lastname": "quux"}]) - self.assertEqual(self.db.execute("SELECT * FROM foo WHERE firstname = :1 AND lastname = :2", ("qux", "quux")), [{"firstname": "qux", "lastname": "quux"}]) - self.assertEqual(self.db.execute("SELECT * FROM foo WHERE firstname = :1 AND lastname = :2", ["qux", "quux"]), [{"firstname": "qux", "lastname": "quux"}]) self.db.execute("DELETE FROM foo") - self.db.execute("INSERT INTO foo VALUES (:1, :2)", ("bar", "baz")) - self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "bar", "lastname": "baz"}]) - self.db.execute("DELETE FROM foo") - - self.db.execute("INSERT INTO foo VALUES (:1, :2)", ["bar", "baz"]) - self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "bar", "lastname": "baz"}]) - self.db.execute("DELETE FROM foo") - - self.db.execute("INSERT INTO foo VALUES (:1,:2)", "bar", "baz") self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "bar", "lastname": "baz"}]) self.db.execute("DELETE FROM foo") From 0a63c892f47ef677376e623f623a71f5037b39ba Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Mon, 23 Nov 2020 15:26:13 -0500 Subject: [PATCH 021/159] version++ --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 7108b87..db44134 100644 --- a/setup.py +++ b/setup.py @@ -16,5 +16,5 @@ package_dir={"": "src"}, packages=["cs50"], url="https://github.com/cs50/python-cs50", - version="5.0.5" + version="6.0.0" ) From ead20a7ac93dd3f58bca7e60c83a0840ce14fa2f Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Mon, 23 Nov 2020 15:36:19 -0500 Subject: [PATCH 022/159] removed TEMPLATES_AUTO_RELOAD --- src/cs50/flask.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/cs50/flask.py b/src/cs50/flask.py index 538d32a..23e1b0a 100644 --- a/src/cs50/flask.py +++ b/src/cs50/flask.py @@ -14,12 +14,11 @@ def _wrap_flask(f): f.logging.default_handler.formatter.formatException = lambda exc_info: _formatException(*exc_info) - if os.getenv("CS50_IDE_TYPE"): + if os.getenv("CS50_IDE_TYPE") == "online": from werkzeug.middleware.proxy_fix import ProxyFix _flask_init_before = f.Flask.__init__ def _flask_init_after(self, *args, **kwargs): _flask_init_before(self, *args, **kwargs) - self.config["TEMPLATES_AUTO_RELOAD"] = True # Automatically reload templates self.wsgi_app = ProxyFix(self.wsgi_app, x_proto=1) # For HTTPS-to-HTTP proxy f.Flask.__init__ = _flask_init_after From 87394fd946c226dfdfdd443535edc8c295303536 Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Tue, 24 Nov 2020 10:27:56 -0500 Subject: [PATCH 023/159] fix typo --- src/cs50/sql.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index 148d57f..fd66b9e 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -292,7 +292,7 @@ def shutdown_session(exception=None): # If no connection for context yet if not flask.g._connection: - flas.g._connection = self._engine.connect() + flask.g._connection = self._engine.connect() # Use context's connection connection = flask.g._connection From a80dfeb0b58bb95f576cdce6d93aa7852dd4ab91 Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Wed, 25 Nov 2020 09:31:36 -0500 Subject: [PATCH 024/159] configure root logger --- src/cs50/cs50.py | 8 ++++++++ src/cs50/flask.py | 2 -- src/cs50/sql.py | 45 +++++++-------------------------------------- 3 files changed, 15 insertions(+), 40 deletions(-) diff --git a/src/cs50/cs50.py b/src/cs50/cs50.py index 6835a62..54adcd1 100644 --- a/src/cs50/cs50.py +++ b/src/cs50/cs50.py @@ -1,6 +1,7 @@ from __future__ import print_function import inspect +import logging import os import re import sys @@ -11,6 +12,13 @@ from traceback import format_exception +# Configure default logging handler and formatter +logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.DEBUG) + +# Patch formatException +logging.root.handlers[0].formatter.formatException = lambda exc_info: _formatException(*exc_info) + + class _flushfile(): """ Disable buffering for standard output and standard error. diff --git a/src/cs50/flask.py b/src/cs50/flask.py index 23e1b0a..324ec30 100644 --- a/src/cs50/flask.py +++ b/src/cs50/flask.py @@ -12,8 +12,6 @@ def _wrap_flask(f): if f.__version__ < StrictVersion("1.0"): return - f.logging.default_handler.formatter.formatException = lambda exc_info: _formatException(*exc_info) - if os.getenv("CS50_IDE_TYPE") == "online": from werkzeug.middleware.proxy_fix import ProxyFix _flask_init_before = f.Flask.__init__ diff --git a/src/cs50/sql.py b/src/cs50/sql.py index fd66b9e..3a26774 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -1,28 +1,4 @@ -def _enable_logging(f): - """Enable logging of SQL statements when Flask is in use.""" - - import logging - import functools - - @functools.wraps(f) - def decorator(*args, **kwargs): - - # Infer whether Flask is installed - try: - import flask - except ModuleNotFoundError: - return f(*args, **kwargs) - - # Enable logging - disabled = logging.getLogger("cs50").disabled - if flask.current_app: - logging.getLogger("cs50").disabled = False - try: - return f(*args, **kwargs) - finally: - logging.getLogger("cs50").disabled = disabled - - return decorator +import logging class SQL(object): @@ -45,9 +21,6 @@ def __init__(self, url, **kwargs): import sqlalchemy import sqlite3 - # Get logger - self._logger = logging.getLogger("cs50") - # Require that file already exist for SQLite matches = re.search(r"^sqlite:///(.+)$", url) if matches: @@ -78,20 +51,17 @@ def connect(dbapi_connection, connection_record): # Register listener sqlalchemy.event.listen(self._engine, "connect", connect) - # Log statements to standard error - logging.basicConfig(level=logging.DEBUG) - # Test database + disabled = logging.root.disabled + logging.root.disabled = True try: - disabled = self._logger.disabled - self._logger.disabled = True self.execute("SELECT 1") except sqlalchemy.exc.OperationalError as e: e = RuntimeError(_parse_exception(e)) e.__cause__ = None raise e finally: - self._logger.disabled = disabled + logging.root.disabled = disabled def __del__(self): """Disconnect from database.""" @@ -103,7 +73,6 @@ def _disconnect(self): self._connection.close() delattr(self, "_connection") - @_enable_logging def execute(self, sql, *args, **kwargs): """Execute a SQL statement.""" @@ -373,7 +342,7 @@ def shutdown_session(exception=None): # If constraint violated, return None except sqlalchemy.exc.IntegrityError as e: - self._logger.debug(termcolor.colored(statement, "yellow")) + logging.debug(termcolor.colored(statement, "yellow")) e = ValueError(e.orig) e.__cause__ = None raise e @@ -381,14 +350,14 @@ def shutdown_session(exception=None): # If user error except (sqlalchemy.exc.OperationalError, sqlalchemy.exc.ProgrammingError) as e: self._disconnect() - self._logger.debug(termcolor.colored(statement, "red")) + logging.debug(termcolor.colored(statement, "red")) e = RuntimeError(e.orig) e.__cause__ = None raise e # Return value else: - self._logger.debug(termcolor.colored(_statement, "green")) + logging.debug(termcolor.colored(_statement, "green")) return ret def _escape(self, value): From f3750c71d9a81aa00ae1c76d2aa1941bbb264e92 Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Wed, 25 Nov 2020 14:45:20 -0500 Subject: [PATCH 025/159] handle no root logger handlers --- src/cs50/cs50.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/cs50/cs50.py b/src/cs50/cs50.py index 54adcd1..bde1ec7 100644 --- a/src/cs50/cs50.py +++ b/src/cs50/cs50.py @@ -15,8 +15,11 @@ # Configure default logging handler and formatter logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.DEBUG) -# Patch formatException -logging.root.handlers[0].formatter.formatException = lambda exc_info: _formatException(*exc_info) +try: + # Patch formatException + logging.root.handlers[0].formatter.formatException = lambda exc_info: _formatException(*exc_info) +except IndexError: + pass class _flushfile(): From db0d14a5d2f3d4b05e2037e911ebe93bcb6d06f7 Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Wed, 25 Nov 2020 15:15:43 -0500 Subject: [PATCH 026/159] use separate logger for cs50.sql --- src/cs50/cs50.py | 1 + src/cs50/sql.py | 32 +++++++++++++++++++++++--------- 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/src/cs50/cs50.py b/src/cs50/cs50.py index bde1ec7..77be18b 100644 --- a/src/cs50/cs50.py +++ b/src/cs50/cs50.py @@ -13,6 +13,7 @@ # Configure default logging handler and formatter +# Prevent flask, werkzeug, etc from adding default handler logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.DEBUG) try: diff --git a/src/cs50/sql.py b/src/cs50/sql.py index 3a26774..1d64088 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -1,6 +1,3 @@ -import logging - - class SQL(object): """Wrap SQLAlchemy to provide a simple SQL API.""" @@ -21,6 +18,8 @@ def __init__(self, url, **kwargs): import sqlalchemy import sqlite3 + from .cs50 import _formatException + # Require that file already exist for SQLite matches = re.search(r"^sqlite:///(.+)$", url) if matches: @@ -51,9 +50,24 @@ def connect(dbapi_connection, connection_record): # Register listener sqlalchemy.event.listen(self._engine, "connect", connect) + # Configure logger + self._logger = logging.getLogger(__name__) + self._logger.setLevel(logging.DEBUG) + + # Log messages once + self._logger.propagate = False + + ch = logging.StreamHandler() + ch.setLevel(logging.DEBUG) + + formatter = logging.Formatter("%(levelname)s: %(message)s") + formatter.formatException = lambda exc_info: _formatException(*exc_info) + ch.setFormatter(formatter) + self._logger.addHandler(ch) + # Test database - disabled = logging.root.disabled - logging.root.disabled = True + disabled = self._logger.disabled + self._logger.disabled = True try: self.execute("SELECT 1") except sqlalchemy.exc.OperationalError as e: @@ -61,7 +75,7 @@ def connect(dbapi_connection, connection_record): e.__cause__ = None raise e finally: - logging.root.disabled = disabled + self._logger.disabled = disabled def __del__(self): """Disconnect from database.""" @@ -342,7 +356,7 @@ def shutdown_session(exception=None): # If constraint violated, return None except sqlalchemy.exc.IntegrityError as e: - logging.debug(termcolor.colored(statement, "yellow")) + self._logger.debug(termcolor.colored(statement, "yellow")) e = ValueError(e.orig) e.__cause__ = None raise e @@ -350,14 +364,14 @@ def shutdown_session(exception=None): # If user error except (sqlalchemy.exc.OperationalError, sqlalchemy.exc.ProgrammingError) as e: self._disconnect() - logging.debug(termcolor.colored(statement, "red")) + self._logger.debug(termcolor.colored(statement, "red")) e = RuntimeError(e.orig) e.__cause__ = None raise e # Return value else: - logging.debug(termcolor.colored(_statement, "green")) + self._logger.debug(termcolor.colored(_statement, "green")) return ret def _escape(self, value): From aea0dfe847e8d9aad8b6fff40f604fdca7e4de70 Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Wed, 25 Nov 2020 15:25:06 -0500 Subject: [PATCH 027/159] move logger to module level --- src/cs50/sql.py | 50 ++++++++++++++++++++++++++----------------------- 1 file changed, 27 insertions(+), 23 deletions(-) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index 1d64088..bb7cf0f 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -1,3 +1,24 @@ +import logging + +from .cs50 import _formatException + + +# Configure logger +_logger = logging.getLogger(__name__) +_logger.setLevel(logging.DEBUG) + +# Log messages once +_logger.propagate = False + +ch = logging.StreamHandler() +ch.setLevel(logging.DEBUG) + +formatter = logging.Formatter("%(levelname)s: %(message)s") +formatter.formatException = lambda exc_info: _formatException(*exc_info) +ch.setFormatter(formatter) +_logger.addHandler(ch) + + class SQL(object): """Wrap SQLAlchemy to provide a simple SQL API.""" @@ -12,14 +33,11 @@ def __init__(self, url, **kwargs): """ # Lazily import - import logging import os import re import sqlalchemy import sqlite3 - from .cs50 import _formatException - # Require that file already exist for SQLite matches = re.search(r"^sqlite:///(.+)$", url) if matches: @@ -50,24 +68,10 @@ def connect(dbapi_connection, connection_record): # Register listener sqlalchemy.event.listen(self._engine, "connect", connect) - # Configure logger - self._logger = logging.getLogger(__name__) - self._logger.setLevel(logging.DEBUG) - - # Log messages once - self._logger.propagate = False - - ch = logging.StreamHandler() - ch.setLevel(logging.DEBUG) - - formatter = logging.Formatter("%(levelname)s: %(message)s") - formatter.formatException = lambda exc_info: _formatException(*exc_info) - ch.setFormatter(formatter) - self._logger.addHandler(ch) # Test database - disabled = self._logger.disabled - self._logger.disabled = True + disabled = _logger.disabled + _logger.disabled = True try: self.execute("SELECT 1") except sqlalchemy.exc.OperationalError as e: @@ -75,7 +79,7 @@ def connect(dbapi_connection, connection_record): e.__cause__ = None raise e finally: - self._logger.disabled = disabled + _logger.disabled = disabled def __del__(self): """Disconnect from database.""" @@ -356,7 +360,7 @@ def shutdown_session(exception=None): # If constraint violated, return None except sqlalchemy.exc.IntegrityError as e: - self._logger.debug(termcolor.colored(statement, "yellow")) + _logger.debug(termcolor.colored(statement, "yellow")) e = ValueError(e.orig) e.__cause__ = None raise e @@ -364,14 +368,14 @@ def shutdown_session(exception=None): # If user error except (sqlalchemy.exc.OperationalError, sqlalchemy.exc.ProgrammingError) as e: self._disconnect() - self._logger.debug(termcolor.colored(statement, "red")) + _logger.debug(termcolor.colored(statement, "red")) e = RuntimeError(e.orig) e.__cause__ = None raise e # Return value else: - self._logger.debug(termcolor.colored(_statement, "green")) + _logger.debug(termcolor.colored(_statement, "green")) return ret def _escape(self, value): From 74f03f089e96f60dccaf0a3648d30f5a1eed9cf0 Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Wed, 25 Nov 2020 15:27:30 -0500 Subject: [PATCH 028/159] s/ch/handler/ --- src/cs50/sql.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index bb7cf0f..8bd72cf 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -10,13 +10,13 @@ # Log messages once _logger.propagate = False -ch = logging.StreamHandler() -ch.setLevel(logging.DEBUG) +handler = logging.StreamHandler() +handler.setLevel(logging.DEBUG) formatter = logging.Formatter("%(levelname)s: %(message)s") formatter.formatException = lambda exc_info: _formatException(*exc_info) -ch.setFormatter(formatter) -_logger.addHandler(ch) +handler.setFormatter(formatter) +_logger.addHandler(handler) class SQL(object): From 9cd038cf95ff6fb040ed411597aa36debfb5cb32 Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Wed, 25 Nov 2020 15:46:23 -0500 Subject: [PATCH 029/159] rename logger to cs50 --- src/cs50/sql.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index 8bd72cf..6f26eb5 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -4,7 +4,7 @@ # Configure logger -_logger = logging.getLogger(__name__) +_logger = logging.getLogger("cs50") _logger.setLevel(logging.DEBUG) # Log messages once From f2ba33009eda856d80730aa1eb0fcf1eb8f82874 Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Wed, 25 Nov 2020 15:55:24 -0500 Subject: [PATCH 030/159] configure cs50 logger in cs50 module --- src/cs50/cs50.py | 15 +++++++++++++++ src/cs50/sql.py | 36 +++++++++--------------------------- 2 files changed, 24 insertions(+), 27 deletions(-) diff --git a/src/cs50/cs50.py b/src/cs50/cs50.py index 77be18b..0fc481b 100644 --- a/src/cs50/cs50.py +++ b/src/cs50/cs50.py @@ -22,6 +22,21 @@ except IndexError: pass +# Configure cs50 logger +_logger = logging.getLogger("cs50") +_logger.setLevel(logging.DEBUG) + +# Log messages once +_logger.propagate = False + +handler = logging.StreamHandler() +handler.setLevel(logging.DEBUG) + +formatter = logging.Formatter("%(levelname)s: %(message)s") +formatter.formatException = lambda exc_info: _formatException(*exc_info) +handler.setFormatter(formatter) +_logger.addHandler(handler) + class _flushfile(): """ diff --git a/src/cs50/sql.py b/src/cs50/sql.py index 6f26eb5..71494d3 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -1,24 +1,3 @@ -import logging - -from .cs50 import _formatException - - -# Configure logger -_logger = logging.getLogger("cs50") -_logger.setLevel(logging.DEBUG) - -# Log messages once -_logger.propagate = False - -handler = logging.StreamHandler() -handler.setLevel(logging.DEBUG) - -formatter = logging.Formatter("%(levelname)s: %(message)s") -formatter.formatException = lambda exc_info: _formatException(*exc_info) -handler.setFormatter(formatter) -_logger.addHandler(handler) - - class SQL(object): """Wrap SQLAlchemy to provide a simple SQL API.""" @@ -33,6 +12,7 @@ def __init__(self, url, **kwargs): """ # Lazily import + import logging import os import re import sqlalchemy @@ -49,6 +29,8 @@ def __init__(self, url, **kwargs): # Create engine, disabling SQLAlchemy's own autocommit mode, raising exception if back end's module not installed self._engine = sqlalchemy.create_engine(url, **kwargs).execution_options(autocommit=False) + self._logger = logging.getLogger("cs50") + # Listener for connections def connect(dbapi_connection, connection_record): @@ -70,8 +52,8 @@ def connect(dbapi_connection, connection_record): # Test database - disabled = _logger.disabled - _logger.disabled = True + disabled = self._logger.disabled + self._logger.disabled = True try: self.execute("SELECT 1") except sqlalchemy.exc.OperationalError as e: @@ -79,7 +61,7 @@ def connect(dbapi_connection, connection_record): e.__cause__ = None raise e finally: - _logger.disabled = disabled + self._logger.disabled = disabled def __del__(self): """Disconnect from database.""" @@ -360,7 +342,7 @@ def shutdown_session(exception=None): # If constraint violated, return None except sqlalchemy.exc.IntegrityError as e: - _logger.debug(termcolor.colored(statement, "yellow")) + self._logger.debug(termcolor.colored(statement, "yellow")) e = ValueError(e.orig) e.__cause__ = None raise e @@ -368,14 +350,14 @@ def shutdown_session(exception=None): # If user error except (sqlalchemy.exc.OperationalError, sqlalchemy.exc.ProgrammingError) as e: self._disconnect() - _logger.debug(termcolor.colored(statement, "red")) + self._logger.debug(termcolor.colored(statement, "red")) e = RuntimeError(e.orig) e.__cause__ = None raise e # Return value else: - _logger.debug(termcolor.colored(_statement, "green")) + self._logger.debug(termcolor.colored(_statement, "green")) return ret def _escape(self, value): From 7dd19ad0fad5b15892604be01fe2e86a1c1dcf56 Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Wed, 25 Nov 2020 16:08:43 -0500 Subject: [PATCH 031/159] forcibly enable cs50 logger in flask apps --- src/cs50/sql.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index 71494d3..3e695d6 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -1,3 +1,30 @@ +def _enable_logging(f): + """Enable logging of SQL statements when Flask is in use.""" + + import logging + import functools + + @functools.wraps(f) + def decorator(*args, **kwargs): + + # Infer whether Flask is installed + try: + import flask + except ModuleNotFoundError: + return f(*args, **kwargs) + + # Enable logging + disabled = logging.getLogger("cs50").disabled + if flask.current_app: + logging.getLogger("cs50").disabled = False + try: + return f(*args, **kwargs) + finally: + logging.getLogger("cs50").disabled = disabled + + return decorator + + class SQL(object): """Wrap SQLAlchemy to provide a simple SQL API.""" @@ -73,6 +100,7 @@ def _disconnect(self): self._connection.close() delattr(self, "_connection") + @_enable_logging def execute(self, sql, *args, **kwargs): """Execute a SQL statement.""" From 5a575e7c985a48072a6262c0a6674028a9b68da9 Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Sat, 28 Nov 2020 13:38:41 -0500 Subject: [PATCH 032/159] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 85d5172..0fb6d64 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ ## Installation ``` -pip install cs50 +pip3 install cs50 ``` ## Usage From ddeba083dbcd38ea9a9f7e256980c3dce01d1008 Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Sat, 28 Nov 2020 18:15:22 -0500 Subject: [PATCH 033/159] updated Travis [skip ci] --- .travis.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.travis.yml b/.travis.yml index b13ffd6..0433f6a 100644 --- a/.travis.yml +++ b/.travis.yml @@ -20,11 +20,11 @@ deploy: \"target_commitish\": \"$TRAVIS_COMMIT\", \"name\": \"v$(python setup.py --version)\" }" --user bot50:$GITHUB_TOKEN https://api.github.com/repos/$TRAVIS_REPO_SLUG/releases' on: - branch: master + branch: main - provider: pypi user: "$PYPI_USERNAME" password: "$PYPI_PASSWORD" - on: master + on: main notifications: slack: secure: lJklhcBVjDT6KzUNa3RFHXdXSeH7ytuuGrkZ5ZcR72CXMoTf2pMJTzPwRLWOp6lCSdDC9Y8MWLrcg/e33dJga4Jlp9alOmWqeqesaFjfee4st8vAsgNbv8/RajPH1gD2bnkt8oIwUzdHItdb5AucKFYjbH2g0d8ndoqYqUeBLrnsT1AP5G/Vi9OHC9OWNpR0FKaZIJE0Wt52vkPMH3sV2mFeIskByPB+56U5y547mualKxn61IVR/dhYBEtZQJuSvnwKHPOn9Pkk7cCa+SSSeTJ4w5LboY8T17otaYNauXo46i1bKIoGiBcCcrJyQHHiPQmcq/YU540MC5Wzt9YXUycmJzRi347oyQeDee27wV3XJlWMXuuhbtJiKCFny7BTQ160VATlj/dbwIzN99Ra6/BtTumv/6LyTdKIuVjdAkcN8dtdDW1nlrQ29zuPNCcXXzJ7zX7kQaOCUV1c2OrsbiH/0fE9nknUORn97txqhlYVi0QMS7764wFo6kg0vpmFQRkkQySsJl+TmgcZ01AlsJc2EMMWVuaj9Af9JU4/4yalqDiXIh1fOYYUZnLfOfWS+MsnI+/oLfqJFyMbrsQQTIjs+kTzbiEdhd2R4EZgusU/xRFWokS2NAvahexrRhRQ6tpAI+LezPrkNOR3aHiykBf+P9BkUa0wPp6V2Ayc6q0= From 8e210de466e81515f0099b6769b460796faf496a Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Sun, 29 Nov 2020 12:06:38 -0500 Subject: [PATCH 034/159] fixing teardown_appcontext --- src/cs50/sql.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index 3e695d6..d03faee 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -276,21 +276,16 @@ def execute(self, sql, *args, **kwargs): assert flask.current_app # If new context - if not hasattr(flask.g, "_connection"): + # https://flask.palletsprojects.com/en/1.1.x/appcontext/#storing-data + if "_connection" not in flask.g: - # Ready to connect - flask.g._connection = None - - # Disconnect later - @flask.current_app.teardown_appcontext - def shutdown_session(exception=None): - if flask.g._connection: - flask.g._connection.close() - - # If no connection for context yet - if not flask.g._connection: + # Connect to database flask.g._connection = self._engine.connect() + # Disconnect from database later + if _teardown_appcontext not in flask.current_app.teardown_appcontext_funcs: + flask.current_app.teardown_appcontext(_teardown_appcontext) + # Use context's connection connection = flask.g._connection @@ -533,3 +528,11 @@ def _parse_placeholder(token): # Invalid raise RuntimeError("{}: invalid placeholder".format(token.value)) + + +def _teardown_appcontext(exception=None): + """Closes context's database connection, if any.""" + import flask + connection = flask.g.pop("_connection", None) + if connection: + connection.close() From 2b9e9a6fccd4ae1bc9c37e7d09c60ac1c184a083 Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Sun, 29 Nov 2020 12:34:43 -0500 Subject: [PATCH 035/159] added support for multiple DB connections --- src/cs50/sql.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index d03faee..4b69a50 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -275,19 +275,24 @@ def execute(self, sql, *args, **kwargs): # Infer whether app is defined assert flask.current_app - # If new context + # If no connections to any databases yet + if not hasattr(flask.g, "_connections"): + setattr(flask.g, "_connections", {}) + connections = getattr(flask.g, "_connections") + + # If not yet connected to this database # https://flask.palletsprojects.com/en/1.1.x/appcontext/#storing-data - if "_connection" not in flask.g: + if id(self) not in connections: # Connect to database - flask.g._connection = self._engine.connect() + connections[id(self)] = self._engine.connect() # Disconnect from database later if _teardown_appcontext not in flask.current_app.teardown_appcontext_funcs: flask.current_app.teardown_appcontext(_teardown_appcontext) - # Use context's connection - connection = flask.g._connection + # Use this connection + connection = connections[id(self)] except (ModuleNotFoundError, AssertionError): @@ -533,6 +538,5 @@ def _parse_placeholder(token): def _teardown_appcontext(exception=None): """Closes context's database connection, if any.""" import flask - connection = flask.g.pop("_connection", None) - if connection: + for connection in flask.g.pop("_connections", {}).values(): connection.close() From 2d5b4462251e1f8a3d9c6e50a93dc9a2ea37417f Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Sun, 29 Nov 2020 12:40:53 -0500 Subject: [PATCH 036/159] removed id(), added comments --- src/cs50/sql.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index 4b69a50..e405dc9 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -56,6 +56,7 @@ def __init__(self, url, **kwargs): # Create engine, disabling SQLAlchemy's own autocommit mode, raising exception if back end's module not installed self._engine = sqlalchemy.create_engine(url, **kwargs).execution_options(autocommit=False) + # Get logger self._logger = logging.getLogger("cs50") # Listener for connections @@ -77,7 +78,6 @@ def connect(dbapi_connection, connection_record): # Register listener sqlalchemy.event.listen(self._engine, "connect", connect) - # Test database disabled = self._logger.disabled self._logger.disabled = True @@ -282,17 +282,17 @@ def execute(self, sql, *args, **kwargs): # If not yet connected to this database # https://flask.palletsprojects.com/en/1.1.x/appcontext/#storing-data - if id(self) not in connections: + if self not in connections: # Connect to database - connections[id(self)] = self._engine.connect() + connections[self] = self._engine.connect() # Disconnect from database later if _teardown_appcontext not in flask.current_app.teardown_appcontext_funcs: flask.current_app.teardown_appcontext(_teardown_appcontext) # Use this connection - connection = connections[id(self)] + connection = connections[self] except (ModuleNotFoundError, AssertionError): From f50dfd7003ca7006497787d13e9f049088997180 Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Sun, 29 Nov 2020 12:51:18 -0500 Subject: [PATCH 037/159] version++ --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index db44134..c8a5f5b 100644 --- a/setup.py +++ b/setup.py @@ -16,5 +16,5 @@ package_dir={"": "src"}, packages=["cs50"], url="https://github.com/cs50/python-cs50", - version="6.0.0" + version="6.0.1" ) From 9612a28ed90f5278806c676f19f298457b5e3c6e Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Sun, 29 Nov 2020 12:53:52 -0500 Subject: [PATCH 038/159] updated Travis for main --- .travis.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.travis.yml b/.travis.yml index b13ffd6..0433f6a 100644 --- a/.travis.yml +++ b/.travis.yml @@ -20,11 +20,11 @@ deploy: \"target_commitish\": \"$TRAVIS_COMMIT\", \"name\": \"v$(python setup.py --version)\" }" --user bot50:$GITHUB_TOKEN https://api.github.com/repos/$TRAVIS_REPO_SLUG/releases' on: - branch: master + branch: main - provider: pypi user: "$PYPI_USERNAME" password: "$PYPI_PASSWORD" - on: master + on: main notifications: slack: secure: lJklhcBVjDT6KzUNa3RFHXdXSeH7ytuuGrkZ5ZcR72CXMoTf2pMJTzPwRLWOp6lCSdDC9Y8MWLrcg/e33dJga4Jlp9alOmWqeqesaFjfee4st8vAsgNbv8/RajPH1gD2bnkt8oIwUzdHItdb5AucKFYjbH2g0d8ndoqYqUeBLrnsT1AP5G/Vi9OHC9OWNpR0FKaZIJE0Wt52vkPMH3sV2mFeIskByPB+56U5y547mualKxn61IVR/dhYBEtZQJuSvnwKHPOn9Pkk7cCa+SSSeTJ4w5LboY8T17otaYNauXo46i1bKIoGiBcCcrJyQHHiPQmcq/YU540MC5Wzt9YXUycmJzRi347oyQeDee27wV3XJlWMXuuhbtJiKCFny7BTQ160VATlj/dbwIzN99Ra6/BtTumv/6LyTdKIuVjdAkcN8dtdDW1nlrQ29zuPNCcXXzJ7zX7kQaOCUV1c2OrsbiH/0fE9nknUORn97txqhlYVi0QMS7764wFo6kg0vpmFQRkkQySsJl+TmgcZ01AlsJc2EMMWVuaj9Af9JU4/4yalqDiXIh1fOYYUZnLfOfWS+MsnI+/oLfqJFyMbrsQQTIjs+kTzbiEdhd2R4EZgusU/xRFWokS2NAvahexrRhRQ6tpAI+LezPrkNOR3aHiykBf+P9BkUa0wPp6V2Ayc6q0= From 01e8ce532f41a120bc91ed901dfb7720935471ff Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Mon, 14 Dec 2020 16:09:03 -0500 Subject: [PATCH 039/159] fixing support for multithreading --- .gitignore | 3 ++- README.md | 1 - setup.py | 2 +- src/cs50/sql.py | 50 +++++++++++++++++++++++++------------------------ 4 files changed, 29 insertions(+), 27 deletions(-) diff --git a/.gitignore b/.gitignore index 5a13495..65f1e1f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,8 @@ .* !.gitignore !.travis.yml -dist/ *.db *.egg-info/ *.pyc +dist/ +test.db diff --git a/README.md b/README.md index 0fb6d64..3d6eed8 100644 --- a/README.md +++ b/README.md @@ -42,7 +42,6 @@ s = cs50.get_string(); ``` 1. Run `service postgresql start`. 1. Run `psql -c 'create database test;' -U postgres`. -1. Run `touch test.db`. ### Sample Tests diff --git a/setup.py b/setup.py index c8a5f5b..d7cd3f2 100644 --- a/setup.py +++ b/setup.py @@ -16,5 +16,5 @@ package_dir={"": "src"}, packages=["cs50"], url="https://github.com/cs50/python-cs50", - version="6.0.1" + version="6.0.2" ) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index e405dc9..74b704f 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -43,6 +43,7 @@ def __init__(self, url, **kwargs): import os import re import sqlalchemy + import sqlalchemy.orm import sqlite3 # Require that file already exist for SQLite @@ -72,12 +73,12 @@ def connect(dbapi_connection, connection_record): cursor.execute("PRAGMA foreign_keys=ON") cursor.close() - # Autocommit by default - self._autocommit = True - # Register listener sqlalchemy.event.listen(self._engine, "connect", connect) + # Autocommit by default + self._autocommit = True + # Test database disabled = self._logger.disabled self._logger.disabled = True @@ -96,9 +97,9 @@ def __del__(self): def _disconnect(self): """Close database connection.""" - if hasattr(self, "_connection"): - self._connection.close() - delattr(self, "_connection") + if hasattr(self, "_session"): + self._session.remove() + delattr(self, "_session") @_enable_logging def execute(self, sql, *args, **kwargs): @@ -275,33 +276,34 @@ def execute(self, sql, *args, **kwargs): # Infer whether app is defined assert flask.current_app - # If no connections to any databases yet - if not hasattr(flask.g, "_connections"): - setattr(flask.g, "_connections", {}) - connections = getattr(flask.g, "_connections") + # If no sessions for any databases yet + if not hasattr(flask.g, "_sessions"): + setattr(flask.g, "_sessions", {}) + sessions = getattr(flask.g, "_sessions") - # If not yet connected to this database + # If no session yet for this database # https://flask.palletsprojects.com/en/1.1.x/appcontext/#storing-data - if self not in connections: + # https://stackoverflow.com/a/34010159 + if self not in sessions: # Connect to database - connections[self] = self._engine.connect() + sessions[self] = sqlalchemy.orm.scoping.scoped_session(sqlalchemy.orm.sessionmaker(bind=self._engine)) - # Disconnect from database later + # Remove session later if _teardown_appcontext not in flask.current_app.teardown_appcontext_funcs: flask.current_app.teardown_appcontext(_teardown_appcontext) # Use this connection - connection = connections[self] + session = sessions[self] except (ModuleNotFoundError, AssertionError): # If no connection yet - if not hasattr(self, "_connection"): - self._connection = self._engine.connect() + if not hasattr(self, "_session"): + self._session = sqlalchemy.orm.scoping.scoped_session(sqlalchemy.orm.sessionmaker(bind=self._engine)) # Use this connection - connection = self._connection + session = self._session # Catch SQLAlchemy warnings with warnings.catch_warnings(): @@ -321,10 +323,10 @@ def execute(self, sql, *args, **kwargs): # Execute statement if self._autocommit: - connection.execute(sqlalchemy.text("BEGIN")) - result = connection.execute(sqlalchemy.text(statement)) + session.execute(sqlalchemy.text("BEGIN")) + result = session.execute(sqlalchemy.text(statement)) if self._autocommit: - connection.execute(sqlalchemy.text("COMMIT")) + session.execute(sqlalchemy.text("COMMIT")) # Check for end of transaction if command in ["COMMIT", "ROLLBACK"]: @@ -357,7 +359,7 @@ def execute(self, sql, *args, **kwargs): elif command == "INSERT": if self._engine.url.get_backend_name() in ["postgres", "postgresql"]: try: - result = connection.execute("SELECT LASTVAL()") + result = session.execute("SELECT LASTVAL()") ret = result.first()[0] except sqlalchemy.exc.OperationalError: # If lastval is not yet defined in this session ret = None @@ -538,5 +540,5 @@ def _parse_placeholder(token): def _teardown_appcontext(exception=None): """Closes context's database connection, if any.""" import flask - for connection in flask.g.pop("_connections", {}).values(): - connection.close() + for session in flask.g.pop("_sessions", {}).values(): + session.remove() From cb03ec103ba33cde2c82b4b62e0af3f04017bd36 Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Mon, 14 Dec 2020 17:20:02 -0500 Subject: [PATCH 040/159] fixed comment --- src/cs50/sql.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index 74b704f..1ced4b3 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -293,7 +293,7 @@ def execute(self, sql, *args, **kwargs): if _teardown_appcontext not in flask.current_app.teardown_appcontext_funcs: flask.current_app.teardown_appcontext(_teardown_appcontext) - # Use this connection + # Use this session session = sessions[self] except (ModuleNotFoundError, AssertionError): @@ -302,7 +302,7 @@ def execute(self, sql, *args, **kwargs): if not hasattr(self, "_session"): self._session = sqlalchemy.orm.scoping.scoped_session(sqlalchemy.orm.sessionmaker(bind=self._engine)) - # Use this connection + # Use this session session = self._session # Catch SQLAlchemy warnings From 0cdfd92ccf18139e52206a73593d8ce63658c0b9 Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Tue, 9 Mar 2021 14:06:05 -0500 Subject: [PATCH 041/159] replaces sys.stdin.readline with input --- src/cs50/cs50.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/src/cs50/cs50.py b/src/cs50/cs50.py index 0fc481b..1d7b6ea 100644 --- a/src/cs50/cs50.py +++ b/src/cs50/cs50.py @@ -133,17 +133,11 @@ def get_string(prompt): Read a line of text from standard input and return it as a string, sans trailing line ending. Supports CR (\r), LF (\n), and CRLF (\r\n) as line endings. If user inputs only a line ending, returns "", not None. - Returns None upon error or no input whatsoever (i.e., just EOF). Exits - from Python altogether on SIGINT. + Returns None upon error or no input whatsoever (i.e., just EOF). """ + if type(prompt) is not str: + raise TypeError("prompt must be of type str") try: - if prompt is not None: - print(prompt, end="") - s = sys.stdin.readline() - if not s: - return None - return re.sub(r"(?:\r|\r\n|\n)$", "", s) - except KeyboardInterrupt: - sys.exit("") - except ValueError: + return input(prompt) + except EOFError: return None From aee3be01e1fce6ac73963dd7202baa42da89a650 Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Tue, 9 Mar 2021 20:44:09 -0500 Subject: [PATCH 042/159] up version --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index d7cd3f2..95bd013 100644 --- a/setup.py +++ b/setup.py @@ -16,5 +16,5 @@ package_dir={"": "src"}, packages=["cs50"], url="https://github.com/cs50/python-cs50", - version="6.0.2" + version="6.0.3" ) From 79c6419c6386d87989f29b47282bf8494804e037 Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Wed, 17 Mar 2021 14:34:20 -0400 Subject: [PATCH 043/159] avoid uppercasing identifiers --- setup.py | 2 +- src/cs50/sql.py | 7 ++++--- tests/sql.py | 3 +++ 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/setup.py b/setup.py index 95bd013..550e65d 100644 --- a/setup.py +++ b/setup.py @@ -16,5 +16,5 @@ package_dir={"": "src"}, packages=["cs50"], url="https://github.com/cs50/python-cs50", - version="6.0.3" + version="6.0.4" ) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index 1ced4b3..f95b347 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -114,7 +114,7 @@ def execute(self, sql, *args, **kwargs): import warnings # Parse statement, stripping comments and then leading/trailing whitespace - statements = sqlparse.parse(sqlparse.format(sql, keyword_case="upper", strip_comments=True).strip()) + statements = sqlparse.parse(sqlparse.format(sql, strip_comments=True).strip()) # Allow only one statement at a time, since SQLite doesn't support multiple # https://docs.python.org/3/library/sqlite3.html#sqlite3.Cursor.execute @@ -130,8 +130,9 @@ def execute(self, sql, *args, **kwargs): # Infer command from (unflattened) statement for token in statements[0]: if token.ttype in [sqlparse.tokens.Keyword, sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML]: - if token.value in ["BEGIN", "DELETE", "INSERT", "SELECT", "START", "UPDATE"]: - command = token.value + token_value = token.value.upper() + if token_value in ["BEGIN", "DELETE", "INSERT", "SELECT", "START", "UPDATE"]: + command = token_value break else: command = None diff --git a/tests/sql.py b/tests/sql.py index 8cafdde..e4757c7 100644 --- a/tests/sql.py +++ b/tests/sql.py @@ -129,6 +129,9 @@ def test_rollback(self): self.db.execute("ROLLBACK") self.assertEqual(self.db.execute("SELECT val FROM cs50"), []) + def test_identifier_case(self): + self.assertIn("count", self.db.execute("SELECT 1 AS count")[0]) + def tearDown(self): self.db.execute("DROP TABLE cs50") self.db.execute("DROP TABLE IF EXISTS foo") From 1b671e2b6b205385a0a10dc4d0b0240653e8c44e Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Fri, 9 Apr 2021 19:37:55 -0400 Subject: [PATCH 044/159] refactor, fix scoped session --- setup.py | 2 +- src/cs50/__init__.py | 20 +- src/cs50/_logger.py | 48 ++++ src/cs50/_session.py | 80 ++++++ src/cs50/_statement.py | 269 +++++++++++++++++++ src/cs50/cs50.py | 170 +++++------- src/cs50/sql.py | 582 +++++------------------------------------ tests/test_cs50.py | 151 +++++++++++ 8 files changed, 684 insertions(+), 638 deletions(-) create mode 100644 src/cs50/_logger.py create mode 100644 src/cs50/_session.py create mode 100644 src/cs50/_statement.py create mode 100644 tests/test_cs50.py diff --git a/setup.py b/setup.py index 550e65d..a5b8fb7 100644 --- a/setup.py +++ b/setup.py @@ -16,5 +16,5 @@ package_dir={"": "src"}, packages=["cs50"], url="https://github.com/cs50/python-cs50", - version="6.0.4" + version="7.0.0" ) diff --git a/src/cs50/__init__.py b/src/cs50/__init__.py index aaec161..f04da00 100644 --- a/src/cs50/__init__.py +++ b/src/cs50/__init__.py @@ -1,20 +1,6 @@ -import logging -import os -import sys +from ._logger import _setup_logger +_setup_logger() - -# Disable cs50 logger by default -logging.getLogger("cs50").disabled = True - -# Import cs50_* -from .cs50 import get_char, get_float, get_int, get_string -try: - from .cs50 import get_long -except ImportError: - pass - -# Hook into flask importing +from .cs50 import get_float, get_int, get_string from . import flask - -# Wrap SQLAlchemy from .sql import SQL diff --git a/src/cs50/_logger.py b/src/cs50/_logger.py new file mode 100644 index 0000000..46f0821 --- /dev/null +++ b/src/cs50/_logger.py @@ -0,0 +1,48 @@ +import logging +import os.path +import re +import sys +import traceback + +import termcolor + + +def _setup_logger(): + _logger = logging.getLogger("cs50") + _logger.disabled = True + _logger.setLevel(logging.DEBUG) + + # Log messages once + _logger.propagate = False + + handler = logging.StreamHandler() + handler.setLevel(logging.DEBUG) + + formatter = logging.Formatter("%(levelname)s: %(message)s") + formatter.formatException = lambda exc_info: _formatException(*exc_info) + handler.setFormatter(formatter) + _logger.addHandler(handler) + + +def _formatException(type, value, tb): + """ + Format traceback, darkening entries from global site-packages directories + and user-specific site-packages directory. + https://stackoverflow.com/a/46071447/5156190 + """ + + # Absolute paths to site-packages + packages = tuple(os.path.join(os.path.abspath(p), "") for p in sys.path[1:]) + + # Highlight lines not referring to files in site-packages + lines = [] + for line in traceback.format_exception(type, value, tb): + matches = re.search(r"^ File \"([^\"]+)\", line \d+, in .+", line) + if matches and matches.group(1).startswith(packages): + lines += line + else: + matches = re.search(r"^(\s*)(.*?)(\s*)$", line, re.DOTALL) + lines.append(matches.group(1) + termcolor.colored(matches.group(2), "yellow") + matches.group(3)) + return "".join(lines).rstrip() + + diff --git a/src/cs50/_session.py b/src/cs50/_session.py new file mode 100644 index 0000000..4d1a2a9 --- /dev/null +++ b/src/cs50/_session.py @@ -0,0 +1,80 @@ +import os + +import sqlalchemy +import sqlalchemy.orm +import sqlite3 + +class Session: + def __init__(self, url, **engine_kwargs): + self._url = url + if _is_sqlite_url(self._url): + _assert_sqlite_file_exists(self._url) + + self._engine = _create_engine(self._url, **engine_kwargs) + self._is_postgres = self._engine.url.get_backend_name() in {"postgres", "postgresql"} + _setup_on_connect(self._engine) + self._session = _create_scoped_session(self._engine) + + + def is_postgres(self): + return self._is_postgres + + + def execute(self, statement): + return self._session.execute(sqlalchemy.text(str(statement))) + + + def __getattr__(self, attr): + return getattr(self._session, attr) + + +def _is_sqlite_url(url): + return url.startswith("sqlite:///") + + +def _assert_sqlite_file_exists(url): + path = url[len("sqlite:///"):] + if not os.path.exists(path): + raise RuntimeError(f"does not exist: {path}") + if not os.path.isfile(path): + raise RuntimeError(f"not a file: {path}") + + +def _create_engine(url, **kwargs): + try: + engine = sqlalchemy.create_engine(url, **kwargs) + except sqlalchemy.exc.ArgumentError: + raise RuntimeError(f"invalid URL: {url}") from None + + engine.execution_options(autocommit=False) + return engine + + +def _setup_on_connect(engine): + def connect(dbapi_connection, _): + _disable_auto_begin_commit(dbapi_connection) + if _is_sqlite_connection(dbapi_connection): + _enable_sqlite_foreign_key_constraints(dbapi_connection) + + sqlalchemy.event.listen(engine, "connect", connect) + + +def _create_scoped_session(engine): + session_factory = sqlalchemy.orm.sessionmaker(bind=engine) + return sqlalchemy.orm.scoping.scoped_session(session_factory) + + +def _disable_auto_begin_commit(dbapi_connection): + # Disable underlying API's own emitting of BEGIN and COMMIT so we can ourselves + # https://docs.sqlalchemy.org/en/13/dialects/sqlite.html#serializable-isolation-savepoints-transactional-ddl + dbapi_connection.isolation_level = None + + +def _is_sqlite_connection(dbapi_connection): + return isinstance(dbapi_connection, sqlite3.Connection) + + +def _enable_sqlite_foreign_key_constraints(dbapi_connection): + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() diff --git a/src/cs50/_statement.py b/src/cs50/_statement.py new file mode 100644 index 0000000..7519b1e --- /dev/null +++ b/src/cs50/_statement.py @@ -0,0 +1,269 @@ +import collections +import datetime +import enum +import re + +import sqlalchemy +import sqlparse + + +class Statement: + def __init__(self, dialect, sql, *args, **kwargs): + if len(args) > 0 and len(kwargs) > 0: + raise RuntimeError("cannot pass both positional and named parameters") + + self._dialect = dialect + self._sql = sql + self._args = args + self._kwargs = kwargs + + self._statement = self._parse() + self._command = self._parse_command() + self._tokens = self._bind_params() + + def _parse(self): + formatted_statements = sqlparse.format(self._sql, strip_comments=True).strip() + parsed_statements = sqlparse.parse(formatted_statements) + num_of_statements = len(parsed_statements) + if num_of_statements == 0: + raise RuntimeError("missing statement") + elif num_of_statements > 1: + raise RuntimeError("too many statements at once") + + return parsed_statements[0] + + + def _parse_command(self): + for token in self._statement: + if token.ttype in [sqlparse.tokens.Keyword, sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML]: + token_value = token.value.upper() + if token_value in ["BEGIN", "DELETE", "INSERT", "SELECT", "START", "UPDATE"]: + command = token_value + break + else: + command = None + + return command + + + def _bind_params(self): + tokens = self._tokenize() + paramstyle, placeholders = self._parse_placeholders(tokens) + if paramstyle in [Paramstyle.FORMAT, Paramstyle.QMARK]: + tokens = self._bind_format_or_qmark(placeholders, tokens) + elif paramstyle == Paramstyle.NUMERIC: + tokens = self._bind_numeric(placeholders, tokens) + if paramstyle in [Paramstyle.NAMED, Paramstyle.PYFORMAT]: + tokens = self._bind_named_or_pyformat(placeholders, tokens) + + tokens = _escape_verbatim_colons(tokens) + return tokens + + + def _tokenize(self): + return list(self._statement.flatten()) + + + def _parse_placeholders(self, tokens): + paramstyle = None + placeholders = collections.OrderedDict() + for index, token in enumerate(tokens): + if _is_placeholder(token): + _paramstyle, name = _parse_placeholder(token) + if paramstyle is None: + paramstyle = _paramstyle + elif _paramstyle != paramstyle: + raise RuntimeError("inconsistent paramstyle") + + placeholders[index] = name + + if paramstyle is None: + paramstyle = self._default_paramstyle() + + return paramstyle, placeholders + + + def _default_paramstyle(self): + paramstyle = None + if self._args: + paramstyle = Paramstyle.QMARK + elif self._kwargs: + paramstyle = Paramstyle.NAMED + + return paramstyle + + + def _bind_format_or_qmark(self, placeholders, tokens): + if len(placeholders) != len(self._args): + _placeholders = ", ".join([str(token) for token in placeholders.values()]) + _args = ", ".join([str(self._escape(arg)) for arg in self._args]) + if len(placeholders) < len(self._args): + raise RuntimeError(f"fewer placeholders ({_placeholders}) than values ({_args})") + + raise RuntimeError(f"more placeholders ({_placeholders}) than values ({_args})") + + for arg_index, token_index in enumerate(placeholders.keys()): + tokens[token_index] = self._escape(self._args[arg_index]) + + return tokens + + + def _bind_numeric(self, placeholders, tokens): + unused_arg_indices = set(range(len(self._args))) + for token_index, num in placeholders.items(): + if num >= len(self._args): + raise RuntimeError(f"missing value for placeholder ({num + 1})") + + tokens[token_index] = self._escape(self._args[num]) + unused_arg_indices.remove(num) + + if len(unused_arg_indices) > 0: + unused_args = ", ".join([str(self._escape(self._args[i])) for i in sorted(unused_arg_indices)]) + raise RuntimeError(f"unused value{'' if len(unused_arg_indices) == 1 else 's'} ({unused_args})") + + return tokens + + + def _bind_named_or_pyformat(self, placeholders, tokens): + unused_params = set(self._kwargs.keys()) + for token_index, param_name in placeholders.items(): + if param_name not in self._kwargs: + raise RuntimeError(f"missing value for placeholder ({param_name})") + + tokens[token_index] = self._escape(self._kwargs[param_name]) + unused_params.remove(param_name) + + if len(unused_params) > 0: + raise RuntimeError("unused value{'' if len(unused_params) == 1 else 's'} ({', '.join(sorted(unused_params))})") + + return tokens + + + def _escape(self, value): + """ + Escapes value using engine's conversion function. + https://docs.sqlalchemy.org/en/latest/core/type_api.html#sqlalchemy.types.TypeEngine.literal_processor + """ + + if isinstance(value, (list, tuple)): + return self._escape_iterable(value) + + if isinstance(value, bool): + return sqlparse.sql.Token( + sqlparse.tokens.Number, + sqlalchemy.types.Boolean().literal_processor(self._dialect)(value)) + + if isinstance(value, bytes): + if self._dialect.name in ["mysql", "sqlite"]: + # https://dev.mysql.com/doc/refman/8.0/en/hexadecimal-literals.html + return sqlparse.sql.Token(sqlparse.tokens.Other, f"x'{value.hex()}'") + if self._dialect.name in ["postgres", "postgresql"]: + # https://dba.stackexchange.com/a/203359 + return sqlparse.sql.Token(sqlparse.tokens.Other, f"'\\x{value.hex()}'") + + raise RuntimeError(f"unsupported value: {value}") + + if isinstance(value, datetime.date): + return sqlparse.sql.Token( + sqlparse.tokens.String, + sqlalchemy.types.String().literal_processor(self._dialect)(value.strftime("%Y-%m-%d"))) + + if isinstance(value, datetime.datetime): + return sqlparse.sql.Token( + sqlparse.tokens.String, + sqlalchemy.types.String().literal_processor(self._dialect)(value.strftime("%Y-%m-%d %H:%M:%S"))) + + if isinstance(value, datetime.time): + return sqlparse.sql.Token( + sqlparse.tokens.String, + sqlalchemy.types.String().literal_processor(self._dialect)(value.strftime("%H:%M:%S"))) + + if isinstance(value, float): + return sqlparse.sql.Token( + sqlparse.tokens.Number, + sqlalchemy.types.Float().literal_processor(self._dialect)(value)) + + if isinstance(value, int): + return sqlparse.sql.Token( + sqlparse.tokens.Number, + sqlalchemy.types.Integer().literal_processor(self._dialect)(value)) + + if isinstance(value, str): + return sqlparse.sql.Token( + sqlparse.tokens.String, + sqlalchemy.types.String().literal_processor(self._dialect)(value)) + + if value is None: + return sqlparse.sql.Token( + sqlparse.tokens.Keyword, + sqlalchemy.types.NullType().literal_processor(self._dialect)(value)) + + raise RuntimeError(f"unsupported value: {value}") + + + def _escape_iterable(self, iterable): + return sqlparse.sql.TokenList( + sqlparse.parse(", ".join([str(self._escape(v)) for v in iterable]))) + + + def get_command(self): + return self._command + + + def __str__(self): + return "".join([str(token) for token in self._tokens]) + + +def _is_placeholder(token): + return token.ttype == sqlparse.tokens.Name.Placeholder + + +def _parse_placeholder(token): + if token.value == "?": + return Paramstyle.QMARK, None + + # E.g., :1 + matches = re.search(r"^:([1-9]\d*)$", token.value) + if matches: + return Paramstyle.NUMERIC, int(matches.group(1)) - 1 + + # E.g., :foo + matches = re.search(r"^:([a-zA-Z]\w*)$", token.value) + if matches: + return Paramstyle.NAMED, matches.group(1) + + if token.value == "%s": + return Paramstyle.FORMAT, None + + # E.g., %(foo) + matches = re.search(r"%\((\w+)\)s$", token.value) + if matches: + return Paramstyle.PYFORMAT, matches.group(1) + + raise RuntimeError(f"{token.value}: invalid placeholder") + + +def _escape_verbatim_colons(tokens): + for token in tokens: + if _is_string_literal(token): + token.value = re.sub("(^'|\s+):", r"\1\:", token.value) + elif _is_identifier(token): + token.value = re.sub("(^\"|\s+):", r"\1\:", token.value) + + return tokens + + +def _is_string_literal(token): + return token.ttype in [sqlparse.tokens.Literal.String, sqlparse.tokens.Literal.String.Single] + + +def _is_identifier(token): + return token.ttype == sqlparse.tokens.Literal.String.Symbol + + +class Paramstyle(enum.Enum): + FORMAT = enum.auto() + NAMED = enum.auto() + NUMERIC = enum.auto() + PYFORMAT = enum.auto() + QMARK = enum.auto() diff --git a/src/cs50/cs50.py b/src/cs50/cs50.py index 1d7b6ea..573d862 100644 --- a/src/cs50/cs50.py +++ b/src/cs50/cs50.py @@ -1,98 +1,6 @@ -from __future__ import print_function - -import inspect -import logging -import os import re import sys -from distutils.sysconfig import get_python_lib -from os.path import abspath, join -from termcolor import colored -from traceback import format_exception - - -# Configure default logging handler and formatter -# Prevent flask, werkzeug, etc from adding default handler -logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.DEBUG) - -try: - # Patch formatException - logging.root.handlers[0].formatter.formatException = lambda exc_info: _formatException(*exc_info) -except IndexError: - pass - -# Configure cs50 logger -_logger = logging.getLogger("cs50") -_logger.setLevel(logging.DEBUG) - -# Log messages once -_logger.propagate = False - -handler = logging.StreamHandler() -handler.setLevel(logging.DEBUG) - -formatter = logging.Formatter("%(levelname)s: %(message)s") -formatter.formatException = lambda exc_info: _formatException(*exc_info) -handler.setFormatter(formatter) -_logger.addHandler(handler) - - -class _flushfile(): - """ - Disable buffering for standard output and standard error. - - http://stackoverflow.com/a/231216 - """ - - def __init__(self, f): - self.f = f - - def __getattr__(self, name): - return object.__getattribute__(self.f, name) - - def write(self, x): - self.f.write(x) - self.f.flush() - - -sys.stderr = _flushfile(sys.stderr) -sys.stdout = _flushfile(sys.stdout) - - -def _formatException(type, value, tb): - """ - Format traceback, darkening entries from global site-packages directories - and user-specific site-packages directory. - - https://stackoverflow.com/a/46071447/5156190 - """ - - # Absolute paths to site-packages - packages = tuple(join(abspath(p), "") for p in sys.path[1:]) - - # Highlight lines not referring to files in site-packages - lines = [] - for line in format_exception(type, value, tb): - matches = re.search(r"^ File \"([^\"]+)\", line \d+, in .+", line) - if matches and matches.group(1).startswith(packages): - lines += line - else: - matches = re.search(r"^(\s*)(.*?)(\s*)$", line, re.DOTALL) - lines.append(matches.group(1) + colored(matches.group(2), "yellow") + matches.group(3)) - return "".join(lines).rstrip() - - -sys.excepthook = lambda type, value, tb: print(_formatException(type, value, tb), file=sys.stderr) - - -def eprint(*args, **kwargs): - raise RuntimeError("The CS50 Library for Python no longer supports eprint, but you can use print instead!") - - -def get_char(prompt): - raise RuntimeError("The CS50 Library for Python no longer supports get_char, but you can use get_string instead!") - def get_float(prompt): """ @@ -101,14 +9,21 @@ def get_float(prompt): prompted to retry. If line can't be read, return None. """ while True: - s = get_string(prompt) - if s is None: - return None - if len(s) > 0 and re.search(r"^[+-]?\d*(?:\.\d*)?$", s): - try: - return float(s) - except (OverflowError, ValueError): - pass + try: + return _get_float(prompt) + except (OverflowError, ValueError): + pass + + +def _get_float(prompt): + s = get_string(prompt) + if s is None: + return + + if len(s) > 0 and re.search(r"^[+-]?\d*(?:\.\d*)?$", s): + return float(s) + + raise ValueError(f"invalid float literal: {s}") def get_int(prompt): @@ -118,14 +33,21 @@ def get_int(prompt): can't be read, return None. """ while True: - s = get_string(prompt) - if s is None: - return None - if re.search(r"^[+-]?\d+$", s): - try: - return int(s, 10) - except ValueError: - pass + try: + return _get_int(prompt) + except (MemoryError, ValueError): + pass + + +def _get_int(prompt): + s = get_string(prompt) + if s is None: + return + + if re.search(r"^[+-]?\d+$", s): + return int(s, 10) + + raise ValueError(f"invalid int literal for base 10: {s}") def get_string(prompt): @@ -137,7 +59,35 @@ def get_string(prompt): """ if type(prompt) is not str: raise TypeError("prompt must be of type str") + try: - return input(prompt) + return _get_input(prompt) except EOFError: - return None + return + + +def _get_input(prompt): + return input(prompt) + + +class _flushfile(): + """ + Disable buffering for standard output and standard error. + http://stackoverflow.com/a/231216 + """ + + def __init__(self, f): + self.f = f + + def __getattr__(self, name): + return object.__getattribute__(self.f, name) + + def write(self, x): + self.f.write(x) + self.f.flush() + +def disable_buffering(): + sys.stderr = _flushfile(sys.stderr) + sys.stdout = _flushfile(sys.stdout) + +disable_buffering() diff --git a/src/cs50/sql.py b/src/cs50/sql.py index f95b347..b778601 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -1,545 +1,107 @@ -def _enable_logging(f): - """Enable logging of SQL statements when Flask is in use.""" +import decimal +import logging +import warnings - import logging - import functools +import sqlalchemy +import termcolor - @functools.wraps(f) - def decorator(*args, **kwargs): +from ._session import Session +from ._statement import Statement - # Infer whether Flask is installed - try: - import flask - except ModuleNotFoundError: - return f(*args, **kwargs) - - # Enable logging - disabled = logging.getLogger("cs50").disabled - if flask.current_app: - logging.getLogger("cs50").disabled = False - try: - return f(*args, **kwargs) - finally: - logging.getLogger("cs50").disabled = disabled - - return decorator - - -class SQL(object): - """Wrap SQLAlchemy to provide a simple SQL API.""" - - def __init__(self, url, **kwargs): - """ - Create instance of sqlalchemy.engine.Engine. +_logger = logging.getLogger("cs50") - URL should be a string that indicates database dialect and connection arguments. - http://docs.sqlalchemy.org/en/latest/core/engines.html#sqlalchemy.create_engine - http://docs.sqlalchemy.org/en/latest/dialects/index.html - """ +class SQL: + def __init__(self, url, **engine_kwargs): + self._session = Session(url, **engine_kwargs) + self._autocommit = False + self._test_database() - # Lazily import - import logging - import os - import re - import sqlalchemy - import sqlalchemy.orm - import sqlite3 - # Require that file already exist for SQLite - matches = re.search(r"^sqlite:///(.+)$", url) - if matches: - if not os.path.exists(matches.group(1)): - raise RuntimeError("does not exist: {}".format(matches.group(1))) - if not os.path.isfile(matches.group(1)): - raise RuntimeError("not a file: {}".format(matches.group(1))) + def _test_database(self): + self.execute("SELECT 1") - # Create engine, disabling SQLAlchemy's own autocommit mode, raising exception if back end's module not installed - self._engine = sqlalchemy.create_engine(url, **kwargs).execution_options(autocommit=False) - # Get logger - self._logger = logging.getLogger("cs50") - - # Listener for connections - def connect(dbapi_connection, connection_record): - - # Disable underlying API's own emitting of BEGIN and COMMIT so we can ourselves - # https://docs.sqlalchemy.org/en/13/dialects/sqlite.html#serializable-isolation-savepoints-transactional-ddl - dbapi_connection.isolation_level = None - - # Enable foreign key constraints - if type(dbapi_connection) is sqlite3.Connection: # If back end is sqlite - cursor = dbapi_connection.cursor() - cursor.execute("PRAGMA foreign_keys=ON") - cursor.close() - - # Register listener - sqlalchemy.event.listen(self._engine, "connect", connect) - - # Autocommit by default - self._autocommit = True - - # Test database - disabled = self._logger.disabled - self._logger.disabled = True - try: - self.execute("SELECT 1") - except sqlalchemy.exc.OperationalError as e: - e = RuntimeError(_parse_exception(e)) - e.__cause__ = None - raise e - finally: - self._logger.disabled = disabled - - def __del__(self): - """Disconnect from database.""" - self._disconnect() - - def _disconnect(self): - """Close database connection.""" - if hasattr(self, "_session"): - self._session.remove() - delattr(self, "_session") - - @_enable_logging def execute(self, sql, *args, **kwargs): """Execute a SQL statement.""" + statement = Statement(self._session.get_bind().dialect, sql, *args, **kwargs) + command = statement.get_command() + if command in ["BEGIN", "START"]: + self._autocommit = False - # Lazily import - import decimal - import re - import sqlalchemy - import sqlparse - import termcolor - import warnings + if self._autocommit: + self._session.execute("BEGIN") - # Parse statement, stripping comments and then leading/trailing whitespace - statements = sqlparse.parse(sqlparse.format(sql, strip_comments=True).strip()) + result = self._execute(statement) - # Allow only one statement at a time, since SQLite doesn't support multiple - # https://docs.python.org/3/library/sqlite3.html#sqlite3.Cursor.execute - if len(statements) > 1: - raise RuntimeError("too many statements at once") - elif len(statements) == 0: - raise RuntimeError("missing statement") + if self._autocommit: + self._session.execute("COMMIT") + self._session.remove() - # Ensure named and positional parameters are mutually exclusive - if len(args) > 0 and len(kwargs) > 0: - raise RuntimeError("cannot pass both positional and named parameters") + if command in ["COMMIT", "ROLLBACK"]: + self._autocommit = True + self._session.remove() - # Infer command from (unflattened) statement - for token in statements[0]: - if token.ttype in [sqlparse.tokens.Keyword, sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML]: - token_value = token.value.upper() - if token_value in ["BEGIN", "DELETE", "INSERT", "SELECT", "START", "UPDATE"]: - command = token_value - break + if command == "SELECT": + ret = _fetch_select_result(result) + elif command == "INSERT": + if self._session.is_postgres(): + ret = self._get_last_val() + else: + ret = result.lastrowid if result.rowcount == 1 else None + elif command in ["DELETE", "UPDATE"]: + ret = result.rowcount else: - command = None - - # Flatten statement - tokens = list(statements[0].flatten()) - - # Validate paramstyle - placeholders = {} - paramstyle = None - for index, token in enumerate(tokens): - - # If token is a placeholder - if token.ttype == sqlparse.tokens.Name.Placeholder: - - # Determine paramstyle, name - _paramstyle, name = _parse_placeholder(token) - - # Remember paramstyle - if not paramstyle: - paramstyle = _paramstyle - - # Ensure paramstyle is consistent - elif _paramstyle != paramstyle: - raise RuntimeError("inconsistent paramstyle") - - # Remember placeholder's index, name - placeholders[index] = name - - # If no placeholders - if not paramstyle: - - # Error-check like qmark if args - if args: - paramstyle = "qmark" - - # Error-check like named if kwargs - elif kwargs: - paramstyle = "named" - - # In case of errors - _placeholders = ", ".join([str(tokens[index]) for index in placeholders]) - _args = ", ".join([str(self._escape(arg)) for arg in args]) - - # qmark - if paramstyle == "qmark": - - # Validate number of placeholders - if len(placeholders) != len(args): - if len(placeholders) < len(args): - raise RuntimeError("fewer placeholders ({}) than values ({})".format(_placeholders, _args)) - else: - raise RuntimeError("more placeholders ({}) than values ({})".format(_placeholders, _args)) - - # Escape values - for i, index in enumerate(placeholders.keys()): - tokens[index] = self._escape(args[i]) - - # numeric - elif paramstyle == "numeric": - - # Escape values - for index, i in placeholders.items(): - if i >= len(args): - raise RuntimeError("missing value for placeholder (:{})".format(i + 1, len(args))) - tokens[index] = self._escape(args[i]) - - # Check if any values unused - indices = set(range(len(args))) - set(placeholders.values()) - if indices: - raise RuntimeError("unused {} ({})".format( - "value" if len(indices) == 1 else "values", - ", ".join([str(self._escape(args[index])) for index in indices]))) - - # named - elif paramstyle == "named": - - # Escape values - for index, name in placeholders.items(): - if name not in kwargs: - raise RuntimeError("missing value for placeholder (:{})".format(name)) - tokens[index] = self._escape(kwargs[name]) - - # Check if any keys unused - keys = kwargs.keys() - placeholders.values() - if keys: - raise RuntimeError("unused values ({})".format(", ".join(keys))) - - # format - elif paramstyle == "format": - - # Validate number of placeholders - if len(placeholders) != len(args): - if len(placeholders) < len(args): - raise RuntimeError("fewer placeholders ({}) than values ({})".format(_placeholders, _args)) - else: - raise RuntimeError("more placeholders ({}) than values ({})".format(_placeholders, _args)) - - # Escape values - for i, index in enumerate(placeholders.keys()): - tokens[index] = self._escape(args[i]) - - # pyformat - elif paramstyle == "pyformat": - - # Escape values - for index, name in placeholders.items(): - if name not in kwargs: - raise RuntimeError("missing value for placeholder (%{}s)".format(name)) - tokens[index] = self._escape(kwargs[name]) - - # Check if any keys unused - keys = kwargs.keys() - placeholders.values() - if keys: - raise RuntimeError("unused {} ({})".format( - "value" if len(keys) == 1 else "values", - ", ".join(keys))) - - # For SQL statements where a colon is required verbatim, as within an inline string, use a backslash to escape - # https://docs.sqlalchemy.org/en/13/core/sqlelement.html?highlight=text#sqlalchemy.sql.expression.text - for index, token in enumerate(tokens): - - # In string literal - # https://www.sqlite.org/lang_keywords.html - if token.ttype in [sqlparse.tokens.Literal.String, sqlparse.tokens.Literal.String.Single]: - token.value = re.sub("(^'|\s+):", r"\1\:", token.value) - - # In identifier - # https://www.sqlite.org/lang_keywords.html - elif token.ttype == sqlparse.tokens.Literal.String.Symbol: - token.value = re.sub("(^\"|\s+):", r"\1\:", token.value) - - # Join tokens into statement - statement = "".join([str(token) for token in tokens]) - - # Connect to database - try: - - # Infer whether Flask is installed - import flask + ret = True - # Infer whether app is defined - assert flask.current_app + return ret - # If no sessions for any databases yet - if not hasattr(flask.g, "_sessions"): - setattr(flask.g, "_sessions", {}) - sessions = getattr(flask.g, "_sessions") - - # If no session yet for this database - # https://flask.palletsprojects.com/en/1.1.x/appcontext/#storing-data - # https://stackoverflow.com/a/34010159 - if self not in sessions: - - # Connect to database - sessions[self] = sqlalchemy.orm.scoping.scoped_session(sqlalchemy.orm.sessionmaker(bind=self._engine)) - - # Remove session later - if _teardown_appcontext not in flask.current_app.teardown_appcontext_funcs: - flask.current_app.teardown_appcontext(_teardown_appcontext) - - # Use this session - session = sessions[self] - - except (ModuleNotFoundError, AssertionError): - - # If no connection yet - if not hasattr(self, "_session"): - self._session = sqlalchemy.orm.scoping.scoped_session(sqlalchemy.orm.sessionmaker(bind=self._engine)) - - # Use this session - session = self._session + def _execute(self, statement): # Catch SQLAlchemy warnings with warnings.catch_warnings(): - # Raise exceptions for warnings warnings.simplefilter("error") - - # Prepare, execute statement try: + return self._session.execute(statement) + except sqlalchemy.exc.IntegrityError as exc: + _logger.debug(termcolor.colored(str(statement), "yellow")) + raise ValueError(exc.orig) from None + except (sqlalchemy.exc.OperationalError, sqlalchemy.exc.ProgrammingError) as exc: + self._session.remove() + _logger.debug(termcolor.colored(statement, "red")) + raise RuntimeError(exc.orig) from None - # Join tokens into statement, abbreviating binary data as <class 'bytes'> - _statement = "".join([str(bytes) if token.ttype == sqlparse.tokens.Other else str(token) for token in tokens]) - - # Check for start of transaction - if command in ["BEGIN", "START"]: - self._autocommit = False - - # Execute statement - if self._autocommit: - session.execute(sqlalchemy.text("BEGIN")) - result = session.execute(sqlalchemy.text(statement)) - if self._autocommit: - session.execute(sqlalchemy.text("COMMIT")) - - # Check for end of transaction - if command in ["COMMIT", "ROLLBACK"]: - self._autocommit = True - - # Return value - ret = True - - # If SELECT, return result set as list of dict objects - if command == "SELECT": - - # Coerce types - rows = [dict(row) for row in result.fetchall()] - for row in rows: - for column in row: - - # Coerce decimal.Decimal objects to float objects - # https://groups.google.com/d/msg/sqlalchemy/0qXMYJvq8SA/oqtvMD9Uw-kJ - if type(row[column]) is decimal.Decimal: - row[column] = float(row[column]) - - # Coerce memoryview objects (as from PostgreSQL's bytea columns) to bytes - elif type(row[column]) is memoryview: - row[column] = bytes(row[column]) - - # Rows to be returned - ret = rows - - # If INSERT, return primary key value for a newly inserted row (or None if none) - elif command == "INSERT": - if self._engine.url.get_backend_name() in ["postgres", "postgresql"]: - try: - result = session.execute("SELECT LASTVAL()") - ret = result.first()[0] - except sqlalchemy.exc.OperationalError: # If lastval is not yet defined in this session - ret = None - else: - ret = result.lastrowid if result.rowcount == 1 else None - - # If DELETE or UPDATE, return number of rows matched - elif command in ["DELETE", "UPDATE"]: - ret = result.rowcount - - # If constraint violated, return None - except sqlalchemy.exc.IntegrityError as e: - self._logger.debug(termcolor.colored(statement, "yellow")) - e = ValueError(e.orig) - e.__cause__ = None - raise e - - # If user error - except (sqlalchemy.exc.OperationalError, sqlalchemy.exc.ProgrammingError) as e: - self._disconnect() - self._logger.debug(termcolor.colored(statement, "red")) - e = RuntimeError(e.orig) - e.__cause__ = None - raise e - - # Return value - else: - self._logger.debug(termcolor.colored(_statement, "green")) - return ret - - def _escape(self, value): - """ - Escapes value using engine's conversion function. - - https://docs.sqlalchemy.org/en/latest/core/type_api.html#sqlalchemy.types.TypeEngine.literal_processor - """ - - # Lazily import - import sqlparse + _logger.debug(termcolor.colored(str(statement), "green")) - def __escape(value): - # Lazily import - import datetime - import sqlalchemy - - # bool - if type(value) is bool: - return sqlparse.sql.Token( - sqlparse.tokens.Number, - sqlalchemy.types.Boolean().literal_processor(self._engine.dialect)(value)) - - # bytes - elif type(value) is bytes: - if self._engine.url.get_backend_name() in ["mysql", "sqlite"]: - return sqlparse.sql.Token(sqlparse.tokens.Other, f"x'{value.hex()}'") # https://dev.mysql.com/doc/refman/8.0/en/hexadecimal-literals.html - elif self._engine.url.get_backend_name() == "postgresql": - return sqlparse.sql.Token(sqlparse.tokens.Other, f"'\\x{value.hex()}'") # https://dba.stackexchange.com/a/203359 - else: - raise RuntimeError("unsupported value: {}".format(value)) - - # datetime.date - elif type(value) is datetime.date: - return sqlparse.sql.Token( - sqlparse.tokens.String, - sqlalchemy.types.String().literal_processor(self._engine.dialect)(value.strftime("%Y-%m-%d"))) - - # datetime.datetime - elif type(value) is datetime.datetime: - return sqlparse.sql.Token( - sqlparse.tokens.String, - sqlalchemy.types.String().literal_processor(self._engine.dialect)(value.strftime("%Y-%m-%d %H:%M:%S"))) - - # datetime.time - elif type(value) is datetime.time: - return sqlparse.sql.Token( - sqlparse.tokens.String, - sqlalchemy.types.String().literal_processor(self._engine.dialect)(value.strftime("%H:%M:%S"))) - - # float - elif type(value) is float: - return sqlparse.sql.Token( - sqlparse.tokens.Number, - sqlalchemy.types.Float().literal_processor(self._engine.dialect)(value)) - - # int - elif type(value) is int: - return sqlparse.sql.Token( - sqlparse.tokens.Number, - sqlalchemy.types.Integer().literal_processor(self._engine.dialect)(value)) - - # str - elif type(value) is str: - return sqlparse.sql.Token( - sqlparse.tokens.String, - sqlalchemy.types.String().literal_processor(self._engine.dialect)(value)) - - # None - elif value is None: - return sqlparse.sql.Token( - sqlparse.tokens.Keyword, - sqlalchemy.types.NullType().literal_processor(self._engine.dialect)(value)) - - # Unsupported value - else: - raise RuntimeError("unsupported value: {}".format(value)) - - # Escape value(s), separating with commas as needed - if type(value) in [list, tuple]: - return sqlparse.sql.TokenList(sqlparse.parse(", ".join([str(__escape(v)) for v in value]))) - else: - return __escape(value) - - -def _parse_exception(e): - """Parses an exception, returns its message.""" - - # Lazily import - import re - - # MySQL - matches = re.search(r"^\(_mysql_exceptions\.OperationalError\) \(\d+, \"(.+)\"\)$", str(e)) - if matches: - return matches.group(1) - - # PostgreSQL - matches = re.search(r"^\(psycopg2\.OperationalError\) (.+)$", str(e)) - if matches: - return matches.group(1) - - # SQLite - matches = re.search(r"^\(sqlite3\.OperationalError\) (.+)$", str(e)) - if matches: - return matches.group(1) - - # Default - return str(e) - - -def _parse_placeholder(token): - """Infers paramstyle, name from sqlparse.tokens.Name.Placeholder.""" - - # Lazily load - import re - import sqlparse - - # Validate token - if not isinstance(token, sqlparse.sql.Token) or token.ttype != sqlparse.tokens.Name.Placeholder: - raise TypeError() - - # qmark - if token.value == "?": - return "qmark", None + def _get_last_val(self): + try: + return self._session.execute("SELECT LASTVAL()").first()[0] + except sqlalchemy.exc.OperationalError: # If lastval is not yet defined in this session + return None - # numeric - matches = re.search(r"^:([1-9]\d*)$", token.value) - if matches: - return "numeric", int(matches.group(1)) - 1 - # named - matches = re.search(r"^:([a-zA-Z]\w*)$", token.value) - if matches: - return "named", matches.group(1) + def init_app(self, app): + @app.teardown_appcontext + def shutdown_session(res_or_exc): + self._session.remove() + return res_or_exc - # format - if token.value == "%s": - return "format", None + logging.getLogger("cs50").disabled = False - # pyformat - matches = re.search(r"%\((\w+)\)s$", token.value) - if matches: - return "pyformat", matches.group(1) - # Invalid - raise RuntimeError("{}: invalid placeholder".format(token.value)) +def _fetch_select_result(result): + rows = [dict(row) for row in result.fetchall()] + for row in rows: + for column in row: + # Coerce decimal.Decimal objects to float objects + # https://groups.google.com/d/msg/sqlalchemy/0qXMYJvq8SA/oqtvMD9Uw-kJ + if isinstance(row[column], decimal.Decimal): + row[column] = float(row[column]) + # Coerce memoryview objects (as from PostgreSQL's bytea columns) to bytes + elif isinstance(row[column], memoryview): + row[column] = bytes(row[column]) -def _teardown_appcontext(exception=None): - """Closes context's database connection, if any.""" - import flask - for session in flask.g.pop("_sessions", {}).values(): - session.remove() + return rows diff --git a/tests/test_cs50.py b/tests/test_cs50.py new file mode 100644 index 0000000..a58424d --- /dev/null +++ b/tests/test_cs50.py @@ -0,0 +1,151 @@ +import math +import sys +import unittest + +from unittest.mock import patch + +from cs50.cs50 import get_string, _get_int, _get_float + + +class TestCS50(unittest.TestCase): + @patch("cs50.cs50._get_input", return_value="") + def test_get_string_empty_input(self, mock_get_input): + """Returns empty string when input is empty""" + self.assertEqual(get_string("Answer: "), "") + mock_get_input.assert_called_with("Answer: ") + + + @patch("cs50.cs50._get_input", return_value="test") + def test_get_string_nonempty_input(self, mock_get_input): + """Returns the provided non-empty input""" + self.assertEqual(get_string("Answer: "), "test") + mock_get_input.assert_called_with("Answer: ") + + + @patch("cs50.cs50._get_input", side_effect=EOFError) + def test_get_string_eof(self, mock_get_input): + """Returns None on EOF""" + self.assertIs(get_string("Answer: "), None) + mock_get_input.assert_called_with("Answer: ") + + + def test_get_string_invalid_prompt(self): + """Raises TypeError when prompt is not str""" + with self.assertRaises(TypeError): + get_string(1) + + + @patch("cs50.cs50.get_string", return_value=None) + def test_get_int_eof(self, mock_get_string): + """Returns None on EOF""" + self.assertIs(_get_int("Answer: "), None) + mock_get_string.assert_called_with("Answer: ") + + + def test_get_int_valid_input(self): + """Returns the provided integer input""" + + def assert_equal(return_value, expected_value): + with patch("cs50.cs50.get_string", return_value=return_value) as mock_get_string: + self.assertEqual(_get_int("Answer: "), expected_value) + mock_get_string.assert_called_with("Answer: ") + + values = [ + ("0", 0), + ("50", 50), + ("+50", 50), + ("+42", 42), + ("-42", -42), + ("42", 42), + ] + + for return_value, expected_value in values: + assert_equal(return_value, expected_value) + + + def test_get_int_invalid_input(self): + """Raises ValueError when input is invalid base-10 int""" + + def assert_raises_valueerror(return_value): + with patch("cs50.cs50.get_string", return_value=return_value) as mock_get_string: + with self.assertRaises(ValueError): + _get_int("Answer: ") + + mock_get_string.assert_called_with("Answer: ") + + return_values = [ + "++50", + "--50", + "50+", + "50-", + " 50", + " +50", + " -50", + "50 ", + "ab50", + "50ab", + "ab50ab", + ] + + for return_value in return_values: + assert_raises_valueerror(return_value) + + + @patch("cs50.cs50.get_string", return_value=None) + def test_get_float_eof(self, mock_get_string): + """Returns None on EOF""" + self.assertIs(_get_float("Answer: "), None) + mock_get_string.assert_called_with("Answer: ") + + + def test_get_float_valid_input(self): + """Returns the provided integer input""" + def assert_equal(return_value, expected_value): + with patch("cs50.cs50.get_string", return_value=return_value) as mock_get_string: + f = _get_float("Answer: ") + self.assertTrue(math.isclose(f, expected_value)) + mock_get_string.assert_called_with("Answer: ") + + values = [ + (".0", 0.0), + ("0.", 0.0), + (".42", 0.42), + ("42.", 42.0), + ("50", 50.0), + ("+50", 50.0), + ("-50", -50.0), + ("+3.14", 3.14), + ("-3.14", -3.14), + ] + + for return_value, expected_value in values: + assert_equal(return_value, expected_value) + + + def test_get_float_invalid_input(self): + """Raises ValueError when input is invalid float""" + + def assert_raises_valueerror(return_value): + with patch("cs50.cs50.get_string", return_value=return_value) as mock_get_string: + with self.assertRaises(ValueError): + _get_float("Answer: ") + + mock_get_string.assert_called_with("Answer: ") + + return_values = [ + ".", + "..5", + "a.5", + ".5a" + "0.5a", + "a0.42", + " .42", + "3.14 ", + "++3.14", + "3.14+", + "--3.14", + "3.14--", + ] + + for return_value in return_values: + assert_raises_valueerror(return_value) From d23ed8a9bdd2bbf529021904aa6c98b640781033 Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Fri, 9 Apr 2021 19:44:21 -0400 Subject: [PATCH 045/159] remove unused import --- src/cs50/flask.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/cs50/flask.py b/src/cs50/flask.py index 324ec30..a0e077a 100644 --- a/src/cs50/flask.py +++ b/src/cs50/flask.py @@ -2,18 +2,17 @@ import pkgutil import sys +from distutils.version import StrictVersion +from werkzeug.middleware.proxy_fix import ProxyFix + def _wrap_flask(f): if f is None: return - from distutils.version import StrictVersion - from .cs50 import _formatException - if f.__version__ < StrictVersion("1.0"): return if os.getenv("CS50_IDE_TYPE") == "online": - from werkzeug.middleware.proxy_fix import ProxyFix _flask_init_before = f.Flask.__init__ def _flask_init_after(self, *args, **kwargs): _flask_init_before(self, *args, **kwargs) From b5f030083e30b31acd6056aee2097462c430a731 Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Fri, 9 Apr 2021 19:49:44 -0400 Subject: [PATCH 046/159] fix logger --- src/cs50/sql.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index b778601..d5c8d49 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -63,7 +63,7 @@ def _execute(self, statement): # Raise exceptions for warnings warnings.simplefilter("error") try: - return self._session.execute(statement) + result = self._session.execute(statement) except sqlalchemy.exc.IntegrityError as exc: _logger.debug(termcolor.colored(str(statement), "yellow")) raise ValueError(exc.orig) from None @@ -73,6 +73,7 @@ def _execute(self, statement): raise RuntimeError(exc.orig) from None _logger.debug(termcolor.colored(str(statement), "green")) + return result def _get_last_val(self): From 022a3da3151c7c82315e12bcb2d87fabe61e4600 Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Fri, 9 Apr 2021 20:01:11 -0400 Subject: [PATCH 047/159] remove test_database, rename param --- src/cs50/sql.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index d5c8d49..a1f7dbd 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -15,11 +15,6 @@ class SQL: def __init__(self, url, **engine_kwargs): self._session = Session(url, **engine_kwargs) self._autocommit = False - self._test_database() - - - def _test_database(self): - self.execute("SELECT 1") def execute(self, sql, *args, **kwargs): @@ -85,9 +80,8 @@ def _get_last_val(self): def init_app(self, app): @app.teardown_appcontext - def shutdown_session(res_or_exc): + def shutdown_session(_): self._session.remove() - return res_or_exc logging.getLogger("cs50").disabled = False From a3b32c45bb64308eed38b53680e5997e00cbeac8 Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Fri, 9 Apr 2021 20:12:13 -0400 Subject: [PATCH 048/159] fix exception formatting --- src/cs50/_logger.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/cs50/_logger.py b/src/cs50/_logger.py index 46f0821..c489111 100644 --- a/src/cs50/_logger.py +++ b/src/cs50/_logger.py @@ -8,6 +8,16 @@ def _setup_logger(): + # Configure default logging handler and formatter + # Prevent flask, werkzeug, etc from adding default handler + logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.DEBUG) + + try: + # Patch formatException + logging.root.handlers[0].formatter.formatException = lambda exc_info: _formatException(*exc_info) + except IndexError: + pass + _logger = logging.getLogger("cs50") _logger.disabled = True _logger.setLevel(logging.DEBUG) @@ -23,6 +33,8 @@ def _setup_logger(): handler.setFormatter(formatter) _logger.addHandler(handler) + sys.excepthook = lambda type, value, tb: print(_formatException(type, value, tb), file=sys.stderr) + def _formatException(type, value, tb): """ @@ -44,5 +56,3 @@ def _formatException(type, value, tb): matches = re.search(r"^(\s*)(.*?)(\s*)$", line, re.DOTALL) lines.append(matches.group(1) + termcolor.colored(matches.group(2), "yellow") + matches.group(3)) return "".join(lines).rstrip() - - From 663e6bdf919853d8518d140e3b9c5da51edca0ed Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Sat, 10 Apr 2021 00:06:25 -0400 Subject: [PATCH 049/159] simplify _session, execute --- src/cs50/_session.py | 16 +++++----------- src/cs50/_statement.py | 24 ++++++++++++++---------- src/cs50/sql.py | 25 +++++++++++++++++-------- 3 files changed, 36 insertions(+), 29 deletions(-) diff --git a/src/cs50/_session.py b/src/cs50/_session.py index 4d1a2a9..441371a 100644 --- a/src/cs50/_session.py +++ b/src/cs50/_session.py @@ -6,18 +6,12 @@ class Session: def __init__(self, url, **engine_kwargs): - self._url = url - if _is_sqlite_url(self._url): - _assert_sqlite_file_exists(self._url) + if _is_sqlite_url(url): + _assert_sqlite_file_exists(url) - self._engine = _create_engine(self._url, **engine_kwargs) - self._is_postgres = self._engine.url.get_backend_name() in {"postgres", "postgresql"} - _setup_on_connect(self._engine) - self._session = _create_scoped_session(self._engine) - - - def is_postgres(self): - return self._is_postgres + engine = _create_engine(url, **engine_kwargs) + _setup_on_connect(engine) + self._session = _create_scoped_session(engine) def execute(self, statement): diff --git a/src/cs50/_statement.py b/src/cs50/_statement.py index 7519b1e..d6ba10d 100644 --- a/src/cs50/_statement.py +++ b/src/cs50/_statement.py @@ -24,10 +24,10 @@ def __init__(self, dialect, sql, *args, **kwargs): def _parse(self): formatted_statements = sqlparse.format(self._sql, strip_comments=True).strip() parsed_statements = sqlparse.parse(formatted_statements) - num_of_statements = len(parsed_statements) - if num_of_statements == 0: + statement_count = len(parsed_statements) + if statement_count == 0: raise RuntimeError("missing statement") - elif num_of_statements > 1: + elif statement_count > 1: raise RuntimeError("too many statements at once") return parsed_statements[0] @@ -35,9 +35,9 @@ def _parse(self): def _parse_command(self): for token in self._statement: - if token.ttype in [sqlparse.tokens.Keyword, sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML]: + if _is_command_token(token): token_value = token.value.upper() - if token_value in ["BEGIN", "DELETE", "INSERT", "SELECT", "START", "UPDATE"]: + if token_value in {"BEGIN", "DELETE", "INSERT", "SELECT", "START", "UPDATE"}: command = token_value break else: @@ -49,11 +49,11 @@ def _parse_command(self): def _bind_params(self): tokens = self._tokenize() paramstyle, placeholders = self._parse_placeholders(tokens) - if paramstyle in [Paramstyle.FORMAT, Paramstyle.QMARK]: + if paramstyle in {Paramstyle.FORMAT, Paramstyle.QMARK}: tokens = self._bind_format_or_qmark(placeholders, tokens) elif paramstyle == Paramstyle.NUMERIC: tokens = self._bind_numeric(placeholders, tokens) - if paramstyle in [Paramstyle.NAMED, Paramstyle.PYFORMAT]: + if paramstyle in {Paramstyle.NAMED, Paramstyle.PYFORMAT}: tokens = self._bind_named_or_pyformat(placeholders, tokens) tokens = _escape_verbatim_colons(tokens) @@ -154,10 +154,10 @@ def _escape(self, value): sqlalchemy.types.Boolean().literal_processor(self._dialect)(value)) if isinstance(value, bytes): - if self._dialect.name in ["mysql", "sqlite"]: + if self._dialect.name in {"mysql", "sqlite"}: # https://dev.mysql.com/doc/refman/8.0/en/hexadecimal-literals.html return sqlparse.sql.Token(sqlparse.tokens.Other, f"x'{value.hex()}'") - if self._dialect.name in ["postgres", "postgresql"]: + if self._dialect.name in {"postgres", "postgresql"}: # https://dba.stackexchange.com/a/203359 return sqlparse.sql.Token(sqlparse.tokens.Other, f"'\\x{value.hex()}'") @@ -235,7 +235,7 @@ def _parse_placeholder(token): if token.value == "%s": return Paramstyle.FORMAT, None - # E.g., %(foo) + # E.g., %(foo)s matches = re.search(r"%\((\w+)\)s$", token.value) if matches: return Paramstyle.PYFORMAT, matches.group(1) @@ -253,6 +253,10 @@ def _escape_verbatim_colons(tokens): return tokens +def _is_command_token(token): + return token.ttype in {sqlparse.tokens.Keyword, sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML} + + def _is_string_literal(token): return token.ttype in [sqlparse.tokens.Literal.String, sqlparse.tokens.Literal.String.Single] diff --git a/src/cs50/sql.py b/src/cs50/sql.py index a1f7dbd..64aa83d 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -14,14 +14,16 @@ class SQL: def __init__(self, url, **engine_kwargs): self._session = Session(url, **engine_kwargs) + self._dialect = self._session.get_bind().dialect + self._is_postgres = self._dialect in {"postgres", "postgresql"} self._autocommit = False def execute(self, sql, *args, **kwargs): """Execute a SQL statement.""" - statement = Statement(self._session.get_bind().dialect, sql, *args, **kwargs) + statement = Statement(self._dialect, sql, *args, **kwargs) command = statement.get_command() - if command in ["BEGIN", "START"]: + if command in {"BEGIN", "START"}: self._autocommit = False if self._autocommit: @@ -33,18 +35,15 @@ def execute(self, sql, *args, **kwargs): self._session.execute("COMMIT") self._session.remove() - if command in ["COMMIT", "ROLLBACK"]: + if command in {"COMMIT", "ROLLBACK"}: self._autocommit = True self._session.remove() if command == "SELECT": ret = _fetch_select_result(result) elif command == "INSERT": - if self._session.is_postgres(): - ret = self._get_last_val() - else: - ret = result.lastrowid if result.rowcount == 1 else None - elif command in ["DELETE", "UPDATE"]: + ret = self._last_row_id_or_none(result) + elif command in {"DELETE", "UPDATE"}: ret = result.rowcount else: ret = True @@ -71,6 +70,16 @@ def _execute(self, statement): return result + def _last_row_id_or_none(self, result): + if self.is_postgres(): + return self._get_last_val() + return result.lastrowid if result.rowcount == 1 else None + + + def is_postgres(self): + return self._is_postgres + + def _get_last_val(self): try: return self._session.execute("SELECT LASTVAL()").first()[0] From f8afbccdb15f860dbcefb69febd4458ad3dc8673 Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Sat, 10 Apr 2021 07:39:19 -0400 Subject: [PATCH 050/159] fix is_postgres --- src/cs50/sql.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index 64aa83d..74ec9b2 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -15,7 +15,7 @@ class SQL: def __init__(self, url, **engine_kwargs): self._session = Session(url, **engine_kwargs) self._dialect = self._session.get_bind().dialect - self._is_postgres = self._dialect in {"postgres", "postgresql"} + self._is_postgres = self._dialect.name in {"postgres", "postgresql"} self._autocommit = False From 62b998265a5a8928b4d03962e7e9df65d4fc2ea1 Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Sat, 10 Apr 2021 07:47:29 -0400 Subject: [PATCH 051/159] abstract away engine creation --- src/cs50/_session.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/cs50/_session.py b/src/cs50/_session.py index 441371a..3aff4f7 100644 --- a/src/cs50/_session.py +++ b/src/cs50/_session.py @@ -5,13 +5,12 @@ import sqlite3 class Session: + """Wraps a SQLAlchemy scoped session""" def __init__(self, url, **engine_kwargs): if _is_sqlite_url(url): _assert_sqlite_file_exists(url) - engine = _create_engine(url, **engine_kwargs) - _setup_on_connect(engine) - self._session = _create_scoped_session(engine) + self._session = _create_session(url, **engine_kwargs) def execute(self, statement): @@ -34,6 +33,12 @@ def _assert_sqlite_file_exists(url): raise RuntimeError(f"not a file: {path}") +def _create_session(url, **engine_kwargs): + engine = _create_engine(url, **engine_kwargs) + _setup_on_connect(engine) + return _create_scoped_session(engine) + + def _create_engine(url, **kwargs): try: engine = sqlalchemy.create_engine(url, **kwargs) From da613bec033538cf4df818479a92b12ef47867f6 Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Sat, 10 Apr 2021 09:01:58 -0400 Subject: [PATCH 052/159] fix pylint errors --- src/cs50/__init__.py | 9 ++++-- src/cs50/_flask.py | 38 ++++++++++++++++++++++++++ src/cs50/_logger.py | 17 ++++++++---- src/cs50/_session.py | 6 +++- src/cs50/_statement.py | 62 +++++++++++++++++++++++------------------- src/cs50/cs50.py | 48 +++++++++++++++++--------------- src/cs50/flask.py | 37 ------------------------- src/cs50/sql.py | 12 ++++---- 8 files changed, 126 insertions(+), 103 deletions(-) create mode 100644 src/cs50/_flask.py delete mode 100644 src/cs50/flask.py diff --git a/src/cs50/__init__.py b/src/cs50/__init__.py index f04da00..b75f415 100644 --- a/src/cs50/__init__.py +++ b/src/cs50/__init__.py @@ -1,6 +1,9 @@ -from ._logger import _setup_logger -_setup_logger() +"""Exposes API, wraps flask, and sets up logging""" from .cs50 import get_float, get_int, get_string -from . import flask from .sql import SQL +from ._logger import _setup_logger +from ._flask import _wrap_flask + +_setup_logger() +_wrap_flask() diff --git a/src/cs50/_flask.py b/src/cs50/_flask.py new file mode 100644 index 0000000..d65a8a5 --- /dev/null +++ b/src/cs50/_flask.py @@ -0,0 +1,38 @@ +"""Hooks into flask importing to support X-Forwarded-Proto header in online IDE""" + +import os +import pkgutil +import sys + +from distutils.version import StrictVersion +from werkzeug.middleware.proxy_fix import ProxyFix + + +def _wrap_flask(): + if "flask" in sys.modules: + _support_x_forwarded_proto(sys.modules["flask"]) + else: + flask_loader = pkgutil.get_loader('flask') + if flask_loader: + _exec_module_before = flask_loader.exec_module + + def _exec_module_after(*args, **kwargs): + _exec_module_before(*args, **kwargs) + _support_x_forwarded_proto(sys.modules["flask"]) + + flask_loader.exec_module = _exec_module_after + + +def _support_x_forwarded_proto(flask_module): + if flask_module is None: + return + + if flask_module.__version__ < StrictVersion("1.0"): + return + + if os.getenv("CS50_IDE_TYPE") == "online": + _flask_init_before = flask_module.Flask.__init__ + def _flask_init_after(self, *args, **kwargs): + _flask_init_before(self, *args, **kwargs) + self.wsgi_app = ProxyFix(self.wsgi_app, x_proto=1) # For HTTPS-to-HTTP proxy + flask_module.Flask.__init__ = _flask_init_after diff --git a/src/cs50/_logger.py b/src/cs50/_logger.py index c489111..df021a3 100644 --- a/src/cs50/_logger.py +++ b/src/cs50/_logger.py @@ -1,3 +1,5 @@ +"""Sets up logging for cs50 library""" + import logging import os.path import re @@ -14,7 +16,8 @@ def _setup_logger(): try: # Patch formatException - logging.root.handlers[0].formatter.formatException = lambda exc_info: _formatException(*exc_info) + formatter = logging.root.handlers[0].formatter + formatter.formatException = lambda exc_info: _format_exception(*exc_info) except IndexError: pass @@ -29,14 +32,15 @@ def _setup_logger(): handler.setLevel(logging.DEBUG) formatter = logging.Formatter("%(levelname)s: %(message)s") - formatter.formatException = lambda exc_info: _formatException(*exc_info) + formatter.formatException = lambda exc_info: _format_exception(*exc_info) handler.setFormatter(formatter) _logger.addHandler(handler) - sys.excepthook = lambda type, value, tb: print(_formatException(type, value, tb), file=sys.stderr) + sys.excepthook = lambda type_, value, exc_tb: print( + _format_exception(type_, value, exc_tb), file=sys.stderr) -def _formatException(type, value, tb): +def _format_exception(type_, value, exc_tb): """ Format traceback, darkening entries from global site-packages directories and user-specific site-packages directory. @@ -48,11 +52,12 @@ def _formatException(type, value, tb): # Highlight lines not referring to files in site-packages lines = [] - for line in traceback.format_exception(type, value, tb): + for line in traceback.format_exception(type_, value, exc_tb): matches = re.search(r"^ File \"([^\"]+)\", line \d+, in .+", line) if matches and matches.group(1).startswith(packages): lines += line else: matches = re.search(r"^(\s*)(.*?)(\s*)$", line, re.DOTALL) - lines.append(matches.group(1) + termcolor.colored(matches.group(2), "yellow") + matches.group(3)) + lines.append( + matches.group(1) + termcolor.colored(matches.group(2), "yellow") + matches.group(3)) return "".join(lines).rstrip() diff --git a/src/cs50/_session.py b/src/cs50/_session.py index 3aff4f7..cd23453 100644 --- a/src/cs50/_session.py +++ b/src/cs50/_session.py @@ -1,8 +1,10 @@ +"""Wraps a SQLAlchemy scoped session""" + import os +import sqlite3 import sqlalchemy import sqlalchemy.orm -import sqlite3 class Session: """Wraps a SQLAlchemy scoped session""" @@ -14,6 +16,8 @@ def __init__(self, url, **engine_kwargs): def execute(self, statement): + """Converts statement to str and executes it""" + # pylint: disable=no-member return self._session.execute(sqlalchemy.text(str(statement))) diff --git a/src/cs50/_statement.py b/src/cs50/_statement.py index d6ba10d..7a38c90 100644 --- a/src/cs50/_statement.py +++ b/src/cs50/_statement.py @@ -1,3 +1,5 @@ +"""Parses a SQL statement and binds its parameters""" + import collections import datetime import enum @@ -8,6 +10,7 @@ class Statement: + """Parses and binds a SQL statement""" def __init__(self, dialect, sql, *args, **kwargs): if len(args) > 0 and len(kwargs) > 0: raise RuntimeError("cannot pass both positional and named parameters") @@ -21,13 +24,14 @@ def __init__(self, dialect, sql, *args, **kwargs): self._command = self._parse_command() self._tokens = self._bind_params() + def _parse(self): formatted_statements = sqlparse.format(self._sql, strip_comments=True).strip() parsed_statements = sqlparse.parse(formatted_statements) statement_count = len(parsed_statements) if statement_count == 0: raise RuntimeError("missing statement") - elif statement_count > 1: + if statement_count > 1: raise RuntimeError("too many statements at once") return parsed_statements[0] @@ -49,11 +53,11 @@ def _parse_command(self): def _bind_params(self): tokens = self._tokenize() paramstyle, placeholders = self._parse_placeholders(tokens) - if paramstyle in {Paramstyle.FORMAT, Paramstyle.QMARK}: + if paramstyle in {_Paramstyle.FORMAT, _Paramstyle.QMARK}: tokens = self._bind_format_or_qmark(placeholders, tokens) - elif paramstyle == Paramstyle.NUMERIC: + elif paramstyle == _Paramstyle.NUMERIC: tokens = self._bind_numeric(placeholders, tokens) - if paramstyle in {Paramstyle.NAMED, Paramstyle.PYFORMAT}: + if paramstyle in {_Paramstyle.NAMED, _Paramstyle.PYFORMAT}: tokens = self._bind_named_or_pyformat(placeholders, tokens) tokens = _escape_verbatim_colons(tokens) @@ -86,9 +90,9 @@ def _parse_placeholders(self, tokens): def _default_paramstyle(self): paramstyle = None if self._args: - paramstyle = Paramstyle.QMARK + paramstyle = _Paramstyle.QMARK elif self._kwargs: - paramstyle = Paramstyle.NAMED + paramstyle = _Paramstyle.NAMED return paramstyle @@ -118,8 +122,10 @@ def _bind_numeric(self, placeholders, tokens): unused_arg_indices.remove(num) if len(unused_arg_indices) > 0: - unused_args = ", ".join([str(self._escape(self._args[i])) for i in sorted(unused_arg_indices)]) - raise RuntimeError(f"unused value{'' if len(unused_arg_indices) == 1 else 's'} ({unused_args})") + unused_args = ", ".join( + [str(self._escape(self._args[i])) for i in sorted(unused_arg_indices)]) + raise RuntimeError( + f"unused value{'' if len(unused_arg_indices) == 1 else 's'} ({unused_args})") return tokens @@ -134,7 +140,9 @@ def _bind_named_or_pyformat(self, placeholders, tokens): unused_params.remove(param_name) if len(unused_params) > 0: - raise RuntimeError("unused value{'' if len(unused_params) == 1 else 's'} ({', '.join(sorted(unused_params))})") + joined_unused_params = ", ".join(sorted(unused_params)) + raise RuntimeError( + f"unused value{'' if len(unused_params) == 1 else 's'} ({joined_unused_params})") return tokens @@ -144,7 +152,7 @@ def _escape(self, value): Escapes value using engine's conversion function. https://docs.sqlalchemy.org/en/latest/core/type_api.html#sqlalchemy.types.TypeEngine.literal_processor """ - + # pylint: disable=too-many-return-statements if isinstance(value, (list, tuple)): return self._escape_iterable(value) @@ -163,20 +171,18 @@ def _escape(self, value): raise RuntimeError(f"unsupported value: {value}") + string_processor = sqlalchemy.types.String().literal_processor(self._dialect) if isinstance(value, datetime.date): return sqlparse.sql.Token( - sqlparse.tokens.String, - sqlalchemy.types.String().literal_processor(self._dialect)(value.strftime("%Y-%m-%d"))) + sqlparse.tokens.String, string_processor(value.strftime("%Y-%m-%d"))) if isinstance(value, datetime.datetime): return sqlparse.sql.Token( - sqlparse.tokens.String, - sqlalchemy.types.String().literal_processor(self._dialect)(value.strftime("%Y-%m-%d %H:%M:%S"))) + sqlparse.tokens.String, string_processor(value.strftime("%Y-%m-%d %H:%M:%S"))) if isinstance(value, datetime.time): return sqlparse.sql.Token( - sqlparse.tokens.String, - sqlalchemy.types.String().literal_processor(self._dialect)(value.strftime("%H:%M:%S"))) + sqlparse.tokens.String, string_processor(value.strftime("%H:%M:%S"))) if isinstance(value, float): return sqlparse.sql.Token( @@ -189,9 +195,7 @@ def _escape(self, value): sqlalchemy.types.Integer().literal_processor(self._dialect)(value)) if isinstance(value, str): - return sqlparse.sql.Token( - sqlparse.tokens.String, - sqlalchemy.types.String().literal_processor(self._dialect)(value)) + return sqlparse.sql.Token(sqlparse.tokens.String, string_processor(value)) if value is None: return sqlparse.sql.Token( @@ -207,6 +211,7 @@ def _escape_iterable(self, iterable): def get_command(self): + """Returns statement command (e.g., SELECT) or None""" return self._command @@ -220,25 +225,25 @@ def _is_placeholder(token): def _parse_placeholder(token): if token.value == "?": - return Paramstyle.QMARK, None + return _Paramstyle.QMARK, None # E.g., :1 matches = re.search(r"^:([1-9]\d*)$", token.value) if matches: - return Paramstyle.NUMERIC, int(matches.group(1)) - 1 + return _Paramstyle.NUMERIC, int(matches.group(1)) - 1 # E.g., :foo matches = re.search(r"^:([a-zA-Z]\w*)$", token.value) if matches: - return Paramstyle.NAMED, matches.group(1) + return _Paramstyle.NAMED, matches.group(1) if token.value == "%s": - return Paramstyle.FORMAT, None + return _Paramstyle.FORMAT, None # E.g., %(foo)s matches = re.search(r"%\((\w+)\)s$", token.value) if matches: - return Paramstyle.PYFORMAT, matches.group(1) + return _Paramstyle.PYFORMAT, matches.group(1) raise RuntimeError(f"{token.value}: invalid placeholder") @@ -246,15 +251,16 @@ def _parse_placeholder(token): def _escape_verbatim_colons(tokens): for token in tokens: if _is_string_literal(token): - token.value = re.sub("(^'|\s+):", r"\1\:", token.value) + token.value = re.sub(r"(^'|\s+):", r"\1\:", token.value) elif _is_identifier(token): - token.value = re.sub("(^\"|\s+):", r"\1\:", token.value) + token.value = re.sub(r"(^\"|\s+):", r"\1\:", token.value) return tokens def _is_command_token(token): - return token.ttype in {sqlparse.tokens.Keyword, sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML} + return token.ttype in { + sqlparse.tokens.Keyword, sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML} def _is_string_literal(token): @@ -265,7 +271,7 @@ def _is_identifier(token): return token.ttype == sqlparse.tokens.Literal.String.Symbol -class Paramstyle(enum.Enum): +class _Paramstyle(enum.Enum): FORMAT = enum.auto() NAMED = enum.auto() NUMERIC = enum.auto() diff --git a/src/cs50/cs50.py b/src/cs50/cs50.py index 573d862..24c748b 100644 --- a/src/cs50/cs50.py +++ b/src/cs50/cs50.py @@ -1,3 +1,5 @@ +"""Exposes simple API for getting and validating user input""" + import re import sys @@ -16,14 +18,14 @@ def get_float(prompt): def _get_float(prompt): - s = get_string(prompt) - if s is None: - return + user_input = get_string(prompt) + if user_input is None: + return None - if len(s) > 0 and re.search(r"^[+-]?\d*(?:\.\d*)?$", s): - return float(s) + if len(user_input) > 0 and re.search(r"^[+-]?\d*(?:\.\d*)?$", user_input): + return float(user_input) - raise ValueError(f"invalid float literal: {s}") + raise ValueError(f"invalid float literal: {user_input}") def get_int(prompt): @@ -40,14 +42,14 @@ def get_int(prompt): def _get_int(prompt): - s = get_string(prompt) - if s is None: - return + user_input = get_string(prompt) + if user_input is None: + return None - if re.search(r"^[+-]?\d+$", s): - return int(s, 10) + if re.search(r"^[+-]?\d+$", user_input): + return int(user_input, 10) - raise ValueError(f"invalid int literal for base 10: {s}") + raise ValueError(f"invalid int literal for base 10: {user_input}") def get_string(prompt): @@ -57,13 +59,13 @@ def get_string(prompt): as line endings. If user inputs only a line ending, returns "", not None. Returns None upon error or no input whatsoever (i.e., just EOF). """ - if type(prompt) is not str: + if not isinstance(prompt, str): raise TypeError("prompt must be of type str") try: return _get_input(prompt) except EOFError: - return + return None def _get_input(prompt): @@ -76,18 +78,20 @@ class _flushfile(): http://stackoverflow.com/a/231216 """ - def __init__(self, f): - self.f = f + def __init__(self, stream): + self.stream = stream def __getattr__(self, name): - return object.__getattribute__(self.f, name) + return object.__getattribute__(self.stream, name) - def write(self, x): - self.f.write(x) - self.f.flush() + def write(self, data): + """Writes data to stream""" + self.stream.write(data) + self.stream.flush() -def disable_buffering(): +def disable_output_buffering(): + """Disables output buffering to prevent prompts from being buffered""" sys.stderr = _flushfile(sys.stderr) sys.stdout = _flushfile(sys.stdout) -disable_buffering() +disable_output_buffering() diff --git a/src/cs50/flask.py b/src/cs50/flask.py deleted file mode 100644 index a0e077a..0000000 --- a/src/cs50/flask.py +++ /dev/null @@ -1,37 +0,0 @@ -import os -import pkgutil -import sys - -from distutils.version import StrictVersion -from werkzeug.middleware.proxy_fix import ProxyFix - -def _wrap_flask(f): - if f is None: - return - - if f.__version__ < StrictVersion("1.0"): - return - - if os.getenv("CS50_IDE_TYPE") == "online": - _flask_init_before = f.Flask.__init__ - def _flask_init_after(self, *args, **kwargs): - _flask_init_before(self, *args, **kwargs) - self.wsgi_app = ProxyFix(self.wsgi_app, x_proto=1) # For HTTPS-to-HTTP proxy - f.Flask.__init__ = _flask_init_after - - -# If Flask was imported before cs50 -if "flask" in sys.modules: - _wrap_flask(sys.modules["flask"]) - -# If Flask wasn't imported -else: - flask_loader = pkgutil.get_loader('flask') - if flask_loader: - _exec_module_before = flask_loader.exec_module - - def _exec_module_after(*args, **kwargs): - _exec_module_before(*args, **kwargs) - _wrap_flask(sys.modules["flask"]) - - flask_loader.exec_module = _exec_module_after diff --git a/src/cs50/sql.py b/src/cs50/sql.py index 74ec9b2..0510f17 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -1,3 +1,5 @@ +"""Wraps SQLAlchemy""" + import decimal import logging import warnings @@ -12,6 +14,7 @@ class SQL: + """Wraps SQLAlchemy""" def __init__(self, url, **engine_kwargs): self._session = Session(url, **engine_kwargs) self._dialect = self._session.get_bind().dialect @@ -71,15 +74,11 @@ def _execute(self, statement): def _last_row_id_or_none(self, result): - if self.is_postgres(): + if self._is_postgres: return self._get_last_val() return result.lastrowid if result.rowcount == 1 else None - def is_postgres(self): - return self._is_postgres - - def _get_last_val(self): try: return self._session.execute("SELECT LASTVAL()").first()[0] @@ -88,8 +87,9 @@ def _get_last_val(self): def init_app(self, app): + """Registers a teardown_appcontext listener to remove session and enables logging""" @app.teardown_appcontext - def shutdown_session(_): + def _(_): self._session.remove() logging.getLogger("cs50").disabled = False From 4a593dd4ab27ade5978906b5d9be40ac3af78ed2 Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Mon, 12 Apr 2021 15:58:51 -0400 Subject: [PATCH 053/159] rename _parse_command --- src/cs50/_statement.py | 20 ++++++++++---------- src/cs50/sql.py | 12 ++++++------ 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/cs50/_statement.py b/src/cs50/_statement.py index 7a38c90..598b131 100644 --- a/src/cs50/_statement.py +++ b/src/cs50/_statement.py @@ -21,7 +21,7 @@ def __init__(self, dialect, sql, *args, **kwargs): self._kwargs = kwargs self._statement = self._parse() - self._command = self._parse_command() + self._operation_keyword = self._get_operation_keyword() self._tokens = self._bind_params() @@ -37,17 +37,17 @@ def _parse(self): return parsed_statements[0] - def _parse_command(self): + def _get_operation_keyword(self): for token in self._statement: - if _is_command_token(token): + if _is_operation_token(token): token_value = token.value.upper() if token_value in {"BEGIN", "DELETE", "INSERT", "SELECT", "START", "UPDATE"}: - command = token_value + operation_keyword = token_value break else: - command = None + operation_keyword = None - return command + return operation_keyword def _bind_params(self): @@ -210,9 +210,9 @@ def _escape_iterable(self, iterable): sqlparse.parse(", ".join([str(self._escape(v)) for v in iterable]))) - def get_command(self): - """Returns statement command (e.g., SELECT) or None""" - return self._command + def get_operation_keyword(self): + """Returns the operation keyword of the statement (e.g., SELECT) if found, or None""" + return self._operation_keyword def __str__(self): @@ -258,7 +258,7 @@ def _escape_verbatim_colons(tokens): return tokens -def _is_command_token(token): +def _is_operation_token(token): return token.ttype in { sqlparse.tokens.Keyword, sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML} diff --git a/src/cs50/sql.py b/src/cs50/sql.py index 0510f17..fca57d2 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -25,8 +25,8 @@ def __init__(self, url, **engine_kwargs): def execute(self, sql, *args, **kwargs): """Execute a SQL statement.""" statement = Statement(self._dialect, sql, *args, **kwargs) - command = statement.get_command() - if command in {"BEGIN", "START"}: + operation_keyword = statement.get_operation_keyword() + if operation_keyword in {"BEGIN", "START"}: self._autocommit = False if self._autocommit: @@ -38,15 +38,15 @@ def execute(self, sql, *args, **kwargs): self._session.execute("COMMIT") self._session.remove() - if command in {"COMMIT", "ROLLBACK"}: + if operation_keyword in {"COMMIT", "ROLLBACK"}: self._autocommit = True self._session.remove() - if command == "SELECT": + if operation_keyword == "SELECT": ret = _fetch_select_result(result) - elif command == "INSERT": + elif operation_keyword == "INSERT": ret = self._last_row_id_or_none(result) - elif command in {"DELETE", "UPDATE"}: + elif operation_keyword in {"DELETE", "UPDATE"}: ret = result.rowcount else: ret = True From 6fcf7ed469ad8cdff1628e591eff3d7eb6767f34 Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Mon, 12 Apr 2021 16:01:36 -0400 Subject: [PATCH 054/159] rename _bind_params --- src/cs50/_statement.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/cs50/_statement.py b/src/cs50/_statement.py index 598b131..789acdf 100644 --- a/src/cs50/_statement.py +++ b/src/cs50/_statement.py @@ -1,4 +1,4 @@ -"""Parses a SQL statement and binds its parameters""" +"""Parses a SQL statement and replaces placeholders with parameters""" import collections import datetime @@ -10,7 +10,7 @@ class Statement: - """Parses and binds a SQL statement""" + """Parses a SQL statement and replaces placeholders with parameters""" def __init__(self, dialect, sql, *args, **kwargs): if len(args) > 0 and len(kwargs) > 0: raise RuntimeError("cannot pass both positional and named parameters") @@ -22,7 +22,7 @@ def __init__(self, dialect, sql, *args, **kwargs): self._statement = self._parse() self._operation_keyword = self._get_operation_keyword() - self._tokens = self._bind_params() + self._tokens = self._replace_placeholders_with_params() def _parse(self): @@ -50,15 +50,15 @@ def _get_operation_keyword(self): return operation_keyword - def _bind_params(self): + def _replace_placeholders_with_params(self): tokens = self._tokenize() paramstyle, placeholders = self._parse_placeholders(tokens) if paramstyle in {_Paramstyle.FORMAT, _Paramstyle.QMARK}: - tokens = self._bind_format_or_qmark(placeholders, tokens) + tokens = self._replace_format_or_qmark_placeholders(placeholders, tokens) elif paramstyle == _Paramstyle.NUMERIC: - tokens = self._bind_numeric(placeholders, tokens) + tokens = self._replace_numeric_placeholders(placeholders, tokens) if paramstyle in {_Paramstyle.NAMED, _Paramstyle.PYFORMAT}: - tokens = self._bind_named_or_pyformat(placeholders, tokens) + tokens = self._replace_named_or_pyformat_placeholders(placeholders, tokens) tokens = _escape_verbatim_colons(tokens) return tokens @@ -97,7 +97,7 @@ def _default_paramstyle(self): return paramstyle - def _bind_format_or_qmark(self, placeholders, tokens): + def _replace_format_or_qmark_placeholders(self, placeholders, tokens): if len(placeholders) != len(self._args): _placeholders = ", ".join([str(token) for token in placeholders.values()]) _args = ", ".join([str(self._escape(arg)) for arg in self._args]) @@ -112,7 +112,7 @@ def _bind_format_or_qmark(self, placeholders, tokens): return tokens - def _bind_numeric(self, placeholders, tokens): + def _replace_numeric_placeholders(self, placeholders, tokens): unused_arg_indices = set(range(len(self._args))) for token_index, num in placeholders.items(): if num >= len(self._args): @@ -130,7 +130,7 @@ def _bind_numeric(self, placeholders, tokens): return tokens - def _bind_named_or_pyformat(self, placeholders, tokens): + def _replace_named_or_pyformat_placeholders(self, placeholders, tokens): unused_params = set(self._kwargs.keys()) for token_index, param_name in placeholders.items(): if param_name not in self._kwargs: From a4989eb05e078f78ac5dc488300b804ae3ecb996 Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Mon, 12 Apr 2021 19:37:49 -0400 Subject: [PATCH 055/159] wrap flask wrapper the IDE proxy now handles forcing https --- src/cs50/__init__.py | 4 +--- src/cs50/_flask.py | 38 ------------------------------- tests/redirect/application.py | 12 ---------- tests/redirect/templates/foo.html | 1 - 4 files changed, 1 insertion(+), 54 deletions(-) delete mode 100644 src/cs50/_flask.py delete mode 100644 tests/redirect/application.py delete mode 100644 tests/redirect/templates/foo.html diff --git a/src/cs50/__init__.py b/src/cs50/__init__.py index b75f415..fa07171 100644 --- a/src/cs50/__init__.py +++ b/src/cs50/__init__.py @@ -1,9 +1,7 @@ -"""Exposes API, wraps flask, and sets up logging""" +"""Exposes API and sets up logging""" from .cs50 import get_float, get_int, get_string from .sql import SQL from ._logger import _setup_logger -from ._flask import _wrap_flask _setup_logger() -_wrap_flask() diff --git a/src/cs50/_flask.py b/src/cs50/_flask.py deleted file mode 100644 index d65a8a5..0000000 --- a/src/cs50/_flask.py +++ /dev/null @@ -1,38 +0,0 @@ -"""Hooks into flask importing to support X-Forwarded-Proto header in online IDE""" - -import os -import pkgutil -import sys - -from distutils.version import StrictVersion -from werkzeug.middleware.proxy_fix import ProxyFix - - -def _wrap_flask(): - if "flask" in sys.modules: - _support_x_forwarded_proto(sys.modules["flask"]) - else: - flask_loader = pkgutil.get_loader('flask') - if flask_loader: - _exec_module_before = flask_loader.exec_module - - def _exec_module_after(*args, **kwargs): - _exec_module_before(*args, **kwargs) - _support_x_forwarded_proto(sys.modules["flask"]) - - flask_loader.exec_module = _exec_module_after - - -def _support_x_forwarded_proto(flask_module): - if flask_module is None: - return - - if flask_module.__version__ < StrictVersion("1.0"): - return - - if os.getenv("CS50_IDE_TYPE") == "online": - _flask_init_before = flask_module.Flask.__init__ - def _flask_init_after(self, *args, **kwargs): - _flask_init_before(self, *args, **kwargs) - self.wsgi_app = ProxyFix(self.wsgi_app, x_proto=1) # For HTTPS-to-HTTP proxy - flask_module.Flask.__init__ = _flask_init_after diff --git a/tests/redirect/application.py b/tests/redirect/application.py deleted file mode 100644 index 6aff187..0000000 --- a/tests/redirect/application.py +++ /dev/null @@ -1,12 +0,0 @@ -import cs50 -from flask import Flask, redirect, render_template - -app = Flask(__name__) - -@app.route("/") -def index(): - return redirect("/foo") - -@app.route("/foo") -def foo(): - return render_template("foo.html") diff --git a/tests/redirect/templates/foo.html b/tests/redirect/templates/foo.html deleted file mode 100644 index 257cc56..0000000 --- a/tests/redirect/templates/foo.html +++ /dev/null @@ -1 +0,0 @@ -foo From e8827bfa0b68c06b48822d03d30cf84c882b45bf Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Mon, 12 Apr 2021 20:38:01 -0400 Subject: [PATCH 056/159] factor out sanitizer --- src/cs50/_sql_sanitizer.py | 86 +++++++++++++++++++ src/cs50/_statement.py | 167 ++++++++++--------------------------- 2 files changed, 132 insertions(+), 121 deletions(-) create mode 100644 src/cs50/_sql_sanitizer.py diff --git a/src/cs50/_sql_sanitizer.py b/src/cs50/_sql_sanitizer.py new file mode 100644 index 0000000..c2f35c4 --- /dev/null +++ b/src/cs50/_sql_sanitizer.py @@ -0,0 +1,86 @@ +"""Escapes SQL values""" + +import datetime +import re + +import sqlalchemy +import sqlparse + + +class SQLSanitizer: + """Escapes SQL values""" + + def __init__(self, dialect): + self._dialect = dialect + + + def escape(self, value): + """ + Escapes value using engine's conversion function. + https://docs.sqlalchemy.org/en/latest/core/type_api.html#sqlalchemy.types.TypeEngine.literal_processor + """ + # pylint: disable=too-many-return-statements + if isinstance(value, (list, tuple)): + return self.escape_iterable(value) + + if isinstance(value, bool): + return sqlparse.sql.Token( + sqlparse.tokens.Number, + sqlalchemy.types.Boolean().literal_processor(self._dialect)(value)) + + if isinstance(value, bytes): + if self._dialect.name in {"mysql", "sqlite"}: + # https://dev.mysql.com/doc/refman/8.0/en/hexadecimal-literals.html + return sqlparse.sql.Token(sqlparse.tokens.Other, f"x'{value.hex()}'") + if self._dialect.name in {"postgres", "postgresql"}: + # https://dba.stackexchange.com/a/203359 + return sqlparse.sql.Token(sqlparse.tokens.Other, f"'\\x{value.hex()}'") + + raise RuntimeError(f"unsupported value: {value}") + + string_processor = sqlalchemy.types.String().literal_processor(self._dialect) + if isinstance(value, datetime.date): + return sqlparse.sql.Token( + sqlparse.tokens.String, string_processor(value.strftime("%Y-%m-%d"))) + + if isinstance(value, datetime.datetime): + return sqlparse.sql.Token( + sqlparse.tokens.String, string_processor(value.strftime("%Y-%m-%d %H:%M:%S"))) + + if isinstance(value, datetime.time): + return sqlparse.sql.Token( + sqlparse.tokens.String, string_processor(value.strftime("%H:%M:%S"))) + + if isinstance(value, float): + return sqlparse.sql.Token( + sqlparse.tokens.Number, + sqlalchemy.types.Float().literal_processor(self._dialect)(value)) + + if isinstance(value, int): + return sqlparse.sql.Token( + sqlparse.tokens.Number, + sqlalchemy.types.Integer().literal_processor(self._dialect)(value)) + + if isinstance(value, str): + return sqlparse.sql.Token(sqlparse.tokens.String, string_processor(value)) + + if value is None: + return sqlparse.sql.Token( + sqlparse.tokens.Keyword, + sqlalchemy.types.NullType().literal_processor(self._dialect)(value)) + + raise RuntimeError(f"unsupported value: {value}") + + + def escape_iterable(self, iterable): + """Escapes a collection of values (e.g., list, tuple)""" + return sqlparse.sql.TokenList( + sqlparse.parse(", ".join([str(self.escape(v)) for v in iterable]))) + + +def escape_verbatim_colon(value): + """Escapes verbatim colon from a value so as it is not confused with a placeholder""" + + # E.g., ':foo, ":foo, :foo will be replaced with + # '\:foo, "\:foo, \:foo respectively + return re.sub(r"(^(?:'|\")|\s+):", r"\1\:", value) diff --git a/src/cs50/_statement.py b/src/cs50/_statement.py index 789acdf..7222f0e 100644 --- a/src/cs50/_statement.py +++ b/src/cs50/_statement.py @@ -1,40 +1,28 @@ """Parses a SQL statement and replaces placeholders with parameters""" import collections -import datetime import enum import re -import sqlalchemy import sqlparse +from ._sql_sanitizer import SQLSanitizer, escape_verbatim_colon + + class Statement: """Parses a SQL statement and replaces placeholders with parameters""" def __init__(self, dialect, sql, *args, **kwargs): if len(args) > 0 and len(kwargs) > 0: raise RuntimeError("cannot pass both positional and named parameters") - self._dialect = dialect - self._sql = sql + self._sql_sanitizer = SQLSanitizer(dialect) self._args = args self._kwargs = kwargs - - self._statement = self._parse() + self._statement = _parse(sql) self._operation_keyword = self._get_operation_keyword() - self._tokens = self._replace_placeholders_with_params() - - - def _parse(self): - formatted_statements = sqlparse.format(self._sql, strip_comments=True).strip() - parsed_statements = sqlparse.parse(formatted_statements) - statement_count = len(parsed_statements) - if statement_count == 0: - raise RuntimeError("missing statement") - if statement_count > 1: - raise RuntimeError("too many statements at once") - - return parsed_statements[0] + self._tokens = self._tokenize() + self._replace_placeholders_with_params() def _get_operation_keyword(self): @@ -50,28 +38,26 @@ def _get_operation_keyword(self): return operation_keyword + def _tokenize(self): + return list(self._statement.flatten()) + + def _replace_placeholders_with_params(self): - tokens = self._tokenize() - paramstyle, placeholders = self._parse_placeholders(tokens) + paramstyle, placeholders = self._parse_placeholders() if paramstyle in {_Paramstyle.FORMAT, _Paramstyle.QMARK}: - tokens = self._replace_format_or_qmark_placeholders(placeholders, tokens) + self._replace_format_or_qmark_placeholders(placeholders) elif paramstyle == _Paramstyle.NUMERIC: - tokens = self._replace_numeric_placeholders(placeholders, tokens) + self._replace_numeric_placeholders(placeholders) if paramstyle in {_Paramstyle.NAMED, _Paramstyle.PYFORMAT}: - tokens = self._replace_named_or_pyformat_placeholders(placeholders, tokens) - - tokens = _escape_verbatim_colons(tokens) - return tokens - + self._replace_named_or_pyformat_placeholders(placeholders) - def _tokenize(self): - return list(self._statement.flatten()) + self._escape_verbatim_colons() - def _parse_placeholders(self, tokens): + def _parse_placeholders(self): paramstyle = None placeholders = collections.OrderedDict() - for index, token in enumerate(tokens): + for index, token in enumerate(self._tokens): if _is_placeholder(token): _paramstyle, name = _parse_placeholder(token) if paramstyle is None: @@ -97,46 +83,42 @@ def _default_paramstyle(self): return paramstyle - def _replace_format_or_qmark_placeholders(self, placeholders, tokens): + def _replace_format_or_qmark_placeholders(self, placeholders): if len(placeholders) != len(self._args): _placeholders = ", ".join([str(token) for token in placeholders.values()]) - _args = ", ".join([str(self._escape(arg)) for arg in self._args]) + _args = ", ".join([str(self._sql_sanitizer.escape(arg)) for arg in self._args]) if len(placeholders) < len(self._args): raise RuntimeError(f"fewer placeholders ({_placeholders}) than values ({_args})") raise RuntimeError(f"more placeholders ({_placeholders}) than values ({_args})") for arg_index, token_index in enumerate(placeholders.keys()): - tokens[token_index] = self._escape(self._args[arg_index]) + self._tokens[token_index] = self._sql_sanitizer.escape(self._args[arg_index]) - return tokens - - def _replace_numeric_placeholders(self, placeholders, tokens): - unused_arg_indices = set(range(len(self._args))) + def _replace_numeric_placeholders(self, placeholders): + unused_arg_idxs = set(range(len(self._args))) for token_index, num in placeholders.items(): if num >= len(self._args): raise RuntimeError(f"missing value for placeholder ({num + 1})") - tokens[token_index] = self._escape(self._args[num]) - unused_arg_indices.remove(num) + self._tokens[token_index] = self._sql_sanitizer.escape(self._args[num]) + unused_arg_idxs.remove(num) - if len(unused_arg_indices) > 0: + if len(unused_arg_idxs) > 0: unused_args = ", ".join( - [str(self._escape(self._args[i])) for i in sorted(unused_arg_indices)]) + [str(self._sql_sanitizer.escape(self._args[i])) for i in sorted(unused_arg_idxs)]) raise RuntimeError( - f"unused value{'' if len(unused_arg_indices) == 1 else 's'} ({unused_args})") - - return tokens + f"unused value{'' if len(unused_arg_idxs) == 1 else 's'} ({unused_args})") - def _replace_named_or_pyformat_placeholders(self, placeholders, tokens): + def _replace_named_or_pyformat_placeholders(self, placeholders): unused_params = set(self._kwargs.keys()) for token_index, param_name in placeholders.items(): if param_name not in self._kwargs: raise RuntimeError(f"missing value for placeholder ({param_name})") - tokens[token_index] = self._escape(self._kwargs[param_name]) + self._tokens[token_index] = self._sql_sanitizer.escape(self._kwargs[param_name]) unused_params.remove(param_name) if len(unused_params) > 0: @@ -144,70 +126,11 @@ def _replace_named_or_pyformat_placeholders(self, placeholders, tokens): raise RuntimeError( f"unused value{'' if len(unused_params) == 1 else 's'} ({joined_unused_params})") - return tokens - - - def _escape(self, value): - """ - Escapes value using engine's conversion function. - https://docs.sqlalchemy.org/en/latest/core/type_api.html#sqlalchemy.types.TypeEngine.literal_processor - """ - # pylint: disable=too-many-return-statements - if isinstance(value, (list, tuple)): - return self._escape_iterable(value) - if isinstance(value, bool): - return sqlparse.sql.Token( - sqlparse.tokens.Number, - sqlalchemy.types.Boolean().literal_processor(self._dialect)(value)) - - if isinstance(value, bytes): - if self._dialect.name in {"mysql", "sqlite"}: - # https://dev.mysql.com/doc/refman/8.0/en/hexadecimal-literals.html - return sqlparse.sql.Token(sqlparse.tokens.Other, f"x'{value.hex()}'") - if self._dialect.name in {"postgres", "postgresql"}: - # https://dba.stackexchange.com/a/203359 - return sqlparse.sql.Token(sqlparse.tokens.Other, f"'\\x{value.hex()}'") - - raise RuntimeError(f"unsupported value: {value}") - - string_processor = sqlalchemy.types.String().literal_processor(self._dialect) - if isinstance(value, datetime.date): - return sqlparse.sql.Token( - sqlparse.tokens.String, string_processor(value.strftime("%Y-%m-%d"))) - - if isinstance(value, datetime.datetime): - return sqlparse.sql.Token( - sqlparse.tokens.String, string_processor(value.strftime("%Y-%m-%d %H:%M:%S"))) - - if isinstance(value, datetime.time): - return sqlparse.sql.Token( - sqlparse.tokens.String, string_processor(value.strftime("%H:%M:%S"))) - - if isinstance(value, float): - return sqlparse.sql.Token( - sqlparse.tokens.Number, - sqlalchemy.types.Float().literal_processor(self._dialect)(value)) - - if isinstance(value, int): - return sqlparse.sql.Token( - sqlparse.tokens.Number, - sqlalchemy.types.Integer().literal_processor(self._dialect)(value)) - - if isinstance(value, str): - return sqlparse.sql.Token(sqlparse.tokens.String, string_processor(value)) - - if value is None: - return sqlparse.sql.Token( - sqlparse.tokens.Keyword, - sqlalchemy.types.NullType().literal_processor(self._dialect)(value)) - - raise RuntimeError(f"unsupported value: {value}") - - - def _escape_iterable(self, iterable): - return sqlparse.sql.TokenList( - sqlparse.parse(", ".join([str(self._escape(v)) for v in iterable]))) + def _escape_verbatim_colons(self): + for token in self._tokens: + if _is_string_literal(token) or _is_identifier(token): + token.value = escape_verbatim_colon(token.value) def get_operation_keyword(self): @@ -219,6 +142,18 @@ def __str__(self): return "".join([str(token) for token in self._tokens]) +def _parse(sql): + formatted_statements = sqlparse.format(sql, strip_comments=True).strip() + parsed_statements = sqlparse.parse(formatted_statements) + statement_count = len(parsed_statements) + if statement_count == 0: + raise RuntimeError("missing statement") + if statement_count > 1: + raise RuntimeError("too many statements at once") + + return parsed_statements[0] + + def _is_placeholder(token): return token.ttype == sqlparse.tokens.Name.Placeholder @@ -248,16 +183,6 @@ def _parse_placeholder(token): raise RuntimeError(f"{token.value}: invalid placeholder") -def _escape_verbatim_colons(tokens): - for token in tokens: - if _is_string_literal(token): - token.value = re.sub(r"(^'|\s+):", r"\1\:", token.value) - elif _is_identifier(token): - token.value = re.sub(r"(^\"|\s+):", r"\1\:", token.value) - - return tokens - - def _is_operation_token(token): return token.ttype in { sqlparse.tokens.Keyword, sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML} From bd330a703b6a29bf23d6f58749d5b0c857bbd32d Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Mon, 12 Apr 2021 20:58:54 -0400 Subject: [PATCH 057/159] promote paramstyle and placeholders to instance variables --- src/cs50/_statement.py | 97 ++++++++++++++++++++++-------------------- 1 file changed, 52 insertions(+), 45 deletions(-) diff --git a/src/cs50/_statement.py b/src/cs50/_statement.py index 7222f0e..34f9247 100644 --- a/src/cs50/_statement.py +++ b/src/cs50/_statement.py @@ -6,12 +6,11 @@ import sqlparse - from ._sql_sanitizer import SQLSanitizer, escape_verbatim_colon class Statement: - """Parses a SQL statement and replaces placeholders with parameters""" + """Parses a SQL statement and replaces the placeholders with the corresponding parameters""" def __init__(self, dialect, sql, *args, **kwargs): if len(args) > 0 and len(kwargs) > 0: raise RuntimeError("cannot pass both positional and named parameters") @@ -19,23 +18,12 @@ def __init__(self, dialect, sql, *args, **kwargs): self._sql_sanitizer = SQLSanitizer(dialect) self._args = args self._kwargs = kwargs - self._statement = _parse(sql) - self._operation_keyword = self._get_operation_keyword() + self._statement = _format_and_parse(sql) self._tokens = self._tokenize() + self._paramstyle = self._get_paramstyle() + self._placeholders = self._get_placeholders() self._replace_placeholders_with_params() - - - def _get_operation_keyword(self): - for token in self._statement: - if _is_operation_token(token): - token_value = token.value.upper() - if token_value in {"BEGIN", "DELETE", "INSERT", "SELECT", "START", "UPDATE"}: - operation_keyword = token_value - break - else: - operation_keyword = None - - return operation_keyword + self._operation_keyword = self._get_operation_keyword() def _tokenize(self): @@ -43,34 +31,40 @@ def _tokenize(self): def _replace_placeholders_with_params(self): - paramstyle, placeholders = self._parse_placeholders() - if paramstyle in {_Paramstyle.FORMAT, _Paramstyle.QMARK}: - self._replace_format_or_qmark_placeholders(placeholders) - elif paramstyle == _Paramstyle.NUMERIC: - self._replace_numeric_placeholders(placeholders) - if paramstyle in {_Paramstyle.NAMED, _Paramstyle.PYFORMAT}: - self._replace_named_or_pyformat_placeholders(placeholders) + if self._paramstyle in {_Paramstyle.FORMAT, _Paramstyle.QMARK}: + self._replace_format_or_qmark_placeholders() + elif self._paramstyle == _Paramstyle.NUMERIC: + self._replace_numeric_placeholders() + if self._paramstyle in {_Paramstyle.NAMED, _Paramstyle.PYFORMAT}: + self._replace_named_or_pyformat_placeholders() self._escape_verbatim_colons() - def _parse_placeholders(self): + def _get_paramstyle(self): paramstyle = None + for token in self._tokens: + if _is_placeholder(token): + paramstyle, _ = _parse_placeholder(token) + break + + if paramstyle is None: + paramstyle = self._default_paramstyle() + + return paramstyle + + + def _get_placeholders(self): placeholders = collections.OrderedDict() for index, token in enumerate(self._tokens): if _is_placeholder(token): - _paramstyle, name = _parse_placeholder(token) - if paramstyle is None: - paramstyle = _paramstyle - elif _paramstyle != paramstyle: + paramstyle, name = _parse_placeholder(token) + if paramstyle != self._paramstyle: raise RuntimeError("inconsistent paramstyle") placeholders[index] = name - if paramstyle is None: - paramstyle = self._default_paramstyle() - - return paramstyle, placeholders + return placeholders def _default_paramstyle(self): @@ -83,22 +77,22 @@ def _default_paramstyle(self): return paramstyle - def _replace_format_or_qmark_placeholders(self, placeholders): - if len(placeholders) != len(self._args): - _placeholders = ", ".join([str(token) for token in placeholders.values()]) + def _replace_format_or_qmark_placeholders(self): + if len(self._placeholders) != len(self._args): + placeholders = ", ".join([str(token) for token in self._placeholders.values()]) _args = ", ".join([str(self._sql_sanitizer.escape(arg)) for arg in self._args]) - if len(placeholders) < len(self._args): - raise RuntimeError(f"fewer placeholders ({_placeholders}) than values ({_args})") + if len(self._placeholders) < len(self._args): + raise RuntimeError(f"fewer placeholders ({placeholders}) than values ({_args})") - raise RuntimeError(f"more placeholders ({_placeholders}) than values ({_args})") + raise RuntimeError(f"more placeholders ({placeholders}) than values ({_args})") - for arg_index, token_index in enumerate(placeholders.keys()): + for arg_index, token_index in enumerate(self._placeholders.keys()): self._tokens[token_index] = self._sql_sanitizer.escape(self._args[arg_index]) - def _replace_numeric_placeholders(self, placeholders): + def _replace_numeric_placeholders(self): unused_arg_idxs = set(range(len(self._args))) - for token_index, num in placeholders.items(): + for token_index, num in self._placeholders.items(): if num >= len(self._args): raise RuntimeError(f"missing value for placeholder ({num + 1})") @@ -112,9 +106,9 @@ def _replace_numeric_placeholders(self, placeholders): f"unused value{'' if len(unused_arg_idxs) == 1 else 's'} ({unused_args})") - def _replace_named_or_pyformat_placeholders(self, placeholders): + def _replace_named_or_pyformat_placeholders(self): unused_params = set(self._kwargs.keys()) - for token_index, param_name in placeholders.items(): + for token_index, param_name in self._placeholders.items(): if param_name not in self._kwargs: raise RuntimeError(f"missing value for placeholder ({param_name})") @@ -133,6 +127,19 @@ def _escape_verbatim_colons(self): token.value = escape_verbatim_colon(token.value) + def _get_operation_keyword(self): + for token in self._statement: + if _is_operation_token(token): + token_value = token.value.upper() + if token_value in {"BEGIN", "DELETE", "INSERT", "SELECT", "START", "UPDATE"}: + operation_keyword = token_value + break + else: + operation_keyword = None + + return operation_keyword + + def get_operation_keyword(self): """Returns the operation keyword of the statement (e.g., SELECT) if found, or None""" return self._operation_keyword @@ -142,7 +149,7 @@ def __str__(self): return "".join([str(token) for token in self._tokens]) -def _parse(sql): +def _format_and_parse(sql): formatted_statements = sqlparse.format(sql, strip_comments=True).strip() parsed_statements = sqlparse.parse(formatted_statements) statement_count = len(parsed_statements) From db4cce1c25f9b9a3349c696b13078073e1ed8cd6 Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Mon, 12 Apr 2021 21:09:22 -0400 Subject: [PATCH 058/159] pass token type/value around --- src/cs50/_statement.py | 42 +++++++++++++++++++++--------------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/src/cs50/_statement.py b/src/cs50/_statement.py index 34f9247..0d62266 100644 --- a/src/cs50/_statement.py +++ b/src/cs50/_statement.py @@ -44,8 +44,8 @@ def _replace_placeholders_with_params(self): def _get_paramstyle(self): paramstyle = None for token in self._tokens: - if _is_placeholder(token): - paramstyle, _ = _parse_placeholder(token) + if _is_placeholder(token.ttype): + paramstyle, _ = _parse_placeholder(token.value) break if paramstyle is None: @@ -57,8 +57,8 @@ def _get_paramstyle(self): def _get_placeholders(self): placeholders = collections.OrderedDict() for index, token in enumerate(self._tokens): - if _is_placeholder(token): - paramstyle, name = _parse_placeholder(token) + if _is_placeholder(token.ttype): + paramstyle, name = _parse_placeholder(token.value) if paramstyle != self._paramstyle: raise RuntimeError("inconsistent paramstyle") @@ -123,13 +123,13 @@ def _replace_named_or_pyformat_placeholders(self): def _escape_verbatim_colons(self): for token in self._tokens: - if _is_string_literal(token) or _is_identifier(token): + if _is_string_literal(token.ttype) or _is_identifier(token.ttype): token.value = escape_verbatim_colon(token.value) def _get_operation_keyword(self): for token in self._statement: - if _is_operation_token(token): + if _is_operation_token(token.ttype): token_value = token.value.upper() if token_value in {"BEGIN", "DELETE", "INSERT", "SELECT", "START", "UPDATE"}: operation_keyword = token_value @@ -161,46 +161,46 @@ def _format_and_parse(sql): return parsed_statements[0] -def _is_placeholder(token): - return token.ttype == sqlparse.tokens.Name.Placeholder +def _is_placeholder(ttype): + return ttype == sqlparse.tokens.Name.Placeholder -def _parse_placeholder(token): - if token.value == "?": +def _parse_placeholder(value): + if value == "?": return _Paramstyle.QMARK, None # E.g., :1 - matches = re.search(r"^:([1-9]\d*)$", token.value) + matches = re.search(r"^:([1-9]\d*)$", value) if matches: return _Paramstyle.NUMERIC, int(matches.group(1)) - 1 # E.g., :foo - matches = re.search(r"^:([a-zA-Z]\w*)$", token.value) + matches = re.search(r"^:([a-zA-Z]\w*)$", value) if matches: return _Paramstyle.NAMED, matches.group(1) - if token.value == "%s": + if value == "%s": return _Paramstyle.FORMAT, None # E.g., %(foo)s - matches = re.search(r"%\((\w+)\)s$", token.value) + matches = re.search(r"%\((\w+)\)s$", value) if matches: return _Paramstyle.PYFORMAT, matches.group(1) - raise RuntimeError(f"{token.value}: invalid placeholder") + raise RuntimeError(f"{value}: invalid placeholder") -def _is_operation_token(token): - return token.ttype in { +def _is_operation_token(ttype): + return ttype in { sqlparse.tokens.Keyword, sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML} -def _is_string_literal(token): - return token.ttype in [sqlparse.tokens.Literal.String, sqlparse.tokens.Literal.String.Single] +def _is_string_literal(ttype): + return ttype in [sqlparse.tokens.Literal.String, sqlparse.tokens.Literal.String.Single] -def _is_identifier(token): - return token.ttype == sqlparse.tokens.Literal.String.Symbol +def _is_identifier(ttype): + return ttype == sqlparse.tokens.Literal.String.Symbol class _Paramstyle(enum.Enum): From 88adfb95eb760e1d7be01b3c55566560d81125a8 Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Mon, 12 Apr 2021 21:16:04 -0400 Subject: [PATCH 059/159] rename methods --- src/cs50/_statement.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/cs50/_statement.py b/src/cs50/_statement.py index 0d62266..931504d 100644 --- a/src/cs50/_statement.py +++ b/src/cs50/_statement.py @@ -1,4 +1,4 @@ -"""Parses a SQL statement and replaces placeholders with parameters""" +"""Parses a SQL statement and replaces the placeholders with the corresponding parameters""" import collections import enum @@ -22,7 +22,7 @@ def __init__(self, dialect, sql, *args, **kwargs): self._tokens = self._tokenize() self._paramstyle = self._get_paramstyle() self._placeholders = self._get_placeholders() - self._replace_placeholders_with_params() + self._plugin_escaped_params() self._operation_keyword = self._get_operation_keyword() @@ -30,13 +30,13 @@ def _tokenize(self): return list(self._statement.flatten()) - def _replace_placeholders_with_params(self): + def _plugin_escaped_params(self): if self._paramstyle in {_Paramstyle.FORMAT, _Paramstyle.QMARK}: - self._replace_format_or_qmark_placeholders() + self._plugin_format_or_qmark_params() elif self._paramstyle == _Paramstyle.NUMERIC: - self._replace_numeric_placeholders() + self._plugin_numeric_params() if self._paramstyle in {_Paramstyle.NAMED, _Paramstyle.PYFORMAT}: - self._replace_named_or_pyformat_placeholders() + self._plugin_named_or_pyformat_params() self._escape_verbatim_colons() @@ -77,7 +77,7 @@ def _default_paramstyle(self): return paramstyle - def _replace_format_or_qmark_placeholders(self): + def _plugin_format_or_qmark_params(self): if len(self._placeholders) != len(self._args): placeholders = ", ".join([str(token) for token in self._placeholders.values()]) _args = ", ".join([str(self._sql_sanitizer.escape(arg)) for arg in self._args]) @@ -90,7 +90,7 @@ def _replace_format_or_qmark_placeholders(self): self._tokens[token_index] = self._sql_sanitizer.escape(self._args[arg_index]) - def _replace_numeric_placeholders(self): + def _plugin_numeric_params(self): unused_arg_idxs = set(range(len(self._args))) for token_index, num in self._placeholders.items(): if num >= len(self._args): @@ -106,7 +106,7 @@ def _replace_numeric_placeholders(self): f"unused value{'' if len(unused_arg_idxs) == 1 else 's'} ({unused_args})") - def _replace_named_or_pyformat_placeholders(self): + def _plugin_named_or_pyformat_params(self): unused_params = set(self._kwargs.keys()) for token_index, param_name in self._placeholders.items(): if param_name not in self._kwargs: From a4a88108f652ca7fb253d056ac449adacc02b5a4 Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Mon, 12 Apr 2021 21:17:33 -0400 Subject: [PATCH 060/159] move escape_verbatim_colons to constructor --- src/cs50/_statement.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/cs50/_statement.py b/src/cs50/_statement.py index 931504d..5fc41e7 100644 --- a/src/cs50/_statement.py +++ b/src/cs50/_statement.py @@ -23,6 +23,7 @@ def __init__(self, dialect, sql, *args, **kwargs): self._paramstyle = self._get_paramstyle() self._placeholders = self._get_placeholders() self._plugin_escaped_params() + self._escape_verbatim_colons() self._operation_keyword = self._get_operation_keyword() @@ -38,8 +39,6 @@ def _plugin_escaped_params(self): if self._paramstyle in {_Paramstyle.NAMED, _Paramstyle.PYFORMAT}: self._plugin_named_or_pyformat_params() - self._escape_verbatim_colons() - def _get_paramstyle(self): paramstyle = None From f618840dbf0918a060aa8bcbc712a2b4a1c665d2 Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Mon, 12 Apr 2021 21:24:39 -0400 Subject: [PATCH 061/159] reorder methods --- src/cs50/_statement.py | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/src/cs50/_statement.py b/src/cs50/_statement.py index 5fc41e7..dc4013b 100644 --- a/src/cs50/_statement.py +++ b/src/cs50/_statement.py @@ -31,15 +31,6 @@ def _tokenize(self): return list(self._statement.flatten()) - def _plugin_escaped_params(self): - if self._paramstyle in {_Paramstyle.FORMAT, _Paramstyle.QMARK}: - self._plugin_format_or_qmark_params() - elif self._paramstyle == _Paramstyle.NUMERIC: - self._plugin_numeric_params() - if self._paramstyle in {_Paramstyle.NAMED, _Paramstyle.PYFORMAT}: - self._plugin_named_or_pyformat_params() - - def _get_paramstyle(self): paramstyle = None for token in self._tokens: @@ -53,6 +44,16 @@ def _get_paramstyle(self): return paramstyle + def _default_paramstyle(self): + paramstyle = None + if self._args: + paramstyle = _Paramstyle.QMARK + elif self._kwargs: + paramstyle = _Paramstyle.NAMED + + return paramstyle + + def _get_placeholders(self): placeholders = collections.OrderedDict() for index, token in enumerate(self._tokens): @@ -66,14 +67,13 @@ def _get_placeholders(self): return placeholders - def _default_paramstyle(self): - paramstyle = None - if self._args: - paramstyle = _Paramstyle.QMARK - elif self._kwargs: - paramstyle = _Paramstyle.NAMED - - return paramstyle + def _plugin_escaped_params(self): + if self._paramstyle in {_Paramstyle.FORMAT, _Paramstyle.QMARK}: + self._plugin_format_or_qmark_params() + elif self._paramstyle == _Paramstyle.NUMERIC: + self._plugin_numeric_params() + if self._paramstyle in {_Paramstyle.NAMED, _Paramstyle.PYFORMAT}: + self._plugin_named_or_pyformat_params() def _plugin_format_or_qmark_params(self): From 9a153fe9dbca0157958bb577458ec53505c03421 Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Mon, 12 Apr 2021 22:17:36 -0400 Subject: [PATCH 062/159] use else --- src/cs50/_statement.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/cs50/_statement.py b/src/cs50/_statement.py index dc4013b..f0ed325 100644 --- a/src/cs50/_statement.py +++ b/src/cs50/_statement.py @@ -37,8 +37,7 @@ def _get_paramstyle(self): if _is_placeholder(token.ttype): paramstyle, _ = _parse_placeholder(token.value) break - - if paramstyle is None: + else: paramstyle = self._default_paramstyle() return paramstyle From 456cea56083d586edae6ccedfe525bf0f64c4f77 Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Mon, 12 Apr 2021 22:56:48 -0400 Subject: [PATCH 063/159] escape args and kwargs in constructor --- src/cs50/_statement.py | 54 +++++++++++++++++++++++++++--------------- 1 file changed, 35 insertions(+), 19 deletions(-) diff --git a/src/cs50/_statement.py b/src/cs50/_statement.py index f0ed325..d02f844 100644 --- a/src/cs50/_statement.py +++ b/src/cs50/_statement.py @@ -16,8 +16,8 @@ def __init__(self, dialect, sql, *args, **kwargs): raise RuntimeError("cannot pass both positional and named parameters") self._sql_sanitizer = SQLSanitizer(dialect) - self._args = args - self._kwargs = kwargs + self._args = self._get_escaped_args(args) + self._kwargs = self._get_escaped_kwargs(kwargs) self._statement = _format_and_parse(sql) self._tokens = self._tokenize() self._paramstyle = self._get_paramstyle() @@ -27,6 +27,14 @@ def __init__(self, dialect, sql, *args, **kwargs): self._operation_keyword = self._get_operation_keyword() + def _get_escaped_args(self, args): + return [self._sql_sanitizer.escape(arg) for arg in args] + + + def _get_escaped_kwargs(self, kwargs): + return {k: self._sql_sanitizer.escape(v) for k, v in kwargs.items()} + + def _tokenize(self): return list(self._statement.flatten()) @@ -76,32 +84,33 @@ def _plugin_escaped_params(self): def _plugin_format_or_qmark_params(self): + self._assert_valid_arg_count() + for arg_index, token_index in enumerate(self._placeholders.keys()): + self._tokens[token_index] = self._args[arg_index] + + + def _assert_valid_arg_count(self): if len(self._placeholders) != len(self._args): - placeholders = ", ".join([str(token) for token in self._placeholders.values()]) - _args = ", ".join([str(self._sql_sanitizer.escape(arg)) for arg in self._args]) + placeholders = _get_human_readable_list(self._placeholders.values()) + args = _get_human_readable_list(self._args) if len(self._placeholders) < len(self._args): - raise RuntimeError(f"fewer placeholders ({placeholders}) than values ({_args})") + raise RuntimeError(f"fewer placeholders ({placeholders}) than values ({args})") - raise RuntimeError(f"more placeholders ({placeholders}) than values ({_args})") - - for arg_index, token_index in enumerate(self._placeholders.keys()): - self._tokens[token_index] = self._sql_sanitizer.escape(self._args[arg_index]) + raise RuntimeError(f"more placeholders ({placeholders}) than values ({args})") def _plugin_numeric_params(self): - unused_arg_idxs = set(range(len(self._args))) + unused_arg_indices = set(range(len(self._args))) for token_index, num in self._placeholders.items(): if num >= len(self._args): raise RuntimeError(f"missing value for placeholder ({num + 1})") - self._tokens[token_index] = self._sql_sanitizer.escape(self._args[num]) - unused_arg_idxs.remove(num) + self._tokens[token_index] = self._args[num] + unused_arg_indices.remove(num) - if len(unused_arg_idxs) > 0: - unused_args = ", ".join( - [str(self._sql_sanitizer.escape(self._args[i])) for i in sorted(unused_arg_idxs)]) - raise RuntimeError( - f"unused value{'' if len(unused_arg_idxs) == 1 else 's'} ({unused_args})") + if len(unused_arg_indices) > 0: + unused_args = _get_human_readable_list([self._args[i] for i in sorted(unused_arg_indices)]) + raise RuntimeError(f"unused value{'' if len(unused_args) == 1 else 's'} ({unused_args})") def _plugin_named_or_pyformat_params(self): @@ -110,11 +119,11 @@ def _plugin_named_or_pyformat_params(self): if param_name not in self._kwargs: raise RuntimeError(f"missing value for placeholder ({param_name})") - self._tokens[token_index] = self._sql_sanitizer.escape(self._kwargs[param_name]) + self._tokens[token_index] = self._kwargs[param_name] unused_params.remove(param_name) if len(unused_params) > 0: - joined_unused_params = ", ".join(sorted(unused_params)) + joined_unused_params = _get_human_readable_list(sorted(unused_params)) raise RuntimeError( f"unused value{'' if len(unused_params) == 1 else 's'} ({joined_unused_params})") @@ -147,6 +156,9 @@ def __str__(self): return "".join([str(token) for token in self._tokens]) + + + def _format_and_parse(sql): formatted_statements = sqlparse.format(sql, strip_comments=True).strip() parsed_statements = sqlparse.parse(formatted_statements) @@ -201,6 +213,10 @@ def _is_identifier(ttype): return ttype == sqlparse.tokens.Literal.String.Symbol +def _get_human_readable_list(iterable): + return ", ".join(str(v) for v in iterable) + + class _Paramstyle(enum.Enum): FORMAT = enum.auto() NAMED = enum.auto() From 713368f0d5b44c9ce22739410a96b77ae08c237b Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Mon, 12 Apr 2021 22:59:37 -0400 Subject: [PATCH 064/159] reorder methods --- src/cs50/_statement.py | 32 ++++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/src/cs50/_statement.py b/src/cs50/_statement.py index d02f844..50e6673 100644 --- a/src/cs50/_statement.py +++ b/src/cs50/_statement.py @@ -16,15 +16,19 @@ def __init__(self, dialect, sql, *args, **kwargs): raise RuntimeError("cannot pass both positional and named parameters") self._sql_sanitizer = SQLSanitizer(dialect) + self._args = self._get_escaped_args(args) self._kwargs = self._get_escaped_kwargs(kwargs) + self._statement = _format_and_parse(sql) self._tokens = self._tokenize() + + self._operation_keyword = self._get_operation_keyword() + self._paramstyle = self._get_paramstyle() self._placeholders = self._get_placeholders() self._plugin_escaped_params() self._escape_verbatim_colons() - self._operation_keyword = self._get_operation_keyword() def _get_escaped_args(self, args): @@ -39,6 +43,19 @@ def _tokenize(self): return list(self._statement.flatten()) + def _get_operation_keyword(self): + for token in self._statement: + if _is_operation_token(token.ttype): + token_value = token.value.upper() + if token_value in {"BEGIN", "DELETE", "INSERT", "SELECT", "START", "UPDATE"}: + operation_keyword = token_value + break + else: + operation_keyword = None + + return operation_keyword + + def _get_paramstyle(self): paramstyle = None for token in self._tokens: @@ -134,19 +151,6 @@ def _escape_verbatim_colons(self): token.value = escape_verbatim_colon(token.value) - def _get_operation_keyword(self): - for token in self._statement: - if _is_operation_token(token.ttype): - token_value = token.value.upper() - if token_value in {"BEGIN", "DELETE", "INSERT", "SELECT", "START", "UPDATE"}: - operation_keyword = token_value - break - else: - operation_keyword = None - - return operation_keyword - - def get_operation_keyword(self): """Returns the operation keyword of the statement (e.g., SELECT) if found, or None""" return self._operation_keyword From 2187d95b51f9020c98cb6feb5102db375633196d Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Mon, 12 Apr 2021 23:15:01 -0400 Subject: [PATCH 065/159] refactor logger --- src/cs50/_logger.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/src/cs50/_logger.py b/src/cs50/_logger.py index df021a3..1307e19 100644 --- a/src/cs50/_logger.py +++ b/src/cs50/_logger.py @@ -10,17 +10,26 @@ def _setup_logger(): - # Configure default logging handler and formatter - # Prevent flask, werkzeug, etc from adding default handler + _configure_default_logger() + _patch_root_handler_format_exception() + _configure_cs50_logger() + _patch_excepthook() + + +def _configure_default_logger(): + """Configure default handler and formatter to prevent flask and werkzeug from adding theirs""" logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.DEBUG) + +def _patch_root_handler_format_exception(): try: - # Patch formatException formatter = logging.root.handlers[0].formatter formatter.formatException = lambda exc_info: _format_exception(*exc_info) except IndexError: pass + +def _configure_cs50_logger(): _logger = logging.getLogger("cs50") _logger.disabled = True _logger.setLevel(logging.DEBUG) @@ -36,6 +45,8 @@ def _setup_logger(): handler.setFormatter(formatter) _logger.addHandler(handler) + +def _patch_excepthook(): sys.excepthook = lambda type_, value, exc_tb: print( _format_exception(type_, value, exc_tb), file=sys.stderr) From 758910e7374ae6d8b7a35636d6c122318648d9c7 Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Mon, 12 Apr 2021 23:30:30 -0400 Subject: [PATCH 066/159] fix style --- src/cs50/_session.py | 4 ++-- src/cs50/_sql_sanitizer.py | 2 -- src/cs50/_statement.py | 25 +++++-------------------- src/cs50/cs50.py | 4 +++- src/cs50/sql.py | 6 +----- 5 files changed, 11 insertions(+), 30 deletions(-) diff --git a/src/cs50/_session.py b/src/cs50/_session.py index cd23453..4c63b39 100644 --- a/src/cs50/_session.py +++ b/src/cs50/_session.py @@ -6,21 +6,21 @@ import sqlalchemy import sqlalchemy.orm + class Session: """Wraps a SQLAlchemy scoped session""" + def __init__(self, url, **engine_kwargs): if _is_sqlite_url(url): _assert_sqlite_file_exists(url) self._session = _create_session(url, **engine_kwargs) - def execute(self, statement): """Converts statement to str and executes it""" # pylint: disable=no-member return self._session.execute(sqlalchemy.text(str(statement))) - def __getattr__(self, attr): return getattr(self._session, attr) diff --git a/src/cs50/_sql_sanitizer.py b/src/cs50/_sql_sanitizer.py index c2f35c4..f4ff3e0 100644 --- a/src/cs50/_sql_sanitizer.py +++ b/src/cs50/_sql_sanitizer.py @@ -13,7 +13,6 @@ class SQLSanitizer: def __init__(self, dialect): self._dialect = dialect - def escape(self, value): """ Escapes value using engine's conversion function. @@ -71,7 +70,6 @@ def escape(self, value): raise RuntimeError(f"unsupported value: {value}") - def escape_iterable(self, iterable): """Escapes a collection of values (e.g., list, tuple)""" return sqlparse.sql.TokenList( diff --git a/src/cs50/_statement.py b/src/cs50/_statement.py index 50e6673..9f9fae8 100644 --- a/src/cs50/_statement.py +++ b/src/cs50/_statement.py @@ -11,6 +11,7 @@ class Statement: """Parses a SQL statement and replaces the placeholders with the corresponding parameters""" + def __init__(self, dialect, sql, *args, **kwargs): if len(args) > 0 and len(kwargs) > 0: raise RuntimeError("cannot pass both positional and named parameters") @@ -30,19 +31,15 @@ def __init__(self, dialect, sql, *args, **kwargs): self._plugin_escaped_params() self._escape_verbatim_colons() - def _get_escaped_args(self, args): return [self._sql_sanitizer.escape(arg) for arg in args] - def _get_escaped_kwargs(self, kwargs): return {k: self._sql_sanitizer.escape(v) for k, v in kwargs.items()} - def _tokenize(self): return list(self._statement.flatten()) - def _get_operation_keyword(self): for token in self._statement: if _is_operation_token(token.ttype): @@ -55,7 +52,6 @@ def _get_operation_keyword(self): return operation_keyword - def _get_paramstyle(self): paramstyle = None for token in self._tokens: @@ -67,7 +63,6 @@ def _get_paramstyle(self): return paramstyle - def _default_paramstyle(self): paramstyle = None if self._args: @@ -77,7 +72,6 @@ def _default_paramstyle(self): return paramstyle - def _get_placeholders(self): placeholders = collections.OrderedDict() for index, token in enumerate(self._tokens): @@ -90,7 +84,6 @@ def _get_placeholders(self): return placeholders - def _plugin_escaped_params(self): if self._paramstyle in {_Paramstyle.FORMAT, _Paramstyle.QMARK}: self._plugin_format_or_qmark_params() @@ -99,13 +92,11 @@ def _plugin_escaped_params(self): if self._paramstyle in {_Paramstyle.NAMED, _Paramstyle.PYFORMAT}: self._plugin_named_or_pyformat_params() - def _plugin_format_or_qmark_params(self): self._assert_valid_arg_count() for arg_index, token_index in enumerate(self._placeholders.keys()): self._tokens[token_index] = self._args[arg_index] - def _assert_valid_arg_count(self): if len(self._placeholders) != len(self._args): placeholders = _get_human_readable_list(self._placeholders.values()) @@ -115,7 +106,6 @@ def _assert_valid_arg_count(self): raise RuntimeError(f"more placeholders ({placeholders}) than values ({args})") - def _plugin_numeric_params(self): unused_arg_indices = set(range(len(self._args))) for token_index, num in self._placeholders.items(): @@ -126,9 +116,10 @@ def _plugin_numeric_params(self): unused_arg_indices.remove(num) if len(unused_arg_indices) > 0: - unused_args = _get_human_readable_list([self._args[i] for i in sorted(unused_arg_indices)]) - raise RuntimeError(f"unused value{'' if len(unused_args) == 1 else 's'} ({unused_args})") - + unused_args = _get_human_readable_list( + [self._args[i] for i in sorted(unused_arg_indices)]) + raise RuntimeError( + f"unused value{'' if len(unused_args) == 1 else 's'} ({unused_args})") def _plugin_named_or_pyformat_params(self): unused_params = set(self._kwargs.keys()) @@ -144,25 +135,19 @@ def _plugin_named_or_pyformat_params(self): raise RuntimeError( f"unused value{'' if len(unused_params) == 1 else 's'} ({joined_unused_params})") - def _escape_verbatim_colons(self): for token in self._tokens: if _is_string_literal(token.ttype) or _is_identifier(token.ttype): token.value = escape_verbatim_colon(token.value) - def get_operation_keyword(self): """Returns the operation keyword of the statement (e.g., SELECT) if found, or None""" return self._operation_keyword - def __str__(self): return "".join([str(token) for token in self._tokens]) - - - def _format_and_parse(sql): formatted_statements = sqlparse.format(sql, strip_comments=True).strip() parsed_statements = sqlparse.parse(formatted_statements) diff --git a/src/cs50/cs50.py b/src/cs50/cs50.py index 24c748b..30d3515 100644 --- a/src/cs50/cs50.py +++ b/src/cs50/cs50.py @@ -36,7 +36,7 @@ def get_int(prompt): """ while True: try: - return _get_int(prompt) + return _get_int(prompt) except (MemoryError, ValueError): pass @@ -89,9 +89,11 @@ def write(self, data): self.stream.write(data) self.stream.flush() + def disable_output_buffering(): """Disables output buffering to prevent prompts from being buffered""" sys.stderr = _flushfile(sys.stderr) sys.stdout = _flushfile(sys.stdout) + disable_output_buffering() diff --git a/src/cs50/sql.py b/src/cs50/sql.py index fca57d2..8547aca 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -15,13 +15,13 @@ class SQL: """Wraps SQLAlchemy""" + def __init__(self, url, **engine_kwargs): self._session = Session(url, **engine_kwargs) self._dialect = self._session.get_bind().dialect self._is_postgres = self._dialect.name in {"postgres", "postgresql"} self._autocommit = False - def execute(self, sql, *args, **kwargs): """Execute a SQL statement.""" statement = Statement(self._dialect, sql, *args, **kwargs) @@ -53,7 +53,6 @@ def execute(self, sql, *args, **kwargs): return ret - def _execute(self, statement): # Catch SQLAlchemy warnings with warnings.catch_warnings(): @@ -72,20 +71,17 @@ def _execute(self, statement): _logger.debug(termcolor.colored(str(statement), "green")) return result - def _last_row_id_or_none(self, result): if self._is_postgres: return self._get_last_val() return result.lastrowid if result.rowcount == 1 else None - def _get_last_val(self): try: return self._session.execute("SELECT LASTVAL()").first()[0] except sqlalchemy.exc.OperationalError: # If lastval is not yet defined in this session return None - def init_app(self, app): """Registers a teardown_appcontext listener to remove session and enables logging""" @app.teardown_appcontext From 32db777581af3929a69ab3aaded835cf7823e9a9 Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Tue, 13 Apr 2021 07:04:58 -0400 Subject: [PATCH 067/159] factor out utility functions --- src/cs50/_statement.py | 80 +++++-------------------------------- src/cs50/_statement_util.py | 72 +++++++++++++++++++++++++++++++++ 2 files changed, 82 insertions(+), 70 deletions(-) create mode 100644 src/cs50/_statement_util.py diff --git a/src/cs50/_statement.py b/src/cs50/_statement.py index 9f9fae8..c673719 100644 --- a/src/cs50/_statement.py +++ b/src/cs50/_statement.py @@ -1,12 +1,18 @@ """Parses a SQL statement and replaces the placeholders with the corresponding parameters""" import collections -import enum -import re - -import sqlparse from ._sql_sanitizer import SQLSanitizer, escape_verbatim_colon +from ._statement_util import ( + _format_and_parse, + _get_human_readable_list, + _is_identifier, + _is_operation_token, + _is_placeholder, + _is_string_literal, + _Paramstyle, + _parse_placeholder, +) class Statement: @@ -146,69 +152,3 @@ def get_operation_keyword(self): def __str__(self): return "".join([str(token) for token in self._tokens]) - - -def _format_and_parse(sql): - formatted_statements = sqlparse.format(sql, strip_comments=True).strip() - parsed_statements = sqlparse.parse(formatted_statements) - statement_count = len(parsed_statements) - if statement_count == 0: - raise RuntimeError("missing statement") - if statement_count > 1: - raise RuntimeError("too many statements at once") - - return parsed_statements[0] - - -def _is_placeholder(ttype): - return ttype == sqlparse.tokens.Name.Placeholder - - -def _parse_placeholder(value): - if value == "?": - return _Paramstyle.QMARK, None - - # E.g., :1 - matches = re.search(r"^:([1-9]\d*)$", value) - if matches: - return _Paramstyle.NUMERIC, int(matches.group(1)) - 1 - - # E.g., :foo - matches = re.search(r"^:([a-zA-Z]\w*)$", value) - if matches: - return _Paramstyle.NAMED, matches.group(1) - - if value == "%s": - return _Paramstyle.FORMAT, None - - # E.g., %(foo)s - matches = re.search(r"%\((\w+)\)s$", value) - if matches: - return _Paramstyle.PYFORMAT, matches.group(1) - - raise RuntimeError(f"{value}: invalid placeholder") - - -def _is_operation_token(ttype): - return ttype in { - sqlparse.tokens.Keyword, sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML} - - -def _is_string_literal(ttype): - return ttype in [sqlparse.tokens.Literal.String, sqlparse.tokens.Literal.String.Single] - - -def _is_identifier(ttype): - return ttype == sqlparse.tokens.Literal.String.Symbol - - -def _get_human_readable_list(iterable): - return ", ".join(str(v) for v in iterable) - - -class _Paramstyle(enum.Enum): - FORMAT = enum.auto() - NAMED = enum.auto() - NUMERIC = enum.auto() - PYFORMAT = enum.auto() - QMARK = enum.auto() diff --git a/src/cs50/_statement_util.py b/src/cs50/_statement_util.py new file mode 100644 index 0000000..f299767 --- /dev/null +++ b/src/cs50/_statement_util.py @@ -0,0 +1,72 @@ +"""Utility functions used by _statement.py""" + +import enum +import re + +import sqlparse + + +class _Paramstyle(enum.Enum): + FORMAT = enum.auto() + NAMED = enum.auto() + NUMERIC = enum.auto() + PYFORMAT = enum.auto() + QMARK = enum.auto() + + +def _format_and_parse(sql): + formatted_statements = sqlparse.format(sql, strip_comments=True).strip() + parsed_statements = sqlparse.parse(formatted_statements) + statement_count = len(parsed_statements) + if statement_count == 0: + raise RuntimeError("missing statement") + if statement_count > 1: + raise RuntimeError("too many statements at once") + + return parsed_statements[0] + + +def _is_placeholder(ttype): + return ttype == sqlparse.tokens.Name.Placeholder + + +def _parse_placeholder(value): + if value == "?": + return _Paramstyle.QMARK, None + + # E.g., :1 + matches = re.search(r"^:([1-9]\d*)$", value) + if matches: + return _Paramstyle.NUMERIC, int(matches.group(1)) - 1 + + # E.g., :foo + matches = re.search(r"^:([a-zA-Z]\w*)$", value) + if matches: + return _Paramstyle.NAMED, matches.group(1) + + if value == "%s": + return _Paramstyle.FORMAT, None + + # E.g., %(foo)s + matches = re.search(r"%\((\w+)\)s$", value) + if matches: + return _Paramstyle.PYFORMAT, matches.group(1) + + raise RuntimeError(f"{value}: invalid placeholder") + + +def _is_operation_token(ttype): + return ttype in { + sqlparse.tokens.Keyword, sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML} + + +def _is_string_literal(ttype): + return ttype in [sqlparse.tokens.Literal.String, sqlparse.tokens.Literal.String.Single] + + +def _is_identifier(ttype): + return ttype == sqlparse.tokens.Literal.String.Symbol + + +def _get_human_readable_list(iterable): + return ", ".join(str(v) for v in iterable) From 1f622834ddecd571a0c49190a190035f4bbb95e3 Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Tue, 13 Apr 2021 07:08:13 -0400 Subject: [PATCH 068/159] factor out session utility functions --- src/cs50/_session.py | 67 ++++----------------------------------- src/cs50/_session_util.py | 63 ++++++++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+), 61 deletions(-) create mode 100644 src/cs50/_session_util.py diff --git a/src/cs50/_session.py b/src/cs50/_session.py index 4c63b39..a67f02d 100644 --- a/src/cs50/_session.py +++ b/src/cs50/_session.py @@ -1,11 +1,14 @@ """Wraps a SQLAlchemy scoped session""" -import os -import sqlite3 - import sqlalchemy import sqlalchemy.orm +from ._session_util import ( + _is_sqlite_url, + _assert_sqlite_file_exists, + _create_session, +) + class Session: """Wraps a SQLAlchemy scoped session""" @@ -23,61 +26,3 @@ def execute(self, statement): def __getattr__(self, attr): return getattr(self._session, attr) - - -def _is_sqlite_url(url): - return url.startswith("sqlite:///") - - -def _assert_sqlite_file_exists(url): - path = url[len("sqlite:///"):] - if not os.path.exists(path): - raise RuntimeError(f"does not exist: {path}") - if not os.path.isfile(path): - raise RuntimeError(f"not a file: {path}") - - -def _create_session(url, **engine_kwargs): - engine = _create_engine(url, **engine_kwargs) - _setup_on_connect(engine) - return _create_scoped_session(engine) - - -def _create_engine(url, **kwargs): - try: - engine = sqlalchemy.create_engine(url, **kwargs) - except sqlalchemy.exc.ArgumentError: - raise RuntimeError(f"invalid URL: {url}") from None - - engine.execution_options(autocommit=False) - return engine - - -def _setup_on_connect(engine): - def connect(dbapi_connection, _): - _disable_auto_begin_commit(dbapi_connection) - if _is_sqlite_connection(dbapi_connection): - _enable_sqlite_foreign_key_constraints(dbapi_connection) - - sqlalchemy.event.listen(engine, "connect", connect) - - -def _create_scoped_session(engine): - session_factory = sqlalchemy.orm.sessionmaker(bind=engine) - return sqlalchemy.orm.scoping.scoped_session(session_factory) - - -def _disable_auto_begin_commit(dbapi_connection): - # Disable underlying API's own emitting of BEGIN and COMMIT so we can ourselves - # https://docs.sqlalchemy.org/en/13/dialects/sqlite.html#serializable-isolation-savepoints-transactional-ddl - dbapi_connection.isolation_level = None - - -def _is_sqlite_connection(dbapi_connection): - return isinstance(dbapi_connection, sqlite3.Connection) - - -def _enable_sqlite_foreign_key_constraints(dbapi_connection): - cursor = dbapi_connection.cursor() - cursor.execute("PRAGMA foreign_keys=ON") - cursor.close() diff --git a/src/cs50/_session_util.py b/src/cs50/_session_util.py new file mode 100644 index 0000000..c0cf33a --- /dev/null +++ b/src/cs50/_session_util.py @@ -0,0 +1,63 @@ +"""Utility functions used by _session.py""" + +import os +import sqlite3 + +import sqlalchemy + +def _is_sqlite_url(url): + return url.startswith("sqlite:///") + + +def _assert_sqlite_file_exists(url): + path = url[len("sqlite:///"):] + if not os.path.exists(path): + raise RuntimeError(f"does not exist: {path}") + if not os.path.isfile(path): + raise RuntimeError(f"not a file: {path}") + + +def _create_session(url, **engine_kwargs): + engine = _create_engine(url, **engine_kwargs) + _setup_on_connect(engine) + return _create_scoped_session(engine) + + +def _create_engine(url, **kwargs): + try: + engine = sqlalchemy.create_engine(url, **kwargs) + except sqlalchemy.exc.ArgumentError: + raise RuntimeError(f"invalid URL: {url}") from None + + engine.execution_options(autocommit=False) + return engine + + +def _setup_on_connect(engine): + def connect(dbapi_connection, _): + _disable_auto_begin_commit(dbapi_connection) + if _is_sqlite_connection(dbapi_connection): + _enable_sqlite_foreign_key_constraints(dbapi_connection) + + sqlalchemy.event.listen(engine, "connect", connect) + + +def _create_scoped_session(engine): + session_factory = sqlalchemy.orm.sessionmaker(bind=engine) + return sqlalchemy.orm.scoping.scoped_session(session_factory) + + +def _disable_auto_begin_commit(dbapi_connection): + # Disable underlying API's own emitting of BEGIN and COMMIT so we can ourselves + # https://docs.sqlalchemy.org/en/13/dialects/sqlite.html#serializable-isolation-savepoints-transactional-ddl + dbapi_connection.isolation_level = None + + +def _is_sqlite_connection(dbapi_connection): + return isinstance(dbapi_connection, sqlite3.Connection) + + +def _enable_sqlite_foreign_key_constraints(dbapi_connection): + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() From c0534e27dc005dab42062c39f03751e075aa9b6a Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Tue, 13 Apr 2021 07:11:26 -0400 Subject: [PATCH 069/159] factor out sql utility functions --- src/cs50/_sql_util.py | 18 ++++++++++++++++++ src/cs50/sql.py | 20 ++------------------ 2 files changed, 20 insertions(+), 18 deletions(-) create mode 100644 src/cs50/_sql_util.py diff --git a/src/cs50/_sql_util.py b/src/cs50/_sql_util.py new file mode 100644 index 0000000..ea3edad --- /dev/null +++ b/src/cs50/_sql_util.py @@ -0,0 +1,18 @@ +"""Utility functions used by sql.py""" + +import decimal + +def fetch_select_result(result): + rows = [dict(row) for row in result.fetchall()] + for row in rows: + for column in row: + # Coerce decimal.Decimal objects to float objects + # https://groups.google.com/d/msg/sqlalchemy/0qXMYJvq8SA/oqtvMD9Uw-kJ + if isinstance(row[column], decimal.Decimal): + row[column] = float(row[column]) + + # Coerce memoryview objects (as from PostgreSQL's bytea columns) to bytes + elif isinstance(row[column], memoryview): + row[column] = bytes(row[column]) + + return rows diff --git a/src/cs50/sql.py b/src/cs50/sql.py index 8547aca..d823c8b 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -1,6 +1,5 @@ """Wraps SQLAlchemy""" -import decimal import logging import warnings @@ -9,6 +8,7 @@ from ._session import Session from ._statement import Statement +from ._sql_util import fetch_select_result _logger = logging.getLogger("cs50") @@ -43,7 +43,7 @@ def execute(self, sql, *args, **kwargs): self._session.remove() if operation_keyword == "SELECT": - ret = _fetch_select_result(result) + ret = fetch_select_result(result) elif operation_keyword == "INSERT": ret = self._last_row_id_or_none(result) elif operation_keyword in {"DELETE", "UPDATE"}: @@ -89,19 +89,3 @@ def _(_): self._session.remove() logging.getLogger("cs50").disabled = False - - -def _fetch_select_result(result): - rows = [dict(row) for row in result.fetchall()] - for row in rows: - for column in row: - # Coerce decimal.Decimal objects to float objects - # https://groups.google.com/d/msg/sqlalchemy/0qXMYJvq8SA/oqtvMD9Uw-kJ - if isinstance(row[column], decimal.Decimal): - row[column] = float(row[column]) - - # Coerce memoryview objects (as from PostgreSQL's bytea columns) to bytes - elif isinstance(row[column], memoryview): - row[column] = bytes(row[column]) - - return rows From 37fba4f5707aa5e57b41f893b2a6d5f209238a4d Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Tue, 13 Apr 2021 07:16:38 -0400 Subject: [PATCH 070/159] remove underscore from util functions --- src/cs50/_session.py | 12 +++++----- src/cs50/_session_util.py | 6 ++--- src/cs50/_statement.py | 48 ++++++++++++++++++------------------- src/cs50/_statement_util.py | 26 ++++++++++---------- 4 files changed, 46 insertions(+), 46 deletions(-) diff --git a/src/cs50/_session.py b/src/cs50/_session.py index a67f02d..0a30c36 100644 --- a/src/cs50/_session.py +++ b/src/cs50/_session.py @@ -4,9 +4,9 @@ import sqlalchemy.orm from ._session_util import ( - _is_sqlite_url, - _assert_sqlite_file_exists, - _create_session, + is_sqlite_url, + assert_sqlite_file_exists, + create_session, ) @@ -14,10 +14,10 @@ class Session: """Wraps a SQLAlchemy scoped session""" def __init__(self, url, **engine_kwargs): - if _is_sqlite_url(url): - _assert_sqlite_file_exists(url) + if is_sqlite_url(url): + assert_sqlite_file_exists(url) - self._session = _create_session(url, **engine_kwargs) + self._session = create_session(url, **engine_kwargs) def execute(self, statement): """Converts statement to str and executes it""" diff --git a/src/cs50/_session_util.py b/src/cs50/_session_util.py index c0cf33a..3433fa9 100644 --- a/src/cs50/_session_util.py +++ b/src/cs50/_session_util.py @@ -5,11 +5,11 @@ import sqlalchemy -def _is_sqlite_url(url): +def is_sqlite_url(url): return url.startswith("sqlite:///") -def _assert_sqlite_file_exists(url): +def assert_sqlite_file_exists(url): path = url[len("sqlite:///"):] if not os.path.exists(path): raise RuntimeError(f"does not exist: {path}") @@ -17,7 +17,7 @@ def _assert_sqlite_file_exists(url): raise RuntimeError(f"not a file: {path}") -def _create_session(url, **engine_kwargs): +def create_session(url, **engine_kwargs): engine = _create_engine(url, **engine_kwargs) _setup_on_connect(engine) return _create_scoped_session(engine) diff --git a/src/cs50/_statement.py b/src/cs50/_statement.py index c673719..ac83758 100644 --- a/src/cs50/_statement.py +++ b/src/cs50/_statement.py @@ -4,14 +4,14 @@ from ._sql_sanitizer import SQLSanitizer, escape_verbatim_colon from ._statement_util import ( - _format_and_parse, - _get_human_readable_list, - _is_identifier, - _is_operation_token, - _is_placeholder, - _is_string_literal, - _Paramstyle, - _parse_placeholder, + format_and_parse, + get_human_readable_list, + is_identifier, + is_operation_token, + is_placeholder, + is_string_literal, + Paramstyle, + parse_placeholder, ) @@ -27,7 +27,7 @@ def __init__(self, dialect, sql, *args, **kwargs): self._args = self._get_escaped_args(args) self._kwargs = self._get_escaped_kwargs(kwargs) - self._statement = _format_and_parse(sql) + self._statement = format_and_parse(sql) self._tokens = self._tokenize() self._operation_keyword = self._get_operation_keyword() @@ -48,7 +48,7 @@ def _tokenize(self): def _get_operation_keyword(self): for token in self._statement: - if _is_operation_token(token.ttype): + if is_operation_token(token.ttype): token_value = token.value.upper() if token_value in {"BEGIN", "DELETE", "INSERT", "SELECT", "START", "UPDATE"}: operation_keyword = token_value @@ -61,8 +61,8 @@ def _get_operation_keyword(self): def _get_paramstyle(self): paramstyle = None for token in self._tokens: - if _is_placeholder(token.ttype): - paramstyle, _ = _parse_placeholder(token.value) + if is_placeholder(token.ttype): + paramstyle, _ = parse_placeholder(token.value) break else: paramstyle = self._default_paramstyle() @@ -72,17 +72,17 @@ def _get_paramstyle(self): def _default_paramstyle(self): paramstyle = None if self._args: - paramstyle = _Paramstyle.QMARK + paramstyle = Paramstyle.QMARK elif self._kwargs: - paramstyle = _Paramstyle.NAMED + paramstyle = Paramstyle.NAMED return paramstyle def _get_placeholders(self): placeholders = collections.OrderedDict() for index, token in enumerate(self._tokens): - if _is_placeholder(token.ttype): - paramstyle, name = _parse_placeholder(token.value) + if is_placeholder(token.ttype): + paramstyle, name = parse_placeholder(token.value) if paramstyle != self._paramstyle: raise RuntimeError("inconsistent paramstyle") @@ -91,11 +91,11 @@ def _get_placeholders(self): return placeholders def _plugin_escaped_params(self): - if self._paramstyle in {_Paramstyle.FORMAT, _Paramstyle.QMARK}: + if self._paramstyle in {Paramstyle.FORMAT, Paramstyle.QMARK}: self._plugin_format_or_qmark_params() - elif self._paramstyle == _Paramstyle.NUMERIC: + elif self._paramstyle == Paramstyle.NUMERIC: self._plugin_numeric_params() - if self._paramstyle in {_Paramstyle.NAMED, _Paramstyle.PYFORMAT}: + if self._paramstyle in {Paramstyle.NAMED, Paramstyle.PYFORMAT}: self._plugin_named_or_pyformat_params() def _plugin_format_or_qmark_params(self): @@ -105,8 +105,8 @@ def _plugin_format_or_qmark_params(self): def _assert_valid_arg_count(self): if len(self._placeholders) != len(self._args): - placeholders = _get_human_readable_list(self._placeholders.values()) - args = _get_human_readable_list(self._args) + placeholders = get_human_readable_list(self._placeholders.values()) + args = get_human_readable_list(self._args) if len(self._placeholders) < len(self._args): raise RuntimeError(f"fewer placeholders ({placeholders}) than values ({args})") @@ -122,7 +122,7 @@ def _plugin_numeric_params(self): unused_arg_indices.remove(num) if len(unused_arg_indices) > 0: - unused_args = _get_human_readable_list( + unused_args = get_human_readable_list( [self._args[i] for i in sorted(unused_arg_indices)]) raise RuntimeError( f"unused value{'' if len(unused_args) == 1 else 's'} ({unused_args})") @@ -137,13 +137,13 @@ def _plugin_named_or_pyformat_params(self): unused_params.remove(param_name) if len(unused_params) > 0: - joined_unused_params = _get_human_readable_list(sorted(unused_params)) + joined_unused_params = get_human_readable_list(sorted(unused_params)) raise RuntimeError( f"unused value{'' if len(unused_params) == 1 else 's'} ({joined_unused_params})") def _escape_verbatim_colons(self): for token in self._tokens: - if _is_string_literal(token.ttype) or _is_identifier(token.ttype): + if is_string_literal(token.ttype) or is_identifier(token.ttype): token.value = escape_verbatim_colon(token.value) def get_operation_keyword(self): diff --git a/src/cs50/_statement_util.py b/src/cs50/_statement_util.py index f299767..81b79e1 100644 --- a/src/cs50/_statement_util.py +++ b/src/cs50/_statement_util.py @@ -6,7 +6,7 @@ import sqlparse -class _Paramstyle(enum.Enum): +class Paramstyle(enum.Enum): FORMAT = enum.auto() NAMED = enum.auto() NUMERIC = enum.auto() @@ -14,7 +14,7 @@ class _Paramstyle(enum.Enum): QMARK = enum.auto() -def _format_and_parse(sql): +def format_and_parse(sql): formatted_statements = sqlparse.format(sql, strip_comments=True).strip() parsed_statements = sqlparse.parse(formatted_statements) statement_count = len(parsed_statements) @@ -26,47 +26,47 @@ def _format_and_parse(sql): return parsed_statements[0] -def _is_placeholder(ttype): +def is_placeholder(ttype): return ttype == sqlparse.tokens.Name.Placeholder -def _parse_placeholder(value): +def parse_placeholder(value): if value == "?": - return _Paramstyle.QMARK, None + return Paramstyle.QMARK, None # E.g., :1 matches = re.search(r"^:([1-9]\d*)$", value) if matches: - return _Paramstyle.NUMERIC, int(matches.group(1)) - 1 + return Paramstyle.NUMERIC, int(matches.group(1)) - 1 # E.g., :foo matches = re.search(r"^:([a-zA-Z]\w*)$", value) if matches: - return _Paramstyle.NAMED, matches.group(1) + return Paramstyle.NAMED, matches.group(1) if value == "%s": - return _Paramstyle.FORMAT, None + return Paramstyle.FORMAT, None # E.g., %(foo)s matches = re.search(r"%\((\w+)\)s$", value) if matches: - return _Paramstyle.PYFORMAT, matches.group(1) + return Paramstyle.PYFORMAT, matches.group(1) raise RuntimeError(f"{value}: invalid placeholder") -def _is_operation_token(ttype): +def is_operation_token(ttype): return ttype in { sqlparse.tokens.Keyword, sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML} -def _is_string_literal(ttype): +def is_string_literal(ttype): return ttype in [sqlparse.tokens.Literal.String, sqlparse.tokens.Literal.String.Single] -def _is_identifier(ttype): +def is_identifier(ttype): return ttype == sqlparse.tokens.Literal.String.Symbol -def _get_human_readable_list(iterable): +def get_human_readable_list(iterable): return ", ".join(str(v) for v in iterable) From e06131c42a3e300cd1a8496dc95a9464381b8961 Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Tue, 13 Apr 2021 07:20:01 -0400 Subject: [PATCH 071/159] fix style --- src/cs50/_session_util.py | 1 + src/cs50/_sql_util.py | 1 + 2 files changed, 2 insertions(+) diff --git a/src/cs50/_session_util.py b/src/cs50/_session_util.py index 3433fa9..ed44eaa 100644 --- a/src/cs50/_session_util.py +++ b/src/cs50/_session_util.py @@ -5,6 +5,7 @@ import sqlalchemy + def is_sqlite_url(url): return url.startswith("sqlite:///") diff --git a/src/cs50/_sql_util.py b/src/cs50/_sql_util.py index ea3edad..238d979 100644 --- a/src/cs50/_sql_util.py +++ b/src/cs50/_sql_util.py @@ -2,6 +2,7 @@ import decimal + def fetch_select_result(result): rows = [dict(row) for row in result.fetchall()] for row in rows: From 86f981f04377c2f81be1fc012c9d59d91c85d5ca Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Tue, 13 Apr 2021 07:21:03 -0400 Subject: [PATCH 072/159] reorder imports --- src/cs50/_session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cs50/_session.py b/src/cs50/_session.py index 0a30c36..c1ea426 100644 --- a/src/cs50/_session.py +++ b/src/cs50/_session.py @@ -4,9 +4,9 @@ import sqlalchemy.orm from ._session_util import ( - is_sqlite_url, assert_sqlite_file_exists, create_session, + is_sqlite_url, ) From 67a7f0c8559ecd719cab86012f443e4d36cd3f8a Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Tue, 13 Apr 2021 11:57:03 -0400 Subject: [PATCH 073/159] remove manual tests --- tests/flask/application.py | 22 --------------- tests/flask/requirements.txt | 2 -- tests/flask/templates/error.html | 10 ------- tests/flask/templates/index.html | 10 ------- tests/foo.py | 48 -------------------------------- tests/mysql.py | 8 ------ tests/python.py | 8 ------ tests/sqlite.py | 44 ----------------------------- tests/tb.py | 10 ------- 9 files changed, 162 deletions(-) delete mode 100644 tests/flask/application.py delete mode 100644 tests/flask/requirements.txt delete mode 100644 tests/flask/templates/error.html delete mode 100644 tests/flask/templates/index.html delete mode 100644 tests/foo.py delete mode 100644 tests/mysql.py delete mode 100644 tests/python.py delete mode 100644 tests/sqlite.py delete mode 100644 tests/tb.py diff --git a/tests/flask/application.py b/tests/flask/application.py deleted file mode 100644 index 939a8f9..0000000 --- a/tests/flask/application.py +++ /dev/null @@ -1,22 +0,0 @@ -import requests -import sys -from flask import Flask, render_template - -sys.path.insert(0, "../../src") - -import cs50 -import cs50.flask - -app = Flask(__name__) - -db = cs50.SQL("sqlite:///../sqlite.db") - -@app.route("/") -def index(): - db.execute("SELECT 1") - """ - def f(): - res = requests.get("cs50.harvard.edu") - f() - """ - return render_template("index.html") diff --git a/tests/flask/requirements.txt b/tests/flask/requirements.txt deleted file mode 100644 index 7d0c101..0000000 --- a/tests/flask/requirements.txt +++ /dev/null @@ -1,2 +0,0 @@ -cs50 -Flask diff --git a/tests/flask/templates/error.html b/tests/flask/templates/error.html deleted file mode 100644 index 3302040..0000000 --- a/tests/flask/templates/error.html +++ /dev/null @@ -1,10 +0,0 @@ -<!DOCTYPE html> - -<html> - <head> - <title>error</title> - </head> - <body> - error - </body> -</html> diff --git a/tests/flask/templates/index.html b/tests/flask/templates/index.html deleted file mode 100644 index 2f6a145..0000000 --- a/tests/flask/templates/index.html +++ /dev/null @@ -1,10 +0,0 @@ -<!DOCTYPE html> - -<html> - <head> - <title>flask</title> - </head> - <body> - flask - </body> -</html> diff --git a/tests/foo.py b/tests/foo.py deleted file mode 100644 index 7f32a00..0000000 --- a/tests/foo.py +++ /dev/null @@ -1,48 +0,0 @@ -import logging -import sys - -sys.path.insert(0, "../src") - -import cs50 - -""" -db = cs50.SQL("sqlite:///foo.db") - -logging.getLogger("cs50").disabled = False - -#db.execute("SELECT ? FROM ? ORDER BY ?", "a", "tbl", "c") -db.execute("CREATE TABLE IF NOT EXISTS bar (firstname STRING)") - -db.execute("INSERT INTO bar VALUES (?)", "baz") -db.execute("INSERT INTO bar VALUES (?)", "qux") -db.execute("SELECT * FROM bar WHERE firstname IN (?)", ("baz", "qux")) -db.execute("DELETE FROM bar") -""" - -db = cs50.SQL("postgresql://postgres@localhost/test") - -""" -print(db.execute("DROP TABLE IF EXISTS cs50")) -print(db.execute("CREATE TABLE cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16), bin BYTEA)")) -print(db.execute("INSERT INTO cs50 (val) VALUES('foo')")) -print(db.execute("SELECT * FROM cs50")) - -print(db.execute("DROP TABLE IF EXISTS cs50")) -print(db.execute("CREATE TABLE cs50 (val VARCHAR(16), bin BYTEA)")) -print(db.execute("INSERT INTO cs50 (val) VALUES('foo')")) -print(db.execute("SELECT * FROM cs50")) -""" - -print(db.execute("DROP TABLE IF EXISTS cs50")) -print(db.execute("CREATE TABLE cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16), bin BYTEA)")) -print(db.execute("INSERT INTO cs50 (val) VALUES('foo')")) -print(db.execute("INSERT INTO cs50 (val) VALUES('bar')")) -print(db.execute("INSERT INTO cs50 (val) VALUES('baz')")) -print(db.execute("SELECT * FROM cs50")) -try: - print(db.execute("INSERT INTO cs50 (id, val) VALUES(1, 'bar')")) -except Exception as e: - print(e) - pass -print(db.execute("INSERT INTO cs50 (val) VALUES('qux')")) -#print(db.execute("DELETE FROM cs50")) diff --git a/tests/mysql.py b/tests/mysql.py deleted file mode 100644 index 2a431c3..0000000 --- a/tests/mysql.py +++ /dev/null @@ -1,8 +0,0 @@ -import sys - -sys.path.insert(0, "../src") - -from cs50 import SQL - -db = SQL("mysql://root@localhost/test") -db.execute("SELECT 1") diff --git a/tests/python.py b/tests/python.py deleted file mode 100644 index 6a265cb..0000000 --- a/tests/python.py +++ /dev/null @@ -1,8 +0,0 @@ -import sys - -sys.path.insert(0, "../src") - -import cs50 - -i = cs50.get_int("Input: ") -print(f"Output: {i}") diff --git a/tests/sqlite.py b/tests/sqlite.py deleted file mode 100644 index 05c2cea..0000000 --- a/tests/sqlite.py +++ /dev/null @@ -1,44 +0,0 @@ -import logging -import sys - -sys.path.insert(0, "../src") - -from cs50 import SQL - -logging.getLogger("cs50").disabled = False - -db = SQL("sqlite:///sqlite.db") -db.execute("SELECT 1") - -# TODO -#db.execute("SELECT * FROM Employee WHERE FirstName = ?", b'\x00') - -db.execute("SELECT * FROM Employee WHERE FirstName = ?", "' OR 1 = 1") - -db.execute("SELECT * FROM Employee WHERE FirstName IN (?)", "Andrew") -db.execute("SELECT * FROM Employee WHERE FirstName IN (?)", ["Andrew"]) -db.execute("SELECT * FROM Employee WHERE FirstName IN (?)", ("Andrew",)) -db.execute("SELECT * FROM Employee WHERE FirstName IN (?)", ["Andrew", "Nancy"]) -db.execute("SELECT * FROM Employee WHERE FirstName IN (?)", ("Andrew", "Nancy")) -db.execute("SELECT * FROM Employee WHERE FirstName IN (?)", []) -db.execute("SELECT * FROM Employee WHERE FirstName IN (?)", ()) - -db.execute("SELECT * FROM Employee WHERE FirstName = ? AND LastName = ?", "Andrew", "Adams") -db.execute("SELECT * FROM Employee WHERE FirstName = ? AND LastName = ?", ["Andrew", "Adams"]) -db.execute("SELECT * FROM Employee WHERE FirstName = ? AND LastName = ?", ("Andrew", "Adams")) - -db.execute("SELECT * FROM Employee WHERE FirstName = :1 AND LastName = :2", "Andrew", "Adams") -db.execute("SELECT * FROM Employee WHERE FirstName = :1 AND LastName = :2", ["Andrew", "Adams"]) -db.execute("SELECT * FROM Employee WHERE FirstName = :1 AND LastName = :2", ("Andrew", "Adams")) - -db.execute("SELECT * FROM Employee WHERE FirstName = ':Andrew :Adams'") - -db.execute("SELECT * FROM Employee WHERE FirstName = :first AND LastName = :last", first="Andrew", last="Adams") -db.execute("SELECT * FROM Employee WHERE FirstName = :first AND LastName = :last", {"first": "Andrew", "last": "Adams"}) - -db.execute("SELECT * FROM Employee WHERE FirstName = %s AND LastName = %s", "Andrew", "Adams") -db.execute("SELECT * FROM Employee WHERE FirstName = %s AND LastName = %s", ["Andrew", "Adams"]) -db.execute("SELECT * FROM Employee WHERE FirstName = %s AND LastName = %s", ("Andrew", "Adams")) - -db.execute("SELECT * FROM Employee WHERE FirstName = %(first)s AND LastName = %(last)s", first="Andrew", last="Adams") -db.execute("SELECT * FROM Employee WHERE FirstName = %(first)s AND LastName = %(last)s", {"first": "Andrew", "last": "Adams"}) diff --git a/tests/tb.py b/tests/tb.py deleted file mode 100644 index 3ad8175..0000000 --- a/tests/tb.py +++ /dev/null @@ -1,10 +0,0 @@ -import sys - -sys.path.insert(0, "../src") - -import cs50 -import requests - -def f(): - res = requests.get("cs50.harvard.edu") -f() From 839b1f1baca132bc4edf83489f3856f7438bf6de Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Tue, 13 Apr 2021 12:40:34 -0400 Subject: [PATCH 074/159] add statement tests, rollback on error in autocommit --- src/cs50/_sql_util.py | 8 ++ src/cs50/_statement.py | 3 +- src/cs50/_statement_util.py | 12 ++ src/cs50/sql.py | 14 ++- tests/test_statement.py | 213 ++++++++++++++++++++++++++++++++++++ 5 files changed, 244 insertions(+), 6 deletions(-) create mode 100644 tests/test_statement.py diff --git a/src/cs50/_sql_util.py b/src/cs50/_sql_util.py index 238d979..dbaff2e 100644 --- a/src/cs50/_sql_util.py +++ b/src/cs50/_sql_util.py @@ -3,6 +3,14 @@ import decimal +def is_transaction_start(keyword): + return keyword in {"BEGIN", "START"} + + +def is_transaction_end(keyword): + return keyword in {"COMMIT", "ROLLBACK"} + + def fetch_select_result(result): rows = [dict(row) for row in result.fetchall()] for row in rows: diff --git a/src/cs50/_statement.py b/src/cs50/_statement.py index ac83758..3347f61 100644 --- a/src/cs50/_statement.py +++ b/src/cs50/_statement.py @@ -10,6 +10,7 @@ is_operation_token, is_placeholder, is_string_literal, + operation_keywords, Paramstyle, parse_placeholder, ) @@ -50,7 +51,7 @@ def _get_operation_keyword(self): for token in self._statement: if is_operation_token(token.ttype): token_value = token.value.upper() - if token_value in {"BEGIN", "DELETE", "INSERT", "SELECT", "START", "UPDATE"}: + if token_value in operation_keywords: operation_keyword = token_value break else: diff --git a/src/cs50/_statement_util.py b/src/cs50/_statement_util.py index 81b79e1..4ef092a 100644 --- a/src/cs50/_statement_util.py +++ b/src/cs50/_statement_util.py @@ -6,6 +6,18 @@ import sqlparse +operation_keywords = { + "BEGIN", + "COMMIT", + "DELETE", + "INSERT", + "ROLLBACK", + "SELECT", + "START", + "UPDATE" +} + + class Paramstyle(enum.Enum): FORMAT = enum.auto() NAMED = enum.auto() diff --git a/src/cs50/sql.py b/src/cs50/sql.py index d823c8b..ae9f97e 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -8,7 +8,7 @@ from ._session import Session from ._statement import Statement -from ._sql_util import fetch_select_result +from ._sql_util import fetch_select_result, is_transaction_start, is_transaction_end _logger = logging.getLogger("cs50") @@ -26,7 +26,7 @@ def execute(self, sql, *args, **kwargs): """Execute a SQL statement.""" statement = Statement(self._dialect, sql, *args, **kwargs) operation_keyword = statement.get_operation_keyword() - if operation_keyword in {"BEGIN", "START"}: + if is_transaction_start(operation_keyword): self._autocommit = False if self._autocommit: @@ -36,11 +36,9 @@ def execute(self, sql, *args, **kwargs): if self._autocommit: self._session.execute("COMMIT") - self._session.remove() - if operation_keyword in {"COMMIT", "ROLLBACK"}: + if is_transaction_end(operation_keyword): self._autocommit = True - self._session.remove() if operation_keyword == "SELECT": ret = fetch_select_result(result) @@ -51,8 +49,12 @@ def execute(self, sql, *args, **kwargs): else: ret = True + if self._autocommit: + self._session.remove() + return ret + def _execute(self, statement): # Catch SQLAlchemy warnings with warnings.catch_warnings(): @@ -62,6 +64,8 @@ def _execute(self, statement): result = self._session.execute(statement) except sqlalchemy.exc.IntegrityError as exc: _logger.debug(termcolor.colored(str(statement), "yellow")) + if self._autocommit: + self._session.execute("ROLLBACK") raise ValueError(exc.orig) from None except (sqlalchemy.exc.OperationalError, sqlalchemy.exc.ProgrammingError) as exc: self._session.remove() diff --git a/tests/test_statement.py b/tests/test_statement.py new file mode 100644 index 0000000..cbbafe8 --- /dev/null +++ b/tests/test_statement.py @@ -0,0 +1,213 @@ +import unittest + +from unittest.mock import patch + +from cs50._statement import Statement +from cs50._sql_sanitizer import SQLSanitizer + +class TestStatement(unittest.TestCase): + # TODO assert correct exception messages + def test_mutex_args_and_kwargs(self): + with self.assertRaises(RuntimeError): + Statement("", "", "test", foo="foo") + + with self.assertRaises(RuntimeError): + Statement("", "", "test", 1, 2, foo="foo", bar="bar") + + @patch.object(SQLSanitizer, "escape", return_value="test") + @patch.object(Statement, "_escape_verbatim_colons") + def test_valid_qmark_count(self, *_): + Statement("", "SELECT * FROM test WHERE id = ?", 1) + Statement("", "SELECT * FROM test WHERE id = ? and val = ?", 1, 'test') + Statement("", "INSERT INTO test (id, val, is_valid) VALUES (?, ?, ?)", 1, 'test', True) + + @patch.object(SQLSanitizer, "escape", return_value="test") + @patch.object(Statement, "_escape_verbatim_colons") + def test_invalid_qmark_count(self, *_): + def assert_invalid_count(sql, *args): + with self.assertRaises(RuntimeError, msg=f"{sql} {str(args)}"): + Statement("", sql, *args) + + statements = [ + ("SELECT * FROM test WHERE id = ?", ()), + ("SELECT * FROM test WHERE id = ?", (1, "test")), + ("SELECT * FROM test WHERE id = ? AND val = ?", (1,)), + ("SELECT * FROM test WHERE id = ? AND val = ?", ()), + ("SELECT * FROM test WHERE id = ? AND val = ?", (1, "test", True)), + ] + + for sql, args in statements: + assert_invalid_count(sql, *args) + + + @patch.object(SQLSanitizer, "escape", return_value="test") + @patch.object(Statement, "_escape_verbatim_colons") + def test_valid_format_count(self, *_): + Statement("", "SELECT * FROM test WHERE id = %s", 1) + Statement("", "SELECT * FROM test WHERE id = %s and val = %s", 1, 'test') + Statement("", "INSERT INTO test (id, val, is_valid) VALUES (%s, %s, %s)", 1, 'test', True) + + @patch.object(SQLSanitizer, "escape", return_value="test") + @patch.object(Statement, "_escape_verbatim_colons") + def test_invalid_format_count(self, *_): + def assert_invalid_count(sql, *args): + with self.assertRaises(RuntimeError, msg=f"{sql} {str(args)}"): + Statement("", sql, *args) + + statements = [ + ("SELECT * FROM test WHERE id = %s", ()), + ("SELECT * FROM test WHERE id = %s", (1, "test")), + ("SELECT * FROM test WHERE id = %s AND val = ?", (1,)), + ("SELECT * FROM test WHERE id = %s AND val = ?", ()), + ("SELECT * FROM test WHERE id = %s AND val = ?", (1, "test", True)), + ] + + for sql, args in statements: + assert_invalid_count(sql, *args) + + @patch.object(SQLSanitizer, "escape", return_value="test") + @patch.object(Statement, "_escape_verbatim_colons") + def test_missing_numeric(self, *_): + def assert_missing_numeric(sql, *args): + with self.assertRaises(RuntimeError, msg=f"{sql} {str(args)}"): + Statement("", sql, *args) + + statements = [ + ("SELECT * FROM test WHERE id = :1", ()), + ("SELECT * FROM test WHERE id = :1 AND val = :2", ()), + ("SELECT * FROM test WHERE id = :1 AND val = :2", (1,)), + ("SELECT * FROM test WHERE id = :1 AND val = :2 AND is_valid = :3", ()), + ("SELECT * FROM test WHERE id = :1 AND val = :2 AND is_valid = :3", (1,)), + ("SELECT * FROM test WHERE id = :1 AND val = :2 AND is_valid = :3", (1, "test")), + ] + + for sql, args in statements: + assert_missing_numeric(sql, *args) + + @patch.object(SQLSanitizer, "escape", return_value="test") + @patch.object(Statement, "_escape_verbatim_colons") + def test_unused_numeric(self, *_): + def assert_unused_numeric(sql, *args): + with self.assertRaises(RuntimeError, msg=f"{sql} {str(args)}"): + Statement("", sql, *args) + + statements = [ + ("SELECT * FROM test WHERE id = :1", (1, "test")), + ("SELECT * FROM test WHERE id = :1", (1, "test", True)), + ("SELECT * FROM test WHERE id = :1 AND val = :2", (1, "test", True)), + ] + + for sql, args in statements: + assert_unused_numeric(sql, *args) + + @patch.object(SQLSanitizer, "escape", return_value="test") + @patch.object(Statement, "_escape_verbatim_colons") + def test_missing_named(self, *_): + def assert_missing_named(sql, **kwargs): + with self.assertRaises(RuntimeError, msg=f"{sql} {str(kwargs)}"): + Statement("", sql, **kwargs) + + statements = [ + ("SELECT * FROM test WHERE id = :id", {}), + ("SELECT * FROM test WHERE id = :id AND val = :val", {}), + ("SELECT * FROM test WHERE id = :id AND val = :val", {"id": 1}), + ("SELECT * FROM test WHERE id = :id AND val = :val AND is_valid = :is_valid", {}), + ("SELECT * FROM test WHERE id = :id AND val = :val AND is_valid = :is_valid", {"id": 1}), + ("SELECT * FROM test WHERE id = :id AND val = :val AND is_valid = :is_valid", {"id": 1, "val": "test"}), + ] + + for sql, kwargs in statements: + assert_missing_named(sql, **kwargs) + + @patch.object(SQLSanitizer, "escape", return_value="test") + @patch.object(Statement, "_escape_verbatim_colons") + def test_unused_named(self, *_): + def assert_unused_named(sql, **kwargs): + with self.assertRaises(RuntimeError, msg=f"{sql} {str(kwargs)}"): + Statement("", sql, **kwargs) + + statements = [ + ("SELECT * FROM test WHERE id = :id", {"id": 1, "val": "test"}), + ("SELECT * FROM test WHERE id = :id", {"id": 1, "val": "test", "is_valid": True}), + ("SELECT * FROM test WHERE id = :id AND val = :val", {"id": 1, "val": "test", "is_valid": True}), + ] + + for sql, kwargs in statements: + assert_unused_named(sql, **kwargs) + + @patch.object(SQLSanitizer, "escape", return_value="test") + @patch.object(Statement, "_escape_verbatim_colons") + def test_missing_pyformat(self, *_): + def assert_missing_pyformat(sql, **kwargs): + with self.assertRaises(RuntimeError, msg=f"{sql} {str(kwargs)}"): + Statement("", sql, **kwargs) + + statements = [ + ("SELECT * FROM test WHERE id = %(id)s", {}), + ("SELECT * FROM test WHERE id = %(id)s AND val = %(val)s", {}), + ("SELECT * FROM test WHERE id = %(id)s AND val = %(val)s", {"id": 1}), + ("SELECT * FROM test WHERE id = %(id)s AND val = %(val)s AND is_valid = %(is_valid)s", {}), + ("SELECT * FROM test WHERE id = %(id)s AND val = %(val)s AND is_valid = %(is_valid)s", {"id": 1}), + ("SELECT * FROM test WHERE id = %(id)s AND val = %(val)s AND is_valid = %(is_valid)s", {"id": 1, "val": "test"}), + ] + + for sql, kwargs in statements: + assert_missing_pyformat(sql, **kwargs) + + @patch.object(SQLSanitizer, "escape", return_value="test") + @patch.object(Statement, "_escape_verbatim_colons") + def test_unused_pyformat(self, *_): + def assert_unused_pyformat(sql, **kwargs): + with self.assertRaises(RuntimeError, msg=f"{sql} {str(kwargs)}"): + Statement("", sql, **kwargs) + + statements = [ + ("SELECT * FROM test WHERE id = %(id)s", {"id": 1, "val": "test"}), + ("SELECT * FROM test WHERE id = %(id)s", {"id": 1, "val": "test", "is_valid": True}), + ("SELECT * FROM test WHERE id = %(id)s AND val = %(val)s", {"id": 1, "val": "test", "is_valid": True}), + ] + + for sql, kwargs in statements: + assert_unused_pyformat(sql, **kwargs) + + def test_multiple_statements(self): + def assert_raises_runtimeerror(sql): + with self.assertRaises(RuntimeError): + Statement("", sql) + + statements = [ + "SELECT 1; SELECT 2;", + "SELECT 1; SELECT 2", + "SELECT 1; SELECT 2; SELECT 3", + "SELECT 1; SELECT 2; SELECT 3;", + "SELECT 1;SELECT 2", + "select 1; select 2", + "select 1;select 2", + "DELETE FROM test; SELECT * FROM test", + ] + + for sql in statements: + assert_raises_runtimeerror(sql) + + def test_get_operation_keyword(self): + def test_raw_and_lowercase(sql, keyword): + statement = Statement("", sql) + self.assertEqual(statement.get_operation_keyword(), keyword) + + statement = Statement("", sql.lower()) + self.assertEqual(statement.get_operation_keyword(), keyword) + + + statements = [ + ("SELECT * FROM test", "SELECT"), + ("INSERT INTO test (id, val) VALUES (1, 'test')", "INSERT"), + ("DELETE FROM test", "DELETE"), + ("UPDATE test SET id = 2", "UPDATE"), + ("START TRANSACTION", "START"), + ("BEGIN", "BEGIN"), + ("COMMIT", "COMMIT"), + ("ROLLBACK", "ROLLBACK"), + ] + + for sql, keyword in statements: + test_raw_and_lowercase(sql, keyword) From 9302a1e8b8a82ac2a206cee0fb340f0dd82cf153 Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Tue, 13 Apr 2021 21:19:13 -0400 Subject: [PATCH 075/159] move operation check to Statement --- src/cs50/_sql_util.py | 8 ---- src/cs50/_statement.py | 20 +++++++-- src/cs50/sql.py | 13 +++--- tests/test_statement.py | 95 ++++++++++++++++++++++++----------------- 4 files changed, 79 insertions(+), 57 deletions(-) diff --git a/src/cs50/_sql_util.py b/src/cs50/_sql_util.py index dbaff2e..238d979 100644 --- a/src/cs50/_sql_util.py +++ b/src/cs50/_sql_util.py @@ -3,14 +3,6 @@ import decimal -def is_transaction_start(keyword): - return keyword in {"BEGIN", "START"} - - -def is_transaction_end(keyword): - return keyword in {"COMMIT", "ROLLBACK"} - - def fetch_select_result(result): rows = [dict(row) for row in result.fetchall()] for row in rows: diff --git a/src/cs50/_statement.py b/src/cs50/_statement.py index 3347f61..2502284 100644 --- a/src/cs50/_statement.py +++ b/src/cs50/_statement.py @@ -147,9 +147,23 @@ def _escape_verbatim_colons(self): if is_string_literal(token.ttype) or is_identifier(token.ttype): token.value = escape_verbatim_colon(token.value) - def get_operation_keyword(self): - """Returns the operation keyword of the statement (e.g., SELECT) if found, or None""" - return self._operation_keyword + def is_transaction_start(self): + return self._operation_keyword in {"BEGIN", "START"} + + def is_transaction_end(self): + return self._operation_keyword in {"COMMIT", "ROLLBACK"} + + def is_delete(self): + return self._operation_keyword == "DELETE" + + def is_insert(self): + return self._operation_keyword == "INSERT" + + def is_select(self): + return self._operation_keyword == "SELECT" + + def is_update(self): + return self._operation_keyword == "UPDATE" def __str__(self): return "".join([str(token) for token in self._tokens]) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index ae9f97e..c0e41fd 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -8,7 +8,7 @@ from ._session import Session from ._statement import Statement -from ._sql_util import fetch_select_result, is_transaction_start, is_transaction_end +from ._sql_util import fetch_select_result _logger = logging.getLogger("cs50") @@ -25,8 +25,7 @@ def __init__(self, url, **engine_kwargs): def execute(self, sql, *args, **kwargs): """Execute a SQL statement.""" statement = Statement(self._dialect, sql, *args, **kwargs) - operation_keyword = statement.get_operation_keyword() - if is_transaction_start(operation_keyword): + if statement.is_transaction_start(): self._autocommit = False if self._autocommit: @@ -37,14 +36,14 @@ def execute(self, sql, *args, **kwargs): if self._autocommit: self._session.execute("COMMIT") - if is_transaction_end(operation_keyword): + if statement.is_transaction_end(): self._autocommit = True - if operation_keyword == "SELECT": + if statement.is_select(): ret = fetch_select_result(result) - elif operation_keyword == "INSERT": + elif statement.is_insert(): ret = self._last_row_id_or_none(result) - elif operation_keyword in {"DELETE", "UPDATE"}: + elif statement.is_delete() or statement.is_update(): ret = result.rowcount else: ret = True diff --git a/tests/test_statement.py b/tests/test_statement.py index cbbafe8..fcee3b9 100644 --- a/tests/test_statement.py +++ b/tests/test_statement.py @@ -9,24 +9,24 @@ class TestStatement(unittest.TestCase): # TODO assert correct exception messages def test_mutex_args_and_kwargs(self): with self.assertRaises(RuntimeError): - Statement("", "", "test", foo="foo") + Statement(None, None, "test", foo="foo") with self.assertRaises(RuntimeError): - Statement("", "", "test", 1, 2, foo="foo", bar="bar") + Statement(None, None, "test", 1, 2, foo="foo", bar="bar") @patch.object(SQLSanitizer, "escape", return_value="test") @patch.object(Statement, "_escape_verbatim_colons") def test_valid_qmark_count(self, *_): - Statement("", "SELECT * FROM test WHERE id = ?", 1) - Statement("", "SELECT * FROM test WHERE id = ? and val = ?", 1, 'test') - Statement("", "INSERT INTO test (id, val, is_valid) VALUES (?, ?, ?)", 1, 'test', True) + Statement(None, "SELECT * FROM test WHERE id = ?", 1) + Statement(None, "SELECT * FROM test WHERE id = ? and val = ?", 1, 'test') + Statement(None, "INSERT INTO test (id, val, is_valid) VALUES (?, ?, ?)", 1, 'test', True) @patch.object(SQLSanitizer, "escape", return_value="test") @patch.object(Statement, "_escape_verbatim_colons") def test_invalid_qmark_count(self, *_): def assert_invalid_count(sql, *args): with self.assertRaises(RuntimeError, msg=f"{sql} {str(args)}"): - Statement("", sql, *args) + Statement(None, sql, *args) statements = [ ("SELECT * FROM test WHERE id = ?", ()), @@ -43,16 +43,16 @@ def assert_invalid_count(sql, *args): @patch.object(SQLSanitizer, "escape", return_value="test") @patch.object(Statement, "_escape_verbatim_colons") def test_valid_format_count(self, *_): - Statement("", "SELECT * FROM test WHERE id = %s", 1) - Statement("", "SELECT * FROM test WHERE id = %s and val = %s", 1, 'test') - Statement("", "INSERT INTO test (id, val, is_valid) VALUES (%s, %s, %s)", 1, 'test', True) + Statement(None, "SELECT * FROM test WHERE id = %s", 1) + Statement(None, "SELECT * FROM test WHERE id = %s and val = %s", 1, 'test') + Statement(None, "INSERT INTO test (id, val, is_valid) VALUES (%s, %s, %s)", 1, 'test', True) @patch.object(SQLSanitizer, "escape", return_value="test") @patch.object(Statement, "_escape_verbatim_colons") def test_invalid_format_count(self, *_): def assert_invalid_count(sql, *args): with self.assertRaises(RuntimeError, msg=f"{sql} {str(args)}"): - Statement("", sql, *args) + Statement(None, sql, *args) statements = [ ("SELECT * FROM test WHERE id = %s", ()), @@ -70,7 +70,7 @@ def assert_invalid_count(sql, *args): def test_missing_numeric(self, *_): def assert_missing_numeric(sql, *args): with self.assertRaises(RuntimeError, msg=f"{sql} {str(args)}"): - Statement("", sql, *args) + Statement(None, sql, *args) statements = [ ("SELECT * FROM test WHERE id = :1", ()), @@ -89,7 +89,7 @@ def assert_missing_numeric(sql, *args): def test_unused_numeric(self, *_): def assert_unused_numeric(sql, *args): with self.assertRaises(RuntimeError, msg=f"{sql} {str(args)}"): - Statement("", sql, *args) + Statement(None, sql, *args) statements = [ ("SELECT * FROM test WHERE id = :1", (1, "test")), @@ -105,7 +105,7 @@ def assert_unused_numeric(sql, *args): def test_missing_named(self, *_): def assert_missing_named(sql, **kwargs): with self.assertRaises(RuntimeError, msg=f"{sql} {str(kwargs)}"): - Statement("", sql, **kwargs) + Statement(None, sql, **kwargs) statements = [ ("SELECT * FROM test WHERE id = :id", {}), @@ -124,7 +124,7 @@ def assert_missing_named(sql, **kwargs): def test_unused_named(self, *_): def assert_unused_named(sql, **kwargs): with self.assertRaises(RuntimeError, msg=f"{sql} {str(kwargs)}"): - Statement("", sql, **kwargs) + Statement(None, sql, **kwargs) statements = [ ("SELECT * FROM test WHERE id = :id", {"id": 1, "val": "test"}), @@ -140,7 +140,7 @@ def assert_unused_named(sql, **kwargs): def test_missing_pyformat(self, *_): def assert_missing_pyformat(sql, **kwargs): with self.assertRaises(RuntimeError, msg=f"{sql} {str(kwargs)}"): - Statement("", sql, **kwargs) + Statement(None, sql, **kwargs) statements = [ ("SELECT * FROM test WHERE id = %(id)s", {}), @@ -159,7 +159,7 @@ def assert_missing_pyformat(sql, **kwargs): def test_unused_pyformat(self, *_): def assert_unused_pyformat(sql, **kwargs): with self.assertRaises(RuntimeError, msg=f"{sql} {str(kwargs)}"): - Statement("", sql, **kwargs) + Statement(None, sql, **kwargs) statements = [ ("SELECT * FROM test WHERE id = %(id)s", {"id": 1, "val": "test"}), @@ -173,7 +173,7 @@ def assert_unused_pyformat(sql, **kwargs): def test_multiple_statements(self): def assert_raises_runtimeerror(sql): with self.assertRaises(RuntimeError): - Statement("", sql) + Statement(None, sql) statements = [ "SELECT 1; SELECT 2;", @@ -189,25 +189,42 @@ def assert_raises_runtimeerror(sql): for sql in statements: assert_raises_runtimeerror(sql) - def test_get_operation_keyword(self): - def test_raw_and_lowercase(sql, keyword): - statement = Statement("", sql) - self.assertEqual(statement.get_operation_keyword(), keyword) - - statement = Statement("", sql.lower()) - self.assertEqual(statement.get_operation_keyword(), keyword) - - - statements = [ - ("SELECT * FROM test", "SELECT"), - ("INSERT INTO test (id, val) VALUES (1, 'test')", "INSERT"), - ("DELETE FROM test", "DELETE"), - ("UPDATE test SET id = 2", "UPDATE"), - ("START TRANSACTION", "START"), - ("BEGIN", "BEGIN"), - ("COMMIT", "COMMIT"), - ("ROLLBACK", "ROLLBACK"), - ] - - for sql, keyword in statements: - test_raw_and_lowercase(sql, keyword) + def test_is_delete(self): + self.assertTrue(Statement(None, "DELETE FROM test").is_delete()) + self.assertTrue(Statement(None, "delete FROM test").is_delete()) + self.assertFalse(Statement(None, "SELECT * FROM test").is_delete()) + self.assertFalse(Statement(None, "INSERT INTO test (id, val) VALUES (1, 'test')").is_delete()) + + def test_is_insert(self): + self.assertTrue(Statement(None, "INSERT INTO test (id, val) VALUES (1, 'test')").is_insert()) + self.assertTrue(Statement(None, "insert INTO test (id, val) VALUES (1, 'test')").is_insert()) + self.assertFalse(Statement(None, "SELECT * FROM test").is_insert()) + self.assertFalse(Statement(None, "DELETE FROM test").is_insert()) + + def test_is_select(self): + self.assertTrue(Statement(None, "SELECT * FROM test").is_select()) + self.assertTrue(Statement(None, "select * FROM test").is_select()) + self.assertFalse(Statement(None, "DELETE FROM test").is_select()) + self.assertFalse(Statement(None, "INSERT INTO test (id, val) VALUES (1, 'test')").is_select()) + + def test_is_update(self): + self.assertTrue(Statement(None, "UPDATE test SET id = 2").is_update()) + self.assertTrue(Statement(None, "update test SET id = 2").is_update()) + self.assertFalse(Statement(None, "SELECT * FROM test").is_update()) + self.assertFalse(Statement(None, "INSERT INTO test (id, val) VALUES (1, 'test')").is_update()) + + def test_is_transaction_start(self): + self.assertTrue(Statement(None, "START TRANSACTION").is_transaction_start()) + self.assertTrue(Statement(None, "start TRANSACTION").is_transaction_start()) + self.assertTrue(Statement(None, "BEGIN").is_transaction_start()) + self.assertTrue(Statement(None, "begin").is_transaction_start()) + self.assertFalse(Statement(None, "SELECT * FROM test").is_transaction_start()) + self.assertFalse(Statement(None, "DELETE FROM test").is_transaction_start()) + + def test_is_transaction_end(self): + self.assertTrue(Statement(None, "COMMIT").is_transaction_end()) + self.assertTrue(Statement(None, "commit").is_transaction_end()) + self.assertTrue(Statement(None, "ROLLBACK").is_transaction_end()) + self.assertTrue(Statement(None, "rollback").is_transaction_end()) + self.assertFalse(Statement(None, "SELECT * FROM test").is_transaction_end()) + self.assertFalse(Statement(None, "DELETE FROM test").is_transaction_end()) From f6912d27ba1ed250519eff6c434720a092679783 Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Tue, 13 Apr 2021 22:31:51 -0400 Subject: [PATCH 076/159] use statement factory --- src/cs50/_statement.py | 13 ++- src/cs50/sql.py | 10 +-- tests/test_cs50.py | 9 --- tests/test_statement.py | 174 ++++++++++++++++++++-------------------- 4 files changed, 105 insertions(+), 101 deletions(-) diff --git a/src/cs50/_statement.py b/src/cs50/_statement.py index 2502284..cc4cdb8 100644 --- a/src/cs50/_statement.py +++ b/src/cs50/_statement.py @@ -16,14 +16,23 @@ ) +def statement_factory(dialect): + sql_sanitizer = SQLSanitizer(dialect) + + def statement(sql, *args, **kwargs): + return Statement(sql_sanitizer, sql, *args, **kwargs) + + return statement + + class Statement: """Parses a SQL statement and replaces the placeholders with the corresponding parameters""" - def __init__(self, dialect, sql, *args, **kwargs): + def __init__(self, sql_sanitizer, sql, *args, **kwargs): if len(args) > 0 and len(kwargs) > 0: raise RuntimeError("cannot pass both positional and named parameters") - self._sql_sanitizer = SQLSanitizer(dialect) + self._sql_sanitizer = sql_sanitizer self._args = self._get_escaped_args(args) self._kwargs = self._get_escaped_kwargs(kwargs) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index c0e41fd..10bffd6 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -7,7 +7,7 @@ import termcolor from ._session import Session -from ._statement import Statement +from ._statement import statement_factory from ._sql_util import fetch_select_result _logger = logging.getLogger("cs50") @@ -18,13 +18,14 @@ class SQL: def __init__(self, url, **engine_kwargs): self._session = Session(url, **engine_kwargs) - self._dialect = self._session.get_bind().dialect - self._is_postgres = self._dialect.name in {"postgres", "postgresql"} + dialect = self._session.get_bind().dialect + self._is_postgres = dialect.name in {"postgres", "postgresql"} + self._sanitized_statement = statement_factory(dialect) self._autocommit = False def execute(self, sql, *args, **kwargs): """Execute a SQL statement.""" - statement = Statement(self._dialect, sql, *args, **kwargs) + statement = self._sanitized_statement(sql, *args, **kwargs) if statement.is_transaction_start(): self._autocommit = False @@ -53,7 +54,6 @@ def execute(self, sql, *args, **kwargs): return ret - def _execute(self, statement): # Catch SQLAlchemy warnings with warnings.catch_warnings(): diff --git a/tests/test_cs50.py b/tests/test_cs50.py index a58424d..dd0f14b 100644 --- a/tests/test_cs50.py +++ b/tests/test_cs50.py @@ -14,34 +14,29 @@ def test_get_string_empty_input(self, mock_get_input): self.assertEqual(get_string("Answer: "), "") mock_get_input.assert_called_with("Answer: ") - @patch("cs50.cs50._get_input", return_value="test") def test_get_string_nonempty_input(self, mock_get_input): """Returns the provided non-empty input""" self.assertEqual(get_string("Answer: "), "test") mock_get_input.assert_called_with("Answer: ") - @patch("cs50.cs50._get_input", side_effect=EOFError) def test_get_string_eof(self, mock_get_input): """Returns None on EOF""" self.assertIs(get_string("Answer: "), None) mock_get_input.assert_called_with("Answer: ") - def test_get_string_invalid_prompt(self): """Raises TypeError when prompt is not str""" with self.assertRaises(TypeError): get_string(1) - @patch("cs50.cs50.get_string", return_value=None) def test_get_int_eof(self, mock_get_string): """Returns None on EOF""" self.assertIs(_get_int("Answer: "), None) mock_get_string.assert_called_with("Answer: ") - def test_get_int_valid_input(self): """Returns the provided integer input""" @@ -62,7 +57,6 @@ def assert_equal(return_value, expected_value): for return_value, expected_value in values: assert_equal(return_value, expected_value) - def test_get_int_invalid_input(self): """Raises ValueError when input is invalid base-10 int""" @@ -90,14 +84,12 @@ def assert_raises_valueerror(return_value): for return_value in return_values: assert_raises_valueerror(return_value) - @patch("cs50.cs50.get_string", return_value=None) def test_get_float_eof(self, mock_get_string): """Returns None on EOF""" self.assertIs(_get_float("Answer: "), None) mock_get_string.assert_called_with("Answer: ") - def test_get_float_valid_input(self): """Returns the provided integer input""" def assert_equal(return_value, expected_value): @@ -121,7 +113,6 @@ def assert_equal(return_value, expected_value): for return_value, expected_value in values: assert_equal(return_value, expected_value) - def test_get_float_invalid_input(self): """Raises ValueError when input is invalid float""" diff --git a/tests/test_statement.py b/tests/test_statement.py index fcee3b9..91261cd 100644 --- a/tests/test_statement.py +++ b/tests/test_statement.py @@ -5,28 +5,29 @@ from cs50._statement import Statement from cs50._sql_sanitizer import SQLSanitizer + +@patch.object(SQLSanitizer, "escape", return_value="test") class TestStatement(unittest.TestCase): # TODO assert correct exception messages - def test_mutex_args_and_kwargs(self): + def test_mutex_args_and_kwargs(self, MockSQLSanitizer): with self.assertRaises(RuntimeError): - Statement(None, None, "test", foo="foo") + Statement(MockSQLSanitizer(), "SELECT * FROM test WHERE id = ? AND val = :val", 1, val="test") with self.assertRaises(RuntimeError): - Statement(None, None, "test", 1, 2, foo="foo", bar="bar") + Statement(MockSQLSanitizer(), "SELECT * FROM test", "test", 1, 2, foo="foo", bar="bar") - @patch.object(SQLSanitizer, "escape", return_value="test") @patch.object(Statement, "_escape_verbatim_colons") - def test_valid_qmark_count(self, *_): - Statement(None, "SELECT * FROM test WHERE id = ?", 1) - Statement(None, "SELECT * FROM test WHERE id = ? and val = ?", 1, 'test') - Statement(None, "INSERT INTO test (id, val, is_valid) VALUES (?, ?, ?)", 1, 'test', True) + def test_valid_qmark_count(self, MockSQLSanitizer, *_): + Statement(MockSQLSanitizer(), "SELECT * FROM test WHERE id = ?", 1) + Statement(MockSQLSanitizer(), "SELECT * FROM test WHERE id = ? and val = ?", 1, 'test') + Statement(MockSQLSanitizer(), + "INSERT INTO test (id, val, is_valid) VALUES (?, ?, ?)", 1, 'test', True) - @patch.object(SQLSanitizer, "escape", return_value="test") @patch.object(Statement, "_escape_verbatim_colons") - def test_invalid_qmark_count(self, *_): + def test_invalid_qmark_count(self, MockSQLSanitizer, *_): def assert_invalid_count(sql, *args): with self.assertRaises(RuntimeError, msg=f"{sql} {str(args)}"): - Statement(None, sql, *args) + Statement(MockSQLSanitizer(), sql, *args) statements = [ ("SELECT * FROM test WHERE id = ?", ()), @@ -39,20 +40,18 @@ def assert_invalid_count(sql, *args): for sql, args in statements: assert_invalid_count(sql, *args) - - @patch.object(SQLSanitizer, "escape", return_value="test") @patch.object(Statement, "_escape_verbatim_colons") - def test_valid_format_count(self, *_): - Statement(None, "SELECT * FROM test WHERE id = %s", 1) - Statement(None, "SELECT * FROM test WHERE id = %s and val = %s", 1, 'test') - Statement(None, "INSERT INTO test (id, val, is_valid) VALUES (%s, %s, %s)", 1, 'test', True) + def test_valid_format_count(self, MockSQLSanitizer, *_): + Statement(MockSQLSanitizer(), "SELECT * FROM test WHERE id = %s", 1) + Statement(MockSQLSanitizer(), "SELECT * FROM test WHERE id = %s and val = %s", 1, 'test') + Statement(MockSQLSanitizer(), + "INSERT INTO test (id, val, is_valid) VALUES (%s, %s, %s)", 1, 'test', True) - @patch.object(SQLSanitizer, "escape", return_value="test") @patch.object(Statement, "_escape_verbatim_colons") - def test_invalid_format_count(self, *_): + def test_invalid_format_count(self, MockSQLSanitizer, *_): def assert_invalid_count(sql, *args): with self.assertRaises(RuntimeError, msg=f"{sql} {str(args)}"): - Statement(None, sql, *args) + Statement(MockSQLSanitizer(), sql, *args) statements = [ ("SELECT * FROM test WHERE id = %s", ()), @@ -65,12 +64,11 @@ def assert_invalid_count(sql, *args): for sql, args in statements: assert_invalid_count(sql, *args) - @patch.object(SQLSanitizer, "escape", return_value="test") @patch.object(Statement, "_escape_verbatim_colons") - def test_missing_numeric(self, *_): + def test_missing_numeric(self, MockSQLSanitizer, *_): def assert_missing_numeric(sql, *args): with self.assertRaises(RuntimeError, msg=f"{sql} {str(args)}"): - Statement(None, sql, *args) + Statement(MockSQLSanitizer(), sql, *args) statements = [ ("SELECT * FROM test WHERE id = :1", ()), @@ -84,12 +82,11 @@ def assert_missing_numeric(sql, *args): for sql, args in statements: assert_missing_numeric(sql, *args) - @patch.object(SQLSanitizer, "escape", return_value="test") @patch.object(Statement, "_escape_verbatim_colons") - def test_unused_numeric(self, *_): + def test_unused_numeric(self, MockSQLSanitizer, *_): def assert_unused_numeric(sql, *args): with self.assertRaises(RuntimeError, msg=f"{sql} {str(args)}"): - Statement(None, sql, *args) + Statement(MockSQLSanitizer(), sql, *args) statements = [ ("SELECT * FROM test WHERE id = :1", (1, "test")), @@ -100,80 +97,82 @@ def assert_unused_numeric(sql, *args): for sql, args in statements: assert_unused_numeric(sql, *args) - @patch.object(SQLSanitizer, "escape", return_value="test") @patch.object(Statement, "_escape_verbatim_colons") - def test_missing_named(self, *_): + def test_missing_named(self, MockSQLSanitizer, *_): def assert_missing_named(sql, **kwargs): with self.assertRaises(RuntimeError, msg=f"{sql} {str(kwargs)}"): - Statement(None, sql, **kwargs) + Statement(MockSQLSanitizer(), sql, **kwargs) statements = [ ("SELECT * FROM test WHERE id = :id", {}), ("SELECT * FROM test WHERE id = :id AND val = :val", {}), ("SELECT * FROM test WHERE id = :id AND val = :val", {"id": 1}), ("SELECT * FROM test WHERE id = :id AND val = :val AND is_valid = :is_valid", {}), - ("SELECT * FROM test WHERE id = :id AND val = :val AND is_valid = :is_valid", {"id": 1}), - ("SELECT * FROM test WHERE id = :id AND val = :val AND is_valid = :is_valid", {"id": 1, "val": "test"}), + ("SELECT * FROM test WHERE id = :id AND val = :val AND is_valid = :is_valid", + {"id": 1}), + ("SELECT * FROM test WHERE id = :id AND val = :val AND is_valid = :is_valid", + {"id": 1, "val": "test"}), ] for sql, kwargs in statements: assert_missing_named(sql, **kwargs) - @patch.object(SQLSanitizer, "escape", return_value="test") @patch.object(Statement, "_escape_verbatim_colons") - def test_unused_named(self, *_): + def test_unused_named(self, MockSQLSanitizer, *_): def assert_unused_named(sql, **kwargs): with self.assertRaises(RuntimeError, msg=f"{sql} {str(kwargs)}"): - Statement(None, sql, **kwargs) + Statement(MockSQLSanitizer(), sql, **kwargs) statements = [ ("SELECT * FROM test WHERE id = :id", {"id": 1, "val": "test"}), ("SELECT * FROM test WHERE id = :id", {"id": 1, "val": "test", "is_valid": True}), - ("SELECT * FROM test WHERE id = :id AND val = :val", {"id": 1, "val": "test", "is_valid": True}), + ("SELECT * FROM test WHERE id = :id AND val = :val", + {"id": 1, "val": "test", "is_valid": True}), ] for sql, kwargs in statements: assert_unused_named(sql, **kwargs) - @patch.object(SQLSanitizer, "escape", return_value="test") @patch.object(Statement, "_escape_verbatim_colons") - def test_missing_pyformat(self, *_): + def test_missing_pyformat(self, MockSQLSanitizer, *_): def assert_missing_pyformat(sql, **kwargs): with self.assertRaises(RuntimeError, msg=f"{sql} {str(kwargs)}"): - Statement(None, sql, **kwargs) + Statement(MockSQLSanitizer(), sql, **kwargs) statements = [ ("SELECT * FROM test WHERE id = %(id)s", {}), ("SELECT * FROM test WHERE id = %(id)s AND val = %(val)s", {}), ("SELECT * FROM test WHERE id = %(id)s AND val = %(val)s", {"id": 1}), ("SELECT * FROM test WHERE id = %(id)s AND val = %(val)s AND is_valid = %(is_valid)s", {}), - ("SELECT * FROM test WHERE id = %(id)s AND val = %(val)s AND is_valid = %(is_valid)s", {"id": 1}), - ("SELECT * FROM test WHERE id = %(id)s AND val = %(val)s AND is_valid = %(is_valid)s", {"id": 1, "val": "test"}), + ("SELECT * FROM test WHERE id = %(id)s AND val = %(val)s AND is_valid = %(is_valid)s", + {"id": 1}), + ("SELECT * FROM test WHERE id = %(id)s AND val = %(val)s AND is_valid = %(is_valid)s", + {"id": 1, "val": "test"}), ] for sql, kwargs in statements: assert_missing_pyformat(sql, **kwargs) - @patch.object(SQLSanitizer, "escape", return_value="test") @patch.object(Statement, "_escape_verbatim_colons") - def test_unused_pyformat(self, *_): + def test_unused_pyformat(self, MockSQLSanitizer, *_): def assert_unused_pyformat(sql, **kwargs): with self.assertRaises(RuntimeError, msg=f"{sql} {str(kwargs)}"): - Statement(None, sql, **kwargs) + Statement(MockSQLSanitizer(), sql, **kwargs) statements = [ ("SELECT * FROM test WHERE id = %(id)s", {"id": 1, "val": "test"}), ("SELECT * FROM test WHERE id = %(id)s", {"id": 1, "val": "test", "is_valid": True}), - ("SELECT * FROM test WHERE id = %(id)s AND val = %(val)s", {"id": 1, "val": "test", "is_valid": True}), + ("SELECT * FROM test WHERE id = %(id)s AND val = %(val)s", + {"id": 1, "val": "test", "is_valid": True}), ] for sql, kwargs in statements: assert_unused_pyformat(sql, **kwargs) - def test_multiple_statements(self): + def test_multiple_statements(self, MockSQLSanitizer): def assert_raises_runtimeerror(sql): with self.assertRaises(RuntimeError): - Statement(None, sql) + Statement(MockSQLSanitizer(), sql) statements = [ "SELECT 1; SELECT 2;", @@ -189,42 +188,47 @@ def assert_raises_runtimeerror(sql): for sql in statements: assert_raises_runtimeerror(sql) - def test_is_delete(self): - self.assertTrue(Statement(None, "DELETE FROM test").is_delete()) - self.assertTrue(Statement(None, "delete FROM test").is_delete()) - self.assertFalse(Statement(None, "SELECT * FROM test").is_delete()) - self.assertFalse(Statement(None, "INSERT INTO test (id, val) VALUES (1, 'test')").is_delete()) - - def test_is_insert(self): - self.assertTrue(Statement(None, "INSERT INTO test (id, val) VALUES (1, 'test')").is_insert()) - self.assertTrue(Statement(None, "insert INTO test (id, val) VALUES (1, 'test')").is_insert()) - self.assertFalse(Statement(None, "SELECT * FROM test").is_insert()) - self.assertFalse(Statement(None, "DELETE FROM test").is_insert()) - - def test_is_select(self): - self.assertTrue(Statement(None, "SELECT * FROM test").is_select()) - self.assertTrue(Statement(None, "select * FROM test").is_select()) - self.assertFalse(Statement(None, "DELETE FROM test").is_select()) - self.assertFalse(Statement(None, "INSERT INTO test (id, val) VALUES (1, 'test')").is_select()) - - def test_is_update(self): - self.assertTrue(Statement(None, "UPDATE test SET id = 2").is_update()) - self.assertTrue(Statement(None, "update test SET id = 2").is_update()) - self.assertFalse(Statement(None, "SELECT * FROM test").is_update()) - self.assertFalse(Statement(None, "INSERT INTO test (id, val) VALUES (1, 'test')").is_update()) - - def test_is_transaction_start(self): - self.assertTrue(Statement(None, "START TRANSACTION").is_transaction_start()) - self.assertTrue(Statement(None, "start TRANSACTION").is_transaction_start()) - self.assertTrue(Statement(None, "BEGIN").is_transaction_start()) - self.assertTrue(Statement(None, "begin").is_transaction_start()) - self.assertFalse(Statement(None, "SELECT * FROM test").is_transaction_start()) - self.assertFalse(Statement(None, "DELETE FROM test").is_transaction_start()) - - def test_is_transaction_end(self): - self.assertTrue(Statement(None, "COMMIT").is_transaction_end()) - self.assertTrue(Statement(None, "commit").is_transaction_end()) - self.assertTrue(Statement(None, "ROLLBACK").is_transaction_end()) - self.assertTrue(Statement(None, "rollback").is_transaction_end()) - self.assertFalse(Statement(None, "SELECT * FROM test").is_transaction_end()) - self.assertFalse(Statement(None, "DELETE FROM test").is_transaction_end()) + def test_is_delete(self, MockSQLSanitizer): + self.assertTrue(Statement(MockSQLSanitizer(), "DELETE FROM test").is_delete()) + self.assertTrue(Statement(MockSQLSanitizer(), "delete FROM test").is_delete()) + self.assertFalse(Statement(MockSQLSanitizer(), "SELECT * FROM test").is_delete()) + self.assertFalse(Statement(MockSQLSanitizer(), + "INSERT INTO test (id, val) VALUES (1, 'test')").is_delete()) + + def test_is_insert(self, MockSQLSanitizer): + self.assertTrue(Statement(MockSQLSanitizer(), + "INSERT INTO test (id, val) VALUES (1, 'test')").is_insert()) + self.assertTrue(Statement(MockSQLSanitizer(), + "insert INTO test (id, val) VALUES (1, 'test')").is_insert()) + self.assertFalse(Statement(MockSQLSanitizer(), "SELECT * FROM test").is_insert()) + self.assertFalse(Statement(MockSQLSanitizer(), "DELETE FROM test").is_insert()) + + def test_is_select(self, MockSQLSanitizer): + self.assertTrue(Statement(MockSQLSanitizer(), "SELECT * FROM test").is_select()) + self.assertTrue(Statement(MockSQLSanitizer(), "select * FROM test").is_select()) + self.assertFalse(Statement(MockSQLSanitizer(), "DELETE FROM test").is_select()) + self.assertFalse(Statement(MockSQLSanitizer(), + "INSERT INTO test (id, val) VALUES (1, 'test')").is_select()) + + def test_is_update(self, MockSQLSanitizer): + self.assertTrue(Statement(MockSQLSanitizer(), "UPDATE test SET id = 2").is_update()) + self.assertTrue(Statement(MockSQLSanitizer(), "update test SET id = 2").is_update()) + self.assertFalse(Statement(MockSQLSanitizer(), "SELECT * FROM test").is_update()) + self.assertFalse(Statement(MockSQLSanitizer(), + "INSERT INTO test (id, val) VALUES (1, 'test')").is_update()) + + def test_is_transaction_start(self, MockSQLSanitizer): + self.assertTrue(Statement(MockSQLSanitizer(), "START TRANSACTION").is_transaction_start()) + self.assertTrue(Statement(MockSQLSanitizer(), "start TRANSACTION").is_transaction_start()) + self.assertTrue(Statement(MockSQLSanitizer(), "BEGIN").is_transaction_start()) + self.assertTrue(Statement(MockSQLSanitizer(), "begin").is_transaction_start()) + self.assertFalse(Statement(MockSQLSanitizer(), "SELECT * FROM test").is_transaction_start()) + self.assertFalse(Statement(MockSQLSanitizer(), "DELETE FROM test").is_transaction_start()) + + def test_is_transaction_end(self, MockSQLSanitizer): + self.assertTrue(Statement(MockSQLSanitizer(), "COMMIT").is_transaction_end()) + self.assertTrue(Statement(MockSQLSanitizer(), "commit").is_transaction_end()) + self.assertTrue(Statement(MockSQLSanitizer(), "ROLLBACK").is_transaction_end()) + self.assertTrue(Statement(MockSQLSanitizer(), "rollback").is_transaction_end()) + self.assertFalse(Statement(MockSQLSanitizer(), "SELECT * FROM test").is_transaction_end()) + self.assertFalse(Statement(MockSQLSanitizer(), "DELETE FROM test").is_transaction_end()) From 08a66362335e7e832bb1d366711556a2aed9fe37 Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Wed, 14 Apr 2021 08:05:45 -0400 Subject: [PATCH 077/159] use remove instead of rollback --- src/cs50/sql.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index 10bffd6..0e7ee8a 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -64,7 +64,7 @@ def _execute(self, statement): except sqlalchemy.exc.IntegrityError as exc: _logger.debug(termcolor.colored(str(statement), "yellow")) if self._autocommit: - self._session.execute("ROLLBACK") + self._session.remove() raise ValueError(exc.orig) from None except (sqlalchemy.exc.OperationalError, sqlalchemy.exc.ProgrammingError) as exc: self._session.remove() From 7f9b77c364c5893427db8e27b77427e7d5872f90 Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Wed, 14 Apr 2021 08:06:24 -0400 Subject: [PATCH 078/159] rename _sanitized_statement --- src/cs50/sql.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index 0e7ee8a..4bf82e5 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -20,12 +20,12 @@ def __init__(self, url, **engine_kwargs): self._session = Session(url, **engine_kwargs) dialect = self._session.get_bind().dialect self._is_postgres = dialect.name in {"postgres", "postgresql"} - self._sanitized_statement = statement_factory(dialect) + self._sanitize_statement = statement_factory(dialect) self._autocommit = False def execute(self, sql, *args, **kwargs): """Execute a SQL statement.""" - statement = self._sanitized_statement(sql, *args, **kwargs) + statement = self._sanitize_statement(sql, *args, **kwargs) if statement.is_transaction_start(): self._autocommit = False From 944934fa1048081293c385b7cd988d71cdc1496d Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Wed, 14 Apr 2021 08:21:01 -0400 Subject: [PATCH 079/159] abstract away catch_warnings --- src/cs50/_sql_util.py | 9 +++++++++ src/cs50/sql.py | 8 ++------ 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/cs50/_sql_util.py b/src/cs50/_sql_util.py index 238d979..52538ad 100644 --- a/src/cs50/_sql_util.py +++ b/src/cs50/_sql_util.py @@ -1,6 +1,8 @@ """Utility functions used by sql.py""" +import contextlib import decimal +import warnings def fetch_select_result(result): @@ -17,3 +19,10 @@ def fetch_select_result(result): row[column] = bytes(row[column]) return rows + + +@contextlib.contextmanager +def raise_errors_for_warnings(): + with warnings.catch_warnings(): + warnings.simplefilter("error") + yield diff --git a/src/cs50/sql.py b/src/cs50/sql.py index 4bf82e5..0486214 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -1,14 +1,13 @@ """Wraps SQLAlchemy""" import logging -import warnings import sqlalchemy import termcolor from ._session import Session from ._statement import statement_factory -from ._sql_util import fetch_select_result +from ._sql_util import fetch_select_result, raise_errors_for_warnings _logger = logging.getLogger("cs50") @@ -55,10 +54,7 @@ def execute(self, sql, *args, **kwargs): return ret def _execute(self, statement): - # Catch SQLAlchemy warnings - with warnings.catch_warnings(): - # Raise exceptions for warnings - warnings.simplefilter("error") + with raise_errors_for_warnings(): try: result = self._session.execute(statement) except sqlalchemy.exc.IntegrityError as exc: From 006657e0a50aeba1f80a76a7016dc1b187f7a4d2 Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Wed, 14 Apr 2021 09:48:20 -0400 Subject: [PATCH 080/159] remove BEGIN and COMMIT --- src/cs50/sql.py | 46 ++++++++++++++++++---------------------------- 1 file changed, 18 insertions(+), 28 deletions(-) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index 0486214..9aab897 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -20,24 +20,31 @@ def __init__(self, url, **engine_kwargs): dialect = self._session.get_bind().dialect self._is_postgres = dialect.name in {"postgres", "postgresql"} self._sanitize_statement = statement_factory(dialect) - self._autocommit = False + self._outside_transaction = True def execute(self, sql, *args, **kwargs): """Execute a SQL statement.""" statement = self._sanitize_statement(sql, *args, **kwargs) - if statement.is_transaction_start(): - self._autocommit = False - - if self._autocommit: - self._session.execute("BEGIN") + try: + with raise_errors_for_warnings(): + result = self._session.execute(statement) + except sqlalchemy.exc.IntegrityError as exc: + _logger.debug(termcolor.colored(str(statement), "yellow")) + if self._outside_transaction: + self._session.remove() + raise ValueError(exc.orig) from None + except (sqlalchemy.exc.OperationalError, sqlalchemy.exc.ProgrammingError) as exc: + self._session.remove() + _logger.debug(termcolor.colored(statement, "red")) + raise RuntimeError(exc.orig) from None - result = self._execute(statement) + if statement.is_transaction_start(): + self._outside_transaction = False - if self._autocommit: - self._session.execute("COMMIT") + _logger.debug(termcolor.colored(str(statement), "green")) if statement.is_transaction_end(): - self._autocommit = True + self._outside_transaction = True if statement.is_select(): ret = fetch_select_result(result) @@ -48,28 +55,11 @@ def execute(self, sql, *args, **kwargs): else: ret = True - if self._autocommit: + if self._outside_transaction: self._session.remove() return ret - def _execute(self, statement): - with raise_errors_for_warnings(): - try: - result = self._session.execute(statement) - except sqlalchemy.exc.IntegrityError as exc: - _logger.debug(termcolor.colored(str(statement), "yellow")) - if self._autocommit: - self._session.remove() - raise ValueError(exc.orig) from None - except (sqlalchemy.exc.OperationalError, sqlalchemy.exc.ProgrammingError) as exc: - self._session.remove() - _logger.debug(termcolor.colored(statement, "red")) - raise RuntimeError(exc.orig) from None - - _logger.debug(termcolor.colored(str(statement), "green")) - return result - def _last_row_id_or_none(self, result): if self._is_postgres: return self._get_last_val() From 36fb280771e1985d95bdd75e0410afcee1352035 Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Wed, 14 Apr 2021 10:28:46 -0400 Subject: [PATCH 081/159] Revert "remove BEGIN and COMMIT" This reverts commit 006657e0a50aeba1f80a76a7016dc1b187f7a4d2. --- src/cs50/sql.py | 46 ++++++++++++++++++++++++++++------------------ 1 file changed, 28 insertions(+), 18 deletions(-) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index 9aab897..0486214 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -20,31 +20,24 @@ def __init__(self, url, **engine_kwargs): dialect = self._session.get_bind().dialect self._is_postgres = dialect.name in {"postgres", "postgresql"} self._sanitize_statement = statement_factory(dialect) - self._outside_transaction = True + self._autocommit = False def execute(self, sql, *args, **kwargs): """Execute a SQL statement.""" statement = self._sanitize_statement(sql, *args, **kwargs) - try: - with raise_errors_for_warnings(): - result = self._session.execute(statement) - except sqlalchemy.exc.IntegrityError as exc: - _logger.debug(termcolor.colored(str(statement), "yellow")) - if self._outside_transaction: - self._session.remove() - raise ValueError(exc.orig) from None - except (sqlalchemy.exc.OperationalError, sqlalchemy.exc.ProgrammingError) as exc: - self._session.remove() - _logger.debug(termcolor.colored(statement, "red")) - raise RuntimeError(exc.orig) from None - if statement.is_transaction_start(): - self._outside_transaction = False + self._autocommit = False + + if self._autocommit: + self._session.execute("BEGIN") - _logger.debug(termcolor.colored(str(statement), "green")) + result = self._execute(statement) + + if self._autocommit: + self._session.execute("COMMIT") if statement.is_transaction_end(): - self._outside_transaction = True + self._autocommit = True if statement.is_select(): ret = fetch_select_result(result) @@ -55,11 +48,28 @@ def execute(self, sql, *args, **kwargs): else: ret = True - if self._outside_transaction: + if self._autocommit: self._session.remove() return ret + def _execute(self, statement): + with raise_errors_for_warnings(): + try: + result = self._session.execute(statement) + except sqlalchemy.exc.IntegrityError as exc: + _logger.debug(termcolor.colored(str(statement), "yellow")) + if self._autocommit: + self._session.remove() + raise ValueError(exc.orig) from None + except (sqlalchemy.exc.OperationalError, sqlalchemy.exc.ProgrammingError) as exc: + self._session.remove() + _logger.debug(termcolor.colored(statement, "red")) + raise RuntimeError(exc.orig) from None + + _logger.debug(termcolor.colored(str(statement), "green")) + return result + def _last_row_id_or_none(self, result): if self._is_postgres: return self._get_last_val() From a6668c093cbe005337aae64147582c1bb26e30ef Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Wed, 14 Apr 2021 12:33:50 -0400 Subject: [PATCH 082/159] rename methods --- src/cs50/_statement.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/cs50/_statement.py b/src/cs50/_statement.py index cc4cdb8..2e286e9 100644 --- a/src/cs50/_statement.py +++ b/src/cs50/_statement.py @@ -44,7 +44,7 @@ def __init__(self, sql_sanitizer, sql, *args, **kwargs): self._paramstyle = self._get_paramstyle() self._placeholders = self._get_placeholders() - self._plugin_escaped_params() + self._substitute_markers_with_escaped_params() self._escape_verbatim_colons() def _get_escaped_args(self, args): @@ -100,15 +100,15 @@ def _get_placeholders(self): return placeholders - def _plugin_escaped_params(self): + def _substitute_markers_with_escaped_params(self): if self._paramstyle in {Paramstyle.FORMAT, Paramstyle.QMARK}: - self._plugin_format_or_qmark_params() + self._substitute_format_or_qmark_markers() elif self._paramstyle == Paramstyle.NUMERIC: - self._plugin_numeric_params() + self._substitue_numeric_markers() if self._paramstyle in {Paramstyle.NAMED, Paramstyle.PYFORMAT}: - self._plugin_named_or_pyformat_params() + self._substitute_named_or_pyformat_markers() - def _plugin_format_or_qmark_params(self): + def _substitute_format_or_qmark_markers(self): self._assert_valid_arg_count() for arg_index, token_index in enumerate(self._placeholders.keys()): self._tokens[token_index] = self._args[arg_index] @@ -122,7 +122,7 @@ def _assert_valid_arg_count(self): raise RuntimeError(f"more placeholders ({placeholders}) than values ({args})") - def _plugin_numeric_params(self): + def _substitue_numeric_markers(self): unused_arg_indices = set(range(len(self._args))) for token_index, num in self._placeholders.items(): if num >= len(self._args): @@ -137,7 +137,7 @@ def _plugin_numeric_params(self): raise RuntimeError( f"unused value{'' if len(unused_args) == 1 else 's'} ({unused_args})") - def _plugin_named_or_pyformat_params(self): + def _substitute_named_or_pyformat_markers(self): unused_params = set(self._kwargs.keys()) for token_index, param_name in self._placeholders.items(): if param_name not in self._kwargs: From 1440ae57a6437c2b1c0389958122c5524d43f5b9 Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Wed, 14 Apr 2021 14:57:35 -0400 Subject: [PATCH 083/159] add docstrings --- src/cs50/__init__.py | 2 - src/cs50/_logger.py | 38 ++++++++-- src/cs50/_session.py | 14 ++-- src/cs50/_session_util.py | 10 ++- src/cs50/_sql_sanitizer.py | 25 +++++-- src/cs50/_sql_util.py | 14 +++- src/cs50/_statement.py | 72 ++++++++++++++++++- src/cs50/_statement_util.py | 19 ++++- src/cs50/sql.py | 137 +++++++++++++++++++++++++----------- 9 files changed, 261 insertions(+), 70 deletions(-) diff --git a/src/cs50/__init__.py b/src/cs50/__init__.py index fa07171..e5ec787 100644 --- a/src/cs50/__init__.py +++ b/src/cs50/__init__.py @@ -1,5 +1,3 @@ -"""Exposes API and sets up logging""" - from .cs50 import get_float, get_int, get_string from .sql import SQL from ._logger import _setup_logger diff --git a/src/cs50/_logger.py b/src/cs50/_logger.py index 1307e19..e7b03ca 100644 --- a/src/cs50/_logger.py +++ b/src/cs50/_logger.py @@ -1,4 +1,5 @@ -"""Sets up logging for cs50 library""" +"""Sets up logging for the library. +""" import logging import os.path @@ -9,6 +10,22 @@ import termcolor +def green(msg): + return _colored(msg, "green") + + +def red(msg): + return _colored(msg, "red") + + +def yellow(msg): + return _colored(msg, "yellow") + + +def _colored(msg, color): + return termcolor.colored(str(msg), color) + + def _setup_logger(): _configure_default_logger() _patch_root_handler_format_exception() @@ -17,11 +34,16 @@ def _setup_logger(): def _configure_default_logger(): - """Configure default handler and formatter to prevent flask and werkzeug from adding theirs""" + """Configures a default handler and formatter to prevent flask and werkzeug from adding theirs. + """ + logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.DEBUG) def _patch_root_handler_format_exception(): + """Patches formatException for the root handler to use ``_format_exception``. + """ + try: formatter = logging.root.handlers[0].formatter formatter.formatException = lambda exc_info: _format_exception(*exc_info) @@ -30,6 +52,10 @@ def _patch_root_handler_format_exception(): def _configure_cs50_logger(): + """Disables the cs50 logger by default. Disables logging propagation to prevent messages from + being logged more than once. Sets the logging handler and formatter. + """ + _logger = logging.getLogger("cs50") _logger.disabled = True _logger.setLevel(logging.DEBUG) @@ -52,9 +78,8 @@ def _patch_excepthook(): def _format_exception(type_, value, exc_tb): - """ - Format traceback, darkening entries from global site-packages directories - and user-specific site-packages directory. + """Formats traceback, darkening entries from global site-packages directories and user-specific + site-packages directory. https://stackoverflow.com/a/46071447/5156190 """ @@ -69,6 +94,5 @@ def _format_exception(type_, value, exc_tb): lines += line else: matches = re.search(r"^(\s*)(.*?)(\s*)$", line, re.DOTALL) - lines.append( - matches.group(1) + termcolor.colored(matches.group(2), "yellow") + matches.group(3)) + lines.append(matches.group(1) + yellow(matches.group(2)) + matches.group(3)) return "".join(lines).rstrip() diff --git a/src/cs50/_session.py b/src/cs50/_session.py index c1ea426..f28c30a 100644 --- a/src/cs50/_session.py +++ b/src/cs50/_session.py @@ -1,5 +1,3 @@ -"""Wraps a SQLAlchemy scoped session""" - import sqlalchemy import sqlalchemy.orm @@ -11,7 +9,8 @@ class Session: - """Wraps a SQLAlchemy scoped session""" + """Wraps a SQLAlchemy scoped session. + """ def __init__(self, url, **engine_kwargs): if is_sqlite_url(url): @@ -20,9 +19,16 @@ def __init__(self, url, **engine_kwargs): self._session = create_session(url, **engine_kwargs) def execute(self, statement): - """Converts statement to str and executes it""" + """Converts statement to str and executes it. + + :param statement: The SQL statement to be executed + """ + # pylint: disable=no-member return self._session.execute(sqlalchemy.text(str(statement))) def __getattr__(self, attr): + """Proxies any attributes to the underlying SQLAlchemy scoped session. + """ + return getattr(self._session, attr) diff --git a/src/cs50/_session_util.py b/src/cs50/_session_util.py index ed44eaa..01983b5 100644 --- a/src/cs50/_session_util.py +++ b/src/cs50/_session_util.py @@ -1,4 +1,5 @@ -"""Utility functions used by _session.py""" +"""Utility functions used by _session.py. +""" import os import sqlite3 @@ -49,8 +50,11 @@ def _create_scoped_session(engine): def _disable_auto_begin_commit(dbapi_connection): - # Disable underlying API's own emitting of BEGIN and COMMIT so we can ourselves - # https://docs.sqlalchemy.org/en/13/dialects/sqlite.html#serializable-isolation-savepoints-transactional-ddl + """Disables the underlying API's own emitting of BEGIN and COMMIT so we can support manual + transactions. + https://docs.sqlalchemy.org/en/13/dialects/sqlite.html#serializable-isolation-savepoints-transactional-ddl + """ + dbapi_connection.isolation_level = None diff --git a/src/cs50/_sql_sanitizer.py b/src/cs50/_sql_sanitizer.py index f4ff3e0..17fc5fa 100644 --- a/src/cs50/_sql_sanitizer.py +++ b/src/cs50/_sql_sanitizer.py @@ -1,5 +1,3 @@ -"""Escapes SQL values""" - import datetime import re @@ -8,15 +6,19 @@ class SQLSanitizer: - """Escapes SQL values""" + """Sanitizes SQL values. + """ def __init__(self, dialect): self._dialect = dialect def escape(self, value): - """ - Escapes value using engine's conversion function. + """Escapes value using engine's conversion function. https://docs.sqlalchemy.org/en/latest/core/type_api.html#sqlalchemy.types.TypeEngine.literal_processor + + :param value: The value to be sanitized + + :returns: The sanitized value """ # pylint: disable=too-many-return-statements if isinstance(value, (list, tuple)): @@ -71,13 +73,22 @@ def escape(self, value): raise RuntimeError(f"unsupported value: {value}") def escape_iterable(self, iterable): - """Escapes a collection of values (e.g., list, tuple)""" + """Escapes each value in iterable and joins all the escaped values with ", ", formatted for + SQL's ``IN`` operator. + + :param: An iterable of values to be escaped + + :returns: A comma-separated list of escaped values from ``iterable`` + :rtype: :class:`sqlparse.sql.TokenList` + """ + return sqlparse.sql.TokenList( sqlparse.parse(", ".join([str(self.escape(v)) for v in iterable]))) def escape_verbatim_colon(value): - """Escapes verbatim colon from a value so as it is not confused with a placeholder""" + """Escapes verbatim colon from a value so as it is not confused with a parameter marker. + """ # E.g., ':foo, ":foo, :foo will be replaced with # '\:foo, "\:foo, \:foo respectively diff --git a/src/cs50/_sql_util.py b/src/cs50/_sql_util.py index 52538ad..0b0c27b 100644 --- a/src/cs50/_sql_util.py +++ b/src/cs50/_sql_util.py @@ -1,11 +1,18 @@ -"""Utility functions used by sql.py""" +"""Utility functions used by sql.py. +""" import contextlib import decimal import warnings -def fetch_select_result(result): +def process_select_result(result): + """Converts a SQLAlchemy result to a ``list`` of ``dict`` objects, each of which represents a + row in the result set. + + :param result: A SQLAlchemy result + :type result: :class:`sqlalchemy.engine.Result` + """ rows = [dict(row) for row in result.fetchall()] for row in rows: for column in row: @@ -23,6 +30,9 @@ def fetch_select_result(result): @contextlib.contextmanager def raise_errors_for_warnings(): + """Catches warnings and raises errors instead. + """ + with warnings.catch_warnings(): warnings.simplefilter("error") yield diff --git a/src/cs50/_statement.py b/src/cs50/_statement.py index 2e286e9..79e77d8 100644 --- a/src/cs50/_statement.py +++ b/src/cs50/_statement.py @@ -1,5 +1,3 @@ -"""Parses a SQL statement and replaces the placeholders with the corresponding parameters""" - import collections from ._sql_sanitizer import SQLSanitizer, escape_verbatim_colon @@ -17,6 +15,13 @@ def statement_factory(dialect): + """Creates a sanitizer for ``dialect`` and injects it into ``Statement``, exposing a simpler + interface for ``Statement``. + + :param dialect: a SQLAlchemy dialect + :type dialect: :class:`sqlalchemy.engine.Dialect` + """ + sql_sanitizer = SQLSanitizer(dialect) def statement(sql, *args, **kwargs): @@ -26,9 +31,23 @@ def statement(sql, *args, **kwargs): class Statement: - """Parses a SQL statement and replaces the placeholders with the corresponding parameters""" + """Parses a SQL statement and substitutes any parameter markers with their corresponding + placeholders. + """ def __init__(self, sql_sanitizer, sql, *args, **kwargs): + """ + :param sql_sanitizer: The SQL sanitizer used to sanitize the parameters + :type sql_sanitizer: :class:`_sql_sanitizer.SQLSanitizer` + + :param sql: The SQL statement + :type sql: str + + :param *args: Zero or more positional parameters to be substituted for the parameter markers + + :param *kwargs: Zero or more keyword arguments to be substituted for the parameter markers + """ + if len(args) > 0 and len(kwargs) > 0: raise RuntimeError("cannot pass both positional and named parameters") @@ -54,9 +73,18 @@ def _get_escaped_kwargs(self, kwargs): return {k: self._sql_sanitizer.escape(v) for k, v in kwargs.items()} def _tokenize(self): + """ + :returns: A flattened list of SQLParse tokens that represent the SQL statement + """ + return list(self._statement.flatten()) def _get_operation_keyword(self): + """ + :returns: The operation keyword of the SQL statement (e.g., ``SELECT``, ``DELETE``, etc) + :rtype: str + """ + for token in self._statement: if is_operation_token(token.ttype): token_value = token.value.upper() @@ -69,6 +97,11 @@ def _get_operation_keyword(self): return operation_keyword def _get_paramstyle(self): + """ + :returns: The paramstyle used in the SQL statement (if any) + :rtype: :class:_statement_util.Paramstyle`` + """ + paramstyle = None for token in self._tokens: if is_placeholder(token.ttype): @@ -80,6 +113,11 @@ def _get_paramstyle(self): return paramstyle def _default_paramstyle(self): + """ + :returns: If positional args were passed, returns ``Paramstyle.QMARK``; if keyword arguments + were passed, returns ``Paramstyle.NAMED``; otherwise, returns ``None`` + """ + paramstyle = None if self._args: paramstyle = Paramstyle.QMARK @@ -89,6 +127,12 @@ def _default_paramstyle(self): return paramstyle def _get_placeholders(self): + """ + :returns: A dict that maps the index of each parameter marker in the tokens list to the name + of that parameter marker (if applicable) or ``None`` + :rtype: dict + """ + placeholders = collections.OrderedDict() for index, token in enumerate(self._tokens): if is_placeholder(token.ttype): @@ -109,11 +153,18 @@ def _substitute_markers_with_escaped_params(self): self._substitute_named_or_pyformat_markers() def _substitute_format_or_qmark_markers(self): + """Substitutes format or qmark parameter markers with their corresponding parameters. + """ + self._assert_valid_arg_count() for arg_index, token_index in enumerate(self._placeholders.keys()): self._tokens[token_index] = self._args[arg_index] def _assert_valid_arg_count(self): + """Raises a ``RuntimeError`` if the number of arguments does not match the number of + placeholders. + """ + if len(self._placeholders) != len(self._args): placeholders = get_human_readable_list(self._placeholders.values()) args = get_human_readable_list(self._args) @@ -123,6 +174,10 @@ def _assert_valid_arg_count(self): raise RuntimeError(f"more placeholders ({placeholders}) than values ({args})") def _substitue_numeric_markers(self): + """Substitutes numeric parameter markers with their corresponding parameters. Raises a + ``RuntimeError`` if any parameters are missing or unused. + """ + unused_arg_indices = set(range(len(self._args))) for token_index, num in self._placeholders.items(): if num >= len(self._args): @@ -138,6 +193,10 @@ def _substitue_numeric_markers(self): f"unused value{'' if len(unused_args) == 1 else 's'} ({unused_args})") def _substitute_named_or_pyformat_markers(self): + """Substitutes named or pyformat parameter markers with their corresponding parameters. + Raises a ``RuntimeError`` if any parameters are missing or unused. + """ + unused_params = set(self._kwargs.keys()) for token_index, param_name in self._placeholders.items(): if param_name not in self._kwargs: @@ -152,6 +211,10 @@ def _substitute_named_or_pyformat_markers(self): f"unused value{'' if len(unused_params) == 1 else 's'} ({joined_unused_params})") def _escape_verbatim_colons(self): + """Escapes verbatim colons from string literal and identifier tokens so they aren't treated + as parameter markers. + """ + for token in self._tokens: if is_string_literal(token.ttype) or is_identifier(token.ttype): token.value = escape_verbatim_colon(token.value) @@ -175,4 +238,7 @@ def is_update(self): return self._operation_keyword == "UPDATE" def __str__(self): + """Joins the statement tokens into a string. + """ + return "".join([str(token) for token in self._tokens]) diff --git a/src/cs50/_statement_util.py b/src/cs50/_statement_util.py index 4ef092a..34ca6ff 100644 --- a/src/cs50/_statement_util.py +++ b/src/cs50/_statement_util.py @@ -1,4 +1,5 @@ -"""Utility functions used by _statement.py""" +"""Utility functions used by _statement.py. +""" import enum import re @@ -19,6 +20,9 @@ class Paramstyle(enum.Enum): + """Represents the supported parameter marker styles. + """ + FORMAT = enum.auto() NAMED = enum.auto() NUMERIC = enum.auto() @@ -27,6 +31,15 @@ class Paramstyle(enum.Enum): def format_and_parse(sql): + """Formats and parses a SQL statement. Raises ``RuntimeError`` if ``sql`` represents more than + one statement. + + :param sql: The SQL statement to be formatted and parsed + :type sql: str + + :returns: A list of unflattened SQLParse tokens that represent the parsed statement + """ + formatted_statements = sqlparse.format(sql, strip_comments=True).strip() parsed_statements = sqlparse.parse(formatted_statements) statement_count = len(parsed_statements) @@ -43,6 +56,10 @@ def is_placeholder(ttype): def parse_placeholder(value): + """ + :returns: A tuple of the paramstyle and the name of the parameter marker (if any) or ``None`` + :rtype: tuple + """ if value == "?": return Paramstyle.QMARK, None diff --git a/src/cs50/sql.py b/src/cs50/sql.py index 0486214..974137c 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -1,46 +1,60 @@ -"""Wraps SQLAlchemy""" - import logging import sqlalchemy -import termcolor +from ._logger import green, red, yellow from ._session import Session from ._statement import statement_factory -from ._sql_util import fetch_select_result, raise_errors_for_warnings +from ._sql_util import process_select_result, raise_errors_for_warnings + _logger = logging.getLogger("cs50") class SQL: - """Wraps SQLAlchemy""" + """An API for executing SQL Statements. + """ + + def __init__(self, url): + """ + :param url: The database URL + """ - def __init__(self, url, **engine_kwargs): - self._session = Session(url, **engine_kwargs) - dialect = self._session.get_bind().dialect + self._session = Session(url) + dialect = self._get_dialect() self._is_postgres = dialect.name in {"postgres", "postgresql"} - self._sanitize_statement = statement_factory(dialect) + self._substitute_markers_with_params = statement_factory(dialect) self._autocommit = False + def _get_dialect(self): + return self._session.get_bind().dialect + def execute(self, sql, *args, **kwargs): - """Execute a SQL statement.""" - statement = self._sanitize_statement(sql, *args, **kwargs) - if statement.is_transaction_start(): - self._autocommit = False + """Executes a SQL statement. - if self._autocommit: - self._session.execute("BEGIN") + :param sql: a SQL statement, possibly with parameters markers + :type sql: str + :param *args: zero or more positional arguments to substitute the parameter markers with + :param **kwargs: zero or more keyword arguments to substitute the parameter markers with - result = self._execute(statement) + :returns: For ``SELECT``, a :py:class:`list` of :py:class:`dict` objects, each of which + represents a row in the result set; for ``INSERT``, the primary key of a newly inserted row + (or ``None`` if none); for ``UPDATE``, the number of rows updated; for ``DELETE``, the + number of rows deleted; for other statements, ``True``; on integrity errors, a + :py:class:`ValueError` is raised, on other errors, a :py:class:`RuntimeError` is raised - if self._autocommit: - self._session.execute("COMMIT") + """ - if statement.is_transaction_end(): - self._autocommit = True + statement = self._substitute_markers_with_params(sql, *args, **kwargs) + if statement.is_transaction_start(): + self._disable_autocommit() + + self._begin_transaction_in_autocommit_mode() + result = self._execute(statement) + self._commit_transaction_in_autocommit_mode() if statement.is_select(): - ret = fetch_select_result(result) + ret = process_select_result(result) elif statement.is_insert(): ret = self._last_row_id_or_none(result) elif statement.is_delete() or statement.is_update(): @@ -48,43 +62,84 @@ def execute(self, sql, *args, **kwargs): else: ret = True - if self._autocommit: - self._session.remove() + if statement.is_transaction_end(): + self._enable_autocommit() + self._shutdown_session_in_autocommit_mode() return ret + def _disable_autocommit(self): + self._autocommit = False + + def _begin_transaction_in_autocommit_mode(self): + if self._autocommit: + self._session.execute("BEGIN") + def _execute(self, statement): - with raise_errors_for_warnings(): - try: + """ + :param statement: a SQL statement represented as a ``str`` or a + :class:`_statement.Statement` + + :rtype: :class:`sqlalchemy.engine.Result` + """ + try: + with raise_errors_for_warnings(): result = self._session.execute(statement) - except sqlalchemy.exc.IntegrityError as exc: - _logger.debug(termcolor.colored(str(statement), "yellow")) - if self._autocommit: - self._session.remove() - raise ValueError(exc.orig) from None - except (sqlalchemy.exc.OperationalError, sqlalchemy.exc.ProgrammingError) as exc: - self._session.remove() - _logger.debug(termcolor.colored(statement, "red")) - raise RuntimeError(exc.orig) from None - - _logger.debug(termcolor.colored(str(statement), "green")) - return result + # E.g., failed constraint + except sqlalchemy.exc.IntegrityError as exc: + _logger.debug(yellow(statement)) + self._shutdown_session_in_autocommit_mode() + raise ValueError(exc.orig) from None + # E.g., connection error or syntax error + except (sqlalchemy.exc.OperationalError, sqlalchemy.exc.ProgrammingError) as exc: + self._shutdown_session() + _logger.debug(red(statement)) + raise RuntimeError(exc.orig) from None + + _logger.debug(green(statement)) + return result + + def _shutdown_session_in_autocommit_mode(self): + if self._autocommit: + self._shutdown_session() + + def _shutdown_session(self): + self._session.remove() + + def _commit_transaction_in_autocommit_mode(self): + if self._autocommit: + self._session.execute("COMMIT") + + def _enable_autocommit(self): + self._autocommit = True def _last_row_id_or_none(self, result): + """ + :param result: A SQLAlchemy result object + :type result: :class:`sqlalchemy.engine.Result` + + :returns: The ID of the last inserted row or ``None`` + """ + if self._is_postgres: - return self._get_last_val() + return self._postgres_lastval() return result.lastrowid if result.rowcount == 1 else None - def _get_last_val(self): + def _postgres_lastval(self): try: return self._session.execute("SELECT LASTVAL()").first()[0] except sqlalchemy.exc.OperationalError: # If lastval is not yet defined in this session return None def init_app(self, app): - """Registers a teardown_appcontext listener to remove session and enables logging""" + """Enables logging and registers a ``teardown_appcontext`` listener to remove the session. + + :param app: a Flask application instance + :type app: :class:`flask.Flask` + """ + @app.teardown_appcontext def _(_): - self._session.remove() + self._shutdown_session() logging.getLogger("cs50").disabled = False From 9e9cb0b34ebd44394f7406a97a2113adae015c4c Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Thu, 15 Apr 2021 19:06:14 -0400 Subject: [PATCH 084/159] update cs50 docstrings --- src/cs50/cs50.py | 37 +++++++++++++++++++++---------------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/src/cs50/cs50.py b/src/cs50/cs50.py index 30d3515..11fa20a 100644 --- a/src/cs50/cs50.py +++ b/src/cs50/cs50.py @@ -5,11 +5,14 @@ def get_float(prompt): + """Reads a line of text from standard input and returns the equivalent float as precisely as + possible; if text does not represent a float, user is prompted to retry. If line can't be read, + returns None. + + :type prompt: str + """ - Read a line of text from standard input and return the equivalent float - as precisely as possible; if text does not represent a double, user is - prompted to retry. If line can't be read, return None. - """ + while True: try: return _get_float(prompt) @@ -29,11 +32,12 @@ def _get_float(prompt): def get_int(prompt): + """Reads a line of text from standard input and return the equivalent int; if text does not + represent an int, user is prompted to retry. If line can't be read, returns None. + + :type prompt: str """ - Read a line of text from standard input and return the equivalent int; - if text does not represent an int, user is prompted to retry. If line - can't be read, return None. - """ + while True: try: return _get_int(prompt) @@ -53,12 +57,13 @@ def _get_int(prompt): def get_string(prompt): + """Reads a line of text from standard input and returns it as a string, sans trailing line + ending. Supports CR (\r), LF (\n), and CRLF (\r\n) as line endings. If user inputs only a line + ending, returns "", not None. Returns None upon error or no input whatsoever (i.e., just EOF). + + :type prompt: str """ - Read a line of text from standard input and return it as a string, - sans trailing line ending. Supports CR (\r), LF (\n), and CRLF (\r\n) - as line endings. If user inputs only a line ending, returns "", not None. - Returns None upon error or no input whatsoever (i.e., just EOF). - """ + if not isinstance(prompt, str): raise TypeError("prompt must be of type str") @@ -73,8 +78,7 @@ def _get_input(prompt): class _flushfile(): - """ - Disable buffering for standard output and standard error. + """ Disable buffering for standard output and standard error. http://stackoverflow.com/a/231216 """ @@ -91,7 +95,8 @@ def write(self, data): def disable_output_buffering(): - """Disables output buffering to prevent prompts from being buffered""" + """Disables output buffering to prevent prompts from being buffered. + """ sys.stderr = _flushfile(sys.stderr) sys.stdout = _flushfile(sys.stdout) From 789bb4015ae38c1e58942f92dd54a6ba5ce65200 Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Thu, 15 Apr 2021 21:15:41 -0400 Subject: [PATCH 085/159] use assertAlmostEqual --- tests/test_cs50.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_cs50.py b/tests/test_cs50.py index dd0f14b..9a0faca 100644 --- a/tests/test_cs50.py +++ b/tests/test_cs50.py @@ -1,4 +1,3 @@ -import math import sys import unittest @@ -95,7 +94,7 @@ def test_get_float_valid_input(self): def assert_equal(return_value, expected_value): with patch("cs50.cs50.get_string", return_value=return_value) as mock_get_string: f = _get_float("Answer: ") - self.assertTrue(math.isclose(f, expected_value)) + self.assertAlmostEqual(f, expected_value) mock_get_string.assert_called_with("Answer: ") values = [ From 05a4d9df5f59415592dcad413ed2c91a63cf7bfb Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Fri, 16 Apr 2021 02:15:05 -0400 Subject: [PATCH 086/159] enable autocommit by default --- src/cs50/sql.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index 974137c..c38ce25 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -24,7 +24,7 @@ def __init__(self, url): dialect = self._get_dialect() self._is_postgres = dialect.name in {"postgres", "postgresql"} self._substitute_markers_with_params = statement_factory(dialect) - self._autocommit = False + self._autocommit = True def _get_dialect(self): return self._session.get_bind().dialect From 7bf7e1688c5162e7e1e1004841639d62fa38c4c7 Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Fri, 16 Apr 2021 09:49:31 -0400 Subject: [PATCH 087/159] avoid using sessions --- setup.py | 2 +- src/cs50/_engine.py | 66 +++++++++++ src/cs50/_engine_util.py | 43 +++++++ src/cs50/_session.py | 34 ------ src/cs50/_session_util.py | 68 ----------- src/cs50/_sql_util.py | 13 +++ src/cs50/_statement.py | 2 +- src/cs50/sql.py | 132 ++++++++------------- tests/test_statement.py | 234 -------------------------------------- 9 files changed, 174 insertions(+), 420 deletions(-) create mode 100644 src/cs50/_engine.py create mode 100644 src/cs50/_engine_util.py delete mode 100644 src/cs50/_session.py delete mode 100644 src/cs50/_session_util.py delete mode 100644 tests/test_statement.py diff --git a/setup.py b/setup.py index a5b8fb7..de271f8 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ "Topic :: Software Development :: Libraries :: Python Modules" ], description="CS50 library for Python", - install_requires=["Flask>=1.0", "SQLAlchemy", "sqlparse", "termcolor"], + install_requires=["Flask>=1.0", "SQLAlchemy<2", "sqlparse", "termcolor"], keywords="cs50", name="cs50", package_dir={"": "src"}, diff --git a/src/cs50/_engine.py b/src/cs50/_engine.py new file mode 100644 index 0000000..d74992c --- /dev/null +++ b/src/cs50/_engine.py @@ -0,0 +1,66 @@ +import threading + +from ._engine_util import create_engine + + +thread_local_data = threading.local() + + +class Engine: + """Wraps a SQLAlchemy engine. + """ + + def __init__(self, url): + self._engine = create_engine(url) + + def get_transaction_connection(self): + """ + :returns: A new connection with autocommit disabled (to be used for transactions). + """ + + _thread_local_connections()[self] = self._engine.connect().execution_options( + autocommit=False) + return self.get_existing_transaction_connection() + + def get_connection(self): + """ + :returns: A new connection with autocommit enabled + """ + + return self._engine.connect().execution_options(autocommit=True) + + def get_existing_transaction_connection(self): + """ + :returns: The transaction connection bound to this Engine instance, if one exists, or None. + """ + + return _thread_local_connections().get(self) + + def close_transaction_connection(self): + """Closes the transaction connection bound to this Engine instance, if one exists and + removes it. + """ + + connection = self.get_existing_transaction_connection() + if connection: + connection.close() + del _thread_local_connections()[self] + + def is_postgres(self): + return self._engine.dialect.name in {"postgres", "postgresql"} + + def __getattr__(self, attr): + return getattr(self._engine, attr) + +def _thread_local_connections(): + """ + :returns: A thread local dict to keep track of transaction connection. If one does not exist, + creates one. + """ + + try: + connections = thread_local_data.connections + except AttributeError: + connections = thread_local_data.connections = {} + + return connections diff --git a/src/cs50/_engine_util.py b/src/cs50/_engine_util.py new file mode 100644 index 0000000..c55b8f2 --- /dev/null +++ b/src/cs50/_engine_util.py @@ -0,0 +1,43 @@ +"""Utility functions used by _session.py. +""" + +import os +import sqlite3 + +import sqlalchemy + +sqlite_url_prefix = "sqlite:///" + + +def create_engine(url, **kwargs): + """Creates a new SQLAlchemy engine. If ``url`` is a URL for a SQLite database, makes sure that + the SQLite file exits and enables foreign key constraints. + """ + + try: + engine = sqlalchemy.create_engine(url, **kwargs) + except sqlalchemy.exc.ArgumentError: + raise RuntimeError(f"invalid URL: {url}") from None + + if _is_sqlite_url(url): + _assert_sqlite_file_exists(url) + sqlalchemy.event.listen(engine, "connect", _enable_sqlite_foreign_key_constraints) + + return engine + +def _is_sqlite_url(url): + return url.startswith(sqlite_url_prefix) + + +def _assert_sqlite_file_exists(url): + path = url[len(sqlite_url_prefix):] + if not os.path.exists(path): + raise RuntimeError(f"does not exist: {path}") + if not os.path.isfile(path): + raise RuntimeError(f"not a file: {path}") + + +def _enable_sqlite_foreign_key_constraints(dbapi_connection, _): + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() diff --git a/src/cs50/_session.py b/src/cs50/_session.py deleted file mode 100644 index f28c30a..0000000 --- a/src/cs50/_session.py +++ /dev/null @@ -1,34 +0,0 @@ -import sqlalchemy -import sqlalchemy.orm - -from ._session_util import ( - assert_sqlite_file_exists, - create_session, - is_sqlite_url, -) - - -class Session: - """Wraps a SQLAlchemy scoped session. - """ - - def __init__(self, url, **engine_kwargs): - if is_sqlite_url(url): - assert_sqlite_file_exists(url) - - self._session = create_session(url, **engine_kwargs) - - def execute(self, statement): - """Converts statement to str and executes it. - - :param statement: The SQL statement to be executed - """ - - # pylint: disable=no-member - return self._session.execute(sqlalchemy.text(str(statement))) - - def __getattr__(self, attr): - """Proxies any attributes to the underlying SQLAlchemy scoped session. - """ - - return getattr(self._session, attr) diff --git a/src/cs50/_session_util.py b/src/cs50/_session_util.py deleted file mode 100644 index 01983b5..0000000 --- a/src/cs50/_session_util.py +++ /dev/null @@ -1,68 +0,0 @@ -"""Utility functions used by _session.py. -""" - -import os -import sqlite3 - -import sqlalchemy - - -def is_sqlite_url(url): - return url.startswith("sqlite:///") - - -def assert_sqlite_file_exists(url): - path = url[len("sqlite:///"):] - if not os.path.exists(path): - raise RuntimeError(f"does not exist: {path}") - if not os.path.isfile(path): - raise RuntimeError(f"not a file: {path}") - - -def create_session(url, **engine_kwargs): - engine = _create_engine(url, **engine_kwargs) - _setup_on_connect(engine) - return _create_scoped_session(engine) - - -def _create_engine(url, **kwargs): - try: - engine = sqlalchemy.create_engine(url, **kwargs) - except sqlalchemy.exc.ArgumentError: - raise RuntimeError(f"invalid URL: {url}") from None - - engine.execution_options(autocommit=False) - return engine - - -def _setup_on_connect(engine): - def connect(dbapi_connection, _): - _disable_auto_begin_commit(dbapi_connection) - if _is_sqlite_connection(dbapi_connection): - _enable_sqlite_foreign_key_constraints(dbapi_connection) - - sqlalchemy.event.listen(engine, "connect", connect) - - -def _create_scoped_session(engine): - session_factory = sqlalchemy.orm.sessionmaker(bind=engine) - return sqlalchemy.orm.scoping.scoped_session(session_factory) - - -def _disable_auto_begin_commit(dbapi_connection): - """Disables the underlying API's own emitting of BEGIN and COMMIT so we can support manual - transactions. - https://docs.sqlalchemy.org/en/13/dialects/sqlite.html#serializable-isolation-savepoints-transactional-ddl - """ - - dbapi_connection.isolation_level = None - - -def _is_sqlite_connection(dbapi_connection): - return isinstance(dbapi_connection, sqlite3.Connection) - - -def _enable_sqlite_foreign_key_constraints(dbapi_connection): - cursor = dbapi_connection.cursor() - cursor.execute("PRAGMA foreign_keys=ON") - cursor.close() diff --git a/src/cs50/_sql_util.py b/src/cs50/_sql_util.py index 0b0c27b..2dbfecf 100644 --- a/src/cs50/_sql_util.py +++ b/src/cs50/_sql_util.py @@ -5,6 +5,8 @@ import decimal import warnings +import sqlalchemy + def process_select_result(result): """Converts a SQLAlchemy result to a ``list`` of ``dict`` objects, each of which represents a @@ -36,3 +38,14 @@ def raise_errors_for_warnings(): with warnings.catch_warnings(): warnings.simplefilter("error") yield + + +def postgres_lastval(connection): + """ + :returns: The ID of the last inserted row, if defined in this session, or None + """ + + try: + return connection.execute("SELECT LASTVAL()").first()[0] + except sqlalchemy.exc.OperationalError: + return None diff --git a/src/cs50/_statement.py b/src/cs50/_statement.py index 79e77d8..2de956a 100644 --- a/src/cs50/_statement.py +++ b/src/cs50/_statement.py @@ -64,7 +64,7 @@ def __init__(self, sql_sanitizer, sql, *args, **kwargs): self._paramstyle = self._get_paramstyle() self._placeholders = self._get_placeholders() self._substitute_markers_with_escaped_params() - self._escape_verbatim_colons() + # self._escape_verbatim_colons() def _get_escaped_args(self, args): return [self._sql_sanitizer.escape(arg) for arg in args] diff --git a/src/cs50/sql.py b/src/cs50/sql.py index c38ce25..d32c319 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -3,9 +3,9 @@ import sqlalchemy from ._logger import green, red, yellow -from ._session import Session +from ._engine import Engine from ._statement import statement_factory -from ._sql_util import process_select_result, raise_errors_for_warnings +from ._sql_util import postgres_lastval, process_select_result, raise_errors_for_warnings _logger = logging.getLogger("cs50") @@ -20,14 +20,8 @@ def __init__(self, url): :param url: The database URL """ - self._session = Session(url) - dialect = self._get_dialect() - self._is_postgres = dialect.name in {"postgres", "postgresql"} - self._substitute_markers_with_params = statement_factory(dialect) - self._autocommit = True - - def _get_dialect(self): - return self._session.get_bind().dialect + self._engine = Engine(url) + self._substitute_markers_with_params = statement_factory(self._engine.dialect) def execute(self, sql, *args, **kwargs): """Executes a SQL statement. @@ -46,73 +40,52 @@ def execute(self, sql, *args, **kwargs): """ statement = self._substitute_markers_with_params(sql, *args, **kwargs) - if statement.is_transaction_start(): - self._disable_autocommit() - - self._begin_transaction_in_autocommit_mode() - result = self._execute(statement) - self._commit_transaction_in_autocommit_mode() - - if statement.is_select(): - ret = process_select_result(result) - elif statement.is_insert(): - ret = self._last_row_id_or_none(result) - elif statement.is_delete() or statement.is_update(): - ret = result.rowcount + connection = self._engine.get_existing_transaction_connection() + if connection is None: + if statement.is_transaction_start(): + connection = self._engine.get_transaction_connection() + else: + connection = self._engine.get_connection() + elif statement.is_transaction_start(): + raise RuntimeError("nested transactions are not supported") + + return self._execute(statement, connection) + + def _execute(self, statement, connection): + with raise_errors_for_warnings(): + try: + result = connection.execute(str(statement)) + # E.g., failed constraint + except sqlalchemy.exc.IntegrityError as exc: + _logger.debug(yellow(statement)) + if self._engine.get_existing_transaction_connection() is None: + connection.close() + raise ValueError(exc.orig) from None + # E.g., connection error or syntax error + except (sqlalchemy.exc.OperationalError, sqlalchemy.exc.ProgrammingError) as exc: + connection.close() + _logger.debug(red(statement)) + raise RuntimeError(exc.orig) from None + + _logger.debug(green(statement)) + + if statement.is_select(): + ret = process_select_result(result) + elif statement.is_insert(): + ret = self._last_row_id_or_none(result) + elif statement.is_delete() or statement.is_update(): + ret = result.rowcount + else: + ret = True + + if self._engine.get_existing_transaction_connection(): + if statement.is_transaction_end(): + self._engine.close_transaction_connection() else: - ret = True - - if statement.is_transaction_end(): - self._enable_autocommit() + connection.close() - self._shutdown_session_in_autocommit_mode() return ret - def _disable_autocommit(self): - self._autocommit = False - - def _begin_transaction_in_autocommit_mode(self): - if self._autocommit: - self._session.execute("BEGIN") - - def _execute(self, statement): - """ - :param statement: a SQL statement represented as a ``str`` or a - :class:`_statement.Statement` - - :rtype: :class:`sqlalchemy.engine.Result` - """ - try: - with raise_errors_for_warnings(): - result = self._session.execute(statement) - # E.g., failed constraint - except sqlalchemy.exc.IntegrityError as exc: - _logger.debug(yellow(statement)) - self._shutdown_session_in_autocommit_mode() - raise ValueError(exc.orig) from None - # E.g., connection error or syntax error - except (sqlalchemy.exc.OperationalError, sqlalchemy.exc.ProgrammingError) as exc: - self._shutdown_session() - _logger.debug(red(statement)) - raise RuntimeError(exc.orig) from None - - _logger.debug(green(statement)) - return result - - def _shutdown_session_in_autocommit_mode(self): - if self._autocommit: - self._shutdown_session() - - def _shutdown_session(self): - self._session.remove() - - def _commit_transaction_in_autocommit_mode(self): - if self._autocommit: - self._session.execute("COMMIT") - - def _enable_autocommit(self): - self._autocommit = True - def _last_row_id_or_none(self, result): """ :param result: A SQLAlchemy result object @@ -121,16 +94,10 @@ def _last_row_id_or_none(self, result): :returns: The ID of the last inserted row or ``None`` """ - if self._is_postgres: - return self._postgres_lastval() + if self._engine.is_postgres(): + return postgres_lastval(result.connection) return result.lastrowid if result.rowcount == 1 else None - def _postgres_lastval(self): - try: - return self._session.execute("SELECT LASTVAL()").first()[0] - except sqlalchemy.exc.OperationalError: # If lastval is not yet defined in this session - return None - def init_app(self, app): """Enables logging and registers a ``teardown_appcontext`` listener to remove the session. @@ -140,6 +107,7 @@ def init_app(self, app): @app.teardown_appcontext def _(_): - self._shutdown_session() + self._engine.close_transaction_connection() + logging.getLogger("cs50").disabled = False diff --git a/tests/test_statement.py b/tests/test_statement.py deleted file mode 100644 index 91261cd..0000000 --- a/tests/test_statement.py +++ /dev/null @@ -1,234 +0,0 @@ -import unittest - -from unittest.mock import patch - -from cs50._statement import Statement -from cs50._sql_sanitizer import SQLSanitizer - - -@patch.object(SQLSanitizer, "escape", return_value="test") -class TestStatement(unittest.TestCase): - # TODO assert correct exception messages - def test_mutex_args_and_kwargs(self, MockSQLSanitizer): - with self.assertRaises(RuntimeError): - Statement(MockSQLSanitizer(), "SELECT * FROM test WHERE id = ? AND val = :val", 1, val="test") - - with self.assertRaises(RuntimeError): - Statement(MockSQLSanitizer(), "SELECT * FROM test", "test", 1, 2, foo="foo", bar="bar") - - @patch.object(Statement, "_escape_verbatim_colons") - def test_valid_qmark_count(self, MockSQLSanitizer, *_): - Statement(MockSQLSanitizer(), "SELECT * FROM test WHERE id = ?", 1) - Statement(MockSQLSanitizer(), "SELECT * FROM test WHERE id = ? and val = ?", 1, 'test') - Statement(MockSQLSanitizer(), - "INSERT INTO test (id, val, is_valid) VALUES (?, ?, ?)", 1, 'test', True) - - @patch.object(Statement, "_escape_verbatim_colons") - def test_invalid_qmark_count(self, MockSQLSanitizer, *_): - def assert_invalid_count(sql, *args): - with self.assertRaises(RuntimeError, msg=f"{sql} {str(args)}"): - Statement(MockSQLSanitizer(), sql, *args) - - statements = [ - ("SELECT * FROM test WHERE id = ?", ()), - ("SELECT * FROM test WHERE id = ?", (1, "test")), - ("SELECT * FROM test WHERE id = ? AND val = ?", (1,)), - ("SELECT * FROM test WHERE id = ? AND val = ?", ()), - ("SELECT * FROM test WHERE id = ? AND val = ?", (1, "test", True)), - ] - - for sql, args in statements: - assert_invalid_count(sql, *args) - - @patch.object(Statement, "_escape_verbatim_colons") - def test_valid_format_count(self, MockSQLSanitizer, *_): - Statement(MockSQLSanitizer(), "SELECT * FROM test WHERE id = %s", 1) - Statement(MockSQLSanitizer(), "SELECT * FROM test WHERE id = %s and val = %s", 1, 'test') - Statement(MockSQLSanitizer(), - "INSERT INTO test (id, val, is_valid) VALUES (%s, %s, %s)", 1, 'test', True) - - @patch.object(Statement, "_escape_verbatim_colons") - def test_invalid_format_count(self, MockSQLSanitizer, *_): - def assert_invalid_count(sql, *args): - with self.assertRaises(RuntimeError, msg=f"{sql} {str(args)}"): - Statement(MockSQLSanitizer(), sql, *args) - - statements = [ - ("SELECT * FROM test WHERE id = %s", ()), - ("SELECT * FROM test WHERE id = %s", (1, "test")), - ("SELECT * FROM test WHERE id = %s AND val = ?", (1,)), - ("SELECT * FROM test WHERE id = %s AND val = ?", ()), - ("SELECT * FROM test WHERE id = %s AND val = ?", (1, "test", True)), - ] - - for sql, args in statements: - assert_invalid_count(sql, *args) - - @patch.object(Statement, "_escape_verbatim_colons") - def test_missing_numeric(self, MockSQLSanitizer, *_): - def assert_missing_numeric(sql, *args): - with self.assertRaises(RuntimeError, msg=f"{sql} {str(args)}"): - Statement(MockSQLSanitizer(), sql, *args) - - statements = [ - ("SELECT * FROM test WHERE id = :1", ()), - ("SELECT * FROM test WHERE id = :1 AND val = :2", ()), - ("SELECT * FROM test WHERE id = :1 AND val = :2", (1,)), - ("SELECT * FROM test WHERE id = :1 AND val = :2 AND is_valid = :3", ()), - ("SELECT * FROM test WHERE id = :1 AND val = :2 AND is_valid = :3", (1,)), - ("SELECT * FROM test WHERE id = :1 AND val = :2 AND is_valid = :3", (1, "test")), - ] - - for sql, args in statements: - assert_missing_numeric(sql, *args) - - @patch.object(Statement, "_escape_verbatim_colons") - def test_unused_numeric(self, MockSQLSanitizer, *_): - def assert_unused_numeric(sql, *args): - with self.assertRaises(RuntimeError, msg=f"{sql} {str(args)}"): - Statement(MockSQLSanitizer(), sql, *args) - - statements = [ - ("SELECT * FROM test WHERE id = :1", (1, "test")), - ("SELECT * FROM test WHERE id = :1", (1, "test", True)), - ("SELECT * FROM test WHERE id = :1 AND val = :2", (1, "test", True)), - ] - - for sql, args in statements: - assert_unused_numeric(sql, *args) - - @patch.object(Statement, "_escape_verbatim_colons") - def test_missing_named(self, MockSQLSanitizer, *_): - def assert_missing_named(sql, **kwargs): - with self.assertRaises(RuntimeError, msg=f"{sql} {str(kwargs)}"): - Statement(MockSQLSanitizer(), sql, **kwargs) - - statements = [ - ("SELECT * FROM test WHERE id = :id", {}), - ("SELECT * FROM test WHERE id = :id AND val = :val", {}), - ("SELECT * FROM test WHERE id = :id AND val = :val", {"id": 1}), - ("SELECT * FROM test WHERE id = :id AND val = :val AND is_valid = :is_valid", {}), - ("SELECT * FROM test WHERE id = :id AND val = :val AND is_valid = :is_valid", - {"id": 1}), - ("SELECT * FROM test WHERE id = :id AND val = :val AND is_valid = :is_valid", - {"id": 1, "val": "test"}), - ] - - for sql, kwargs in statements: - assert_missing_named(sql, **kwargs) - - @patch.object(Statement, "_escape_verbatim_colons") - def test_unused_named(self, MockSQLSanitizer, *_): - def assert_unused_named(sql, **kwargs): - with self.assertRaises(RuntimeError, msg=f"{sql} {str(kwargs)}"): - Statement(MockSQLSanitizer(), sql, **kwargs) - - statements = [ - ("SELECT * FROM test WHERE id = :id", {"id": 1, "val": "test"}), - ("SELECT * FROM test WHERE id = :id", {"id": 1, "val": "test", "is_valid": True}), - ("SELECT * FROM test WHERE id = :id AND val = :val", - {"id": 1, "val": "test", "is_valid": True}), - ] - - for sql, kwargs in statements: - assert_unused_named(sql, **kwargs) - - @patch.object(Statement, "_escape_verbatim_colons") - def test_missing_pyformat(self, MockSQLSanitizer, *_): - def assert_missing_pyformat(sql, **kwargs): - with self.assertRaises(RuntimeError, msg=f"{sql} {str(kwargs)}"): - Statement(MockSQLSanitizer(), sql, **kwargs) - - statements = [ - ("SELECT * FROM test WHERE id = %(id)s", {}), - ("SELECT * FROM test WHERE id = %(id)s AND val = %(val)s", {}), - ("SELECT * FROM test WHERE id = %(id)s AND val = %(val)s", {"id": 1}), - ("SELECT * FROM test WHERE id = %(id)s AND val = %(val)s AND is_valid = %(is_valid)s", {}), - ("SELECT * FROM test WHERE id = %(id)s AND val = %(val)s AND is_valid = %(is_valid)s", - {"id": 1}), - ("SELECT * FROM test WHERE id = %(id)s AND val = %(val)s AND is_valid = %(is_valid)s", - {"id": 1, "val": "test"}), - ] - - for sql, kwargs in statements: - assert_missing_pyformat(sql, **kwargs) - - @patch.object(Statement, "_escape_verbatim_colons") - def test_unused_pyformat(self, MockSQLSanitizer, *_): - def assert_unused_pyformat(sql, **kwargs): - with self.assertRaises(RuntimeError, msg=f"{sql} {str(kwargs)}"): - Statement(MockSQLSanitizer(), sql, **kwargs) - - statements = [ - ("SELECT * FROM test WHERE id = %(id)s", {"id": 1, "val": "test"}), - ("SELECT * FROM test WHERE id = %(id)s", {"id": 1, "val": "test", "is_valid": True}), - ("SELECT * FROM test WHERE id = %(id)s AND val = %(val)s", - {"id": 1, "val": "test", "is_valid": True}), - ] - - for sql, kwargs in statements: - assert_unused_pyformat(sql, **kwargs) - - def test_multiple_statements(self, MockSQLSanitizer): - def assert_raises_runtimeerror(sql): - with self.assertRaises(RuntimeError): - Statement(MockSQLSanitizer(), sql) - - statements = [ - "SELECT 1; SELECT 2;", - "SELECT 1; SELECT 2", - "SELECT 1; SELECT 2; SELECT 3", - "SELECT 1; SELECT 2; SELECT 3;", - "SELECT 1;SELECT 2", - "select 1; select 2", - "select 1;select 2", - "DELETE FROM test; SELECT * FROM test", - ] - - for sql in statements: - assert_raises_runtimeerror(sql) - - def test_is_delete(self, MockSQLSanitizer): - self.assertTrue(Statement(MockSQLSanitizer(), "DELETE FROM test").is_delete()) - self.assertTrue(Statement(MockSQLSanitizer(), "delete FROM test").is_delete()) - self.assertFalse(Statement(MockSQLSanitizer(), "SELECT * FROM test").is_delete()) - self.assertFalse(Statement(MockSQLSanitizer(), - "INSERT INTO test (id, val) VALUES (1, 'test')").is_delete()) - - def test_is_insert(self, MockSQLSanitizer): - self.assertTrue(Statement(MockSQLSanitizer(), - "INSERT INTO test (id, val) VALUES (1, 'test')").is_insert()) - self.assertTrue(Statement(MockSQLSanitizer(), - "insert INTO test (id, val) VALUES (1, 'test')").is_insert()) - self.assertFalse(Statement(MockSQLSanitizer(), "SELECT * FROM test").is_insert()) - self.assertFalse(Statement(MockSQLSanitizer(), "DELETE FROM test").is_insert()) - - def test_is_select(self, MockSQLSanitizer): - self.assertTrue(Statement(MockSQLSanitizer(), "SELECT * FROM test").is_select()) - self.assertTrue(Statement(MockSQLSanitizer(), "select * FROM test").is_select()) - self.assertFalse(Statement(MockSQLSanitizer(), "DELETE FROM test").is_select()) - self.assertFalse(Statement(MockSQLSanitizer(), - "INSERT INTO test (id, val) VALUES (1, 'test')").is_select()) - - def test_is_update(self, MockSQLSanitizer): - self.assertTrue(Statement(MockSQLSanitizer(), "UPDATE test SET id = 2").is_update()) - self.assertTrue(Statement(MockSQLSanitizer(), "update test SET id = 2").is_update()) - self.assertFalse(Statement(MockSQLSanitizer(), "SELECT * FROM test").is_update()) - self.assertFalse(Statement(MockSQLSanitizer(), - "INSERT INTO test (id, val) VALUES (1, 'test')").is_update()) - - def test_is_transaction_start(self, MockSQLSanitizer): - self.assertTrue(Statement(MockSQLSanitizer(), "START TRANSACTION").is_transaction_start()) - self.assertTrue(Statement(MockSQLSanitizer(), "start TRANSACTION").is_transaction_start()) - self.assertTrue(Statement(MockSQLSanitizer(), "BEGIN").is_transaction_start()) - self.assertTrue(Statement(MockSQLSanitizer(), "begin").is_transaction_start()) - self.assertFalse(Statement(MockSQLSanitizer(), "SELECT * FROM test").is_transaction_start()) - self.assertFalse(Statement(MockSQLSanitizer(), "DELETE FROM test").is_transaction_start()) - - def test_is_transaction_end(self, MockSQLSanitizer): - self.assertTrue(Statement(MockSQLSanitizer(), "COMMIT").is_transaction_end()) - self.assertTrue(Statement(MockSQLSanitizer(), "commit").is_transaction_end()) - self.assertTrue(Statement(MockSQLSanitizer(), "ROLLBACK").is_transaction_end()) - self.assertTrue(Statement(MockSQLSanitizer(), "rollback").is_transaction_end()) - self.assertFalse(Statement(MockSQLSanitizer(), "SELECT * FROM test").is_transaction_end()) - self.assertFalse(Statement(MockSQLSanitizer(), "DELETE FROM test").is_transaction_end()) From 0674b7c086946d0c87a912482934b4ffabaa1c04 Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Fri, 16 Apr 2021 10:39:53 -0400 Subject: [PATCH 088/159] delete transaction connection on failure --- src/cs50/sql.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index d32c319..64d30e3 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -63,7 +63,10 @@ def _execute(self, statement, connection): raise ValueError(exc.orig) from None # E.g., connection error or syntax error except (sqlalchemy.exc.OperationalError, sqlalchemy.exc.ProgrammingError) as exc: - connection.close() + if self._engine.get_existing_transaction_connection(): + self._engine.close_transaction_connection() + else: + connection.close() _logger.debug(red(statement)) raise RuntimeError(exc.orig) from None From 8dfb7068961015da1aa9cb52fd23e63f3c8863aa Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Tue, 20 Jul 2021 11:29:10 -0400 Subject: [PATCH 089/159] Deploy with GitHub Actions --- .github/workflows/main.yml | 46 ++++++++++++++++++++++++++++++++++++++ .gitignore | 1 + .travis.yml | 30 ------------------------- setup.py | 2 +- tests/sql.py | 4 ++-- 5 files changed, 50 insertions(+), 33 deletions(-) create mode 100644 .github/workflows/main.yml delete mode 100644 .travis.yml diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml new file mode 100644 index 0000000..0eb0e2c --- /dev/null +++ b/.github/workflows/main.yml @@ -0,0 +1,46 @@ +on: push +jobs: + deploy: + runs-on: ubuntu-latest + services: + mysql: + image: mysql + env: + MYSQL_DATABASE: test + MYSQL_ALLOW_EMPTY_PASSWORD: yes + options: --health-cmd="mysqladmin ping" --health-interval=10s --health-timeout=5s --health-retries=3 + ports: + - 3306:3306 + postgres: + image: postgres + env: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + POSTGRES_DB: test + ports: + - 5432:5432 + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v2 + with: + python-version: '3.6' + - name: Setup databases + run: | + python setup.py install + pip install mysqlclient + pip install psycopg2-binary + touch test.db test1.db + - name: Run tests + run: python tests/sql.py + - name: Install pypa/build + run: | + python -m pip install build --user + - name: Build a binary wheel and a source tarball + run: | + python -m build --sdist --wheel --outdir dist/ . + - name: Deploy to PyPI + if: ${{ github.ref == 'refs/heads/main' }} + uses: pypa/gh-action-pypi-publish@release/v1 + with: + user: __token__ + password: ${{ secrets.PYPI_API_TOKEN }} diff --git a/.gitignore b/.gitignore index 65f1e1f..0a2a684 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ .* +!.github !.gitignore !.travis.yml *.db diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 0433f6a..0000000 --- a/.travis.yml +++ /dev/null @@ -1,30 +0,0 @@ -language: python -python: '3.6' -branches: - except: "/^v\\d/" -services: - - mysql - - postgresql -install: - - python setup.py install - - pip install mysqlclient - - pip install psycopg2-binary -before_script: - - mysql -e 'CREATE DATABASE IF NOT EXISTS test;' - - psql -c 'create database test;' -U postgres - - touch test.db test1.db -script: python tests/sql.py -deploy: - - provider: script - script: 'curl --fail --data "{ \"tag_name\": \"v$(python setup.py --version)\", - \"target_commitish\": \"$TRAVIS_COMMIT\", \"name\": \"v$(python setup.py --version)\" - }" --user bot50:$GITHUB_TOKEN https://api.github.com/repos/$TRAVIS_REPO_SLUG/releases' - on: - branch: main - - provider: pypi - user: "$PYPI_USERNAME" - password: "$PYPI_PASSWORD" - on: main -notifications: - slack: - secure: lJklhcBVjDT6KzUNa3RFHXdXSeH7ytuuGrkZ5ZcR72CXMoTf2pMJTzPwRLWOp6lCSdDC9Y8MWLrcg/e33dJga4Jlp9alOmWqeqesaFjfee4st8vAsgNbv8/RajPH1gD2bnkt8oIwUzdHItdb5AucKFYjbH2g0d8ndoqYqUeBLrnsT1AP5G/Vi9OHC9OWNpR0FKaZIJE0Wt52vkPMH3sV2mFeIskByPB+56U5y547mualKxn61IVR/dhYBEtZQJuSvnwKHPOn9Pkk7cCa+SSSeTJ4w5LboY8T17otaYNauXo46i1bKIoGiBcCcrJyQHHiPQmcq/YU540MC5Wzt9YXUycmJzRi347oyQeDee27wV3XJlWMXuuhbtJiKCFny7BTQ160VATlj/dbwIzN99Ra6/BtTumv/6LyTdKIuVjdAkcN8dtdDW1nlrQ29zuPNCcXXzJ7zX7kQaOCUV1c2OrsbiH/0fE9nknUORn97txqhlYVi0QMS7764wFo6kg0vpmFQRkkQySsJl+TmgcZ01AlsJc2EMMWVuaj9Af9JU4/4yalqDiXIh1fOYYUZnLfOfWS+MsnI+/oLfqJFyMbrsQQTIjs+kTzbiEdhd2R4EZgusU/xRFWokS2NAvahexrRhRQ6tpAI+LezPrkNOR3aHiykBf+P9BkUa0wPp6V2Ayc6q0= diff --git a/setup.py b/setup.py index 550e65d..c96cf23 100644 --- a/setup.py +++ b/setup.py @@ -16,5 +16,5 @@ package_dir={"": "src"}, packages=["cs50"], url="https://github.com/cs50/python-cs50", - version="6.0.4" + version="6.0.5" ) diff --git a/tests/sql.py b/tests/sql.py index e4757c7..89853a7 100644 --- a/tests/sql.py +++ b/tests/sql.py @@ -150,7 +150,7 @@ def tearDownClass(self): class MySQLTests(SQLTests): @classmethod def setUpClass(self): - self.db = SQL("mysql://root@localhost/test") + self.db = SQL("mysql://root@127.0.0.1/test") def setUp(self): self.db.execute("CREATE TABLE IF NOT EXISTS cs50 (id INTEGER NOT NULL AUTO_INCREMENT, val VARCHAR(16), bin BLOB, PRIMARY KEY (id))") @@ -160,7 +160,7 @@ def setUp(self): class PostgresTests(SQLTests): @classmethod def setUpClass(self): - self.db = SQL("postgresql://postgres@localhost/test") + self.db = SQL("postgresql://postgres:postgres@127.0.0.1/test") def setUp(self): self.db.execute("CREATE TABLE IF NOT EXISTS cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16), bin BYTEA)") From ff9e69f5ac7519498b6552a4e3180b5fadee4b78 Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Mon, 26 Jul 2021 15:58:21 -0400 Subject: [PATCH 090/159] Deploy with GitHub Actions --- .github/workflows/main.yml | 46 ++++++++++++++++++++++++++++++++++++++ .travis.yml | 30 ------------------------- 2 files changed, 46 insertions(+), 30 deletions(-) create mode 100644 .github/workflows/main.yml delete mode 100644 .travis.yml diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml new file mode 100644 index 0000000..0eb0e2c --- /dev/null +++ b/.github/workflows/main.yml @@ -0,0 +1,46 @@ +on: push +jobs: + deploy: + runs-on: ubuntu-latest + services: + mysql: + image: mysql + env: + MYSQL_DATABASE: test + MYSQL_ALLOW_EMPTY_PASSWORD: yes + options: --health-cmd="mysqladmin ping" --health-interval=10s --health-timeout=5s --health-retries=3 + ports: + - 3306:3306 + postgres: + image: postgres + env: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + POSTGRES_DB: test + ports: + - 5432:5432 + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v2 + with: + python-version: '3.6' + - name: Setup databases + run: | + python setup.py install + pip install mysqlclient + pip install psycopg2-binary + touch test.db test1.db + - name: Run tests + run: python tests/sql.py + - name: Install pypa/build + run: | + python -m pip install build --user + - name: Build a binary wheel and a source tarball + run: | + python -m build --sdist --wheel --outdir dist/ . + - name: Deploy to PyPI + if: ${{ github.ref == 'refs/heads/main' }} + uses: pypa/gh-action-pypi-publish@release/v1 + with: + user: __token__ + password: ${{ secrets.PYPI_API_TOKEN }} diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 0433f6a..0000000 --- a/.travis.yml +++ /dev/null @@ -1,30 +0,0 @@ -language: python -python: '3.6' -branches: - except: "/^v\\d/" -services: - - mysql - - postgresql -install: - - python setup.py install - - pip install mysqlclient - - pip install psycopg2-binary -before_script: - - mysql -e 'CREATE DATABASE IF NOT EXISTS test;' - - psql -c 'create database test;' -U postgres - - touch test.db test1.db -script: python tests/sql.py -deploy: - - provider: script - script: 'curl --fail --data "{ \"tag_name\": \"v$(python setup.py --version)\", - \"target_commitish\": \"$TRAVIS_COMMIT\", \"name\": \"v$(python setup.py --version)\" - }" --user bot50:$GITHUB_TOKEN https://api.github.com/repos/$TRAVIS_REPO_SLUG/releases' - on: - branch: main - - provider: pypi - user: "$PYPI_USERNAME" - password: "$PYPI_PASSWORD" - on: main -notifications: - slack: - secure: lJklhcBVjDT6KzUNa3RFHXdXSeH7ytuuGrkZ5ZcR72CXMoTf2pMJTzPwRLWOp6lCSdDC9Y8MWLrcg/e33dJga4Jlp9alOmWqeqesaFjfee4st8vAsgNbv8/RajPH1gD2bnkt8oIwUzdHItdb5AucKFYjbH2g0d8ndoqYqUeBLrnsT1AP5G/Vi9OHC9OWNpR0FKaZIJE0Wt52vkPMH3sV2mFeIskByPB+56U5y547mualKxn61IVR/dhYBEtZQJuSvnwKHPOn9Pkk7cCa+SSSeTJ4w5LboY8T17otaYNauXo46i1bKIoGiBcCcrJyQHHiPQmcq/YU540MC5Wzt9YXUycmJzRi347oyQeDee27wV3XJlWMXuuhbtJiKCFny7BTQ160VATlj/dbwIzN99Ra6/BtTumv/6LyTdKIuVjdAkcN8dtdDW1nlrQ29zuPNCcXXzJ7zX7kQaOCUV1c2OrsbiH/0fE9nknUORn97txqhlYVi0QMS7764wFo6kg0vpmFQRkkQySsJl+TmgcZ01AlsJc2EMMWVuaj9Af9JU4/4yalqDiXIh1fOYYUZnLfOfWS+MsnI+/oLfqJFyMbrsQQTIjs+kTzbiEdhd2R4EZgusU/xRFWokS2NAvahexrRhRQ6tpAI+LezPrkNOR3aHiykBf+P9BkUa0wPp6V2Ayc6q0= From 6aabf0c6ebbbb7ccfe59416a500223b02b710f3d Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Mon, 26 Jul 2021 16:01:03 -0400 Subject: [PATCH 091/159] Update sql.py --- tests/sql.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/sql.py b/tests/sql.py index e4757c7..89853a7 100644 --- a/tests/sql.py +++ b/tests/sql.py @@ -150,7 +150,7 @@ def tearDownClass(self): class MySQLTests(SQLTests): @classmethod def setUpClass(self): - self.db = SQL("mysql://root@localhost/test") + self.db = SQL("mysql://root@127.0.0.1/test") def setUp(self): self.db.execute("CREATE TABLE IF NOT EXISTS cs50 (id INTEGER NOT NULL AUTO_INCREMENT, val VARCHAR(16), bin BLOB, PRIMARY KEY (id))") @@ -160,7 +160,7 @@ def setUp(self): class PostgresTests(SQLTests): @classmethod def setUpClass(self): - self.db = SQL("postgresql://postgres@localhost/test") + self.db = SQL("postgresql://postgres:postgres@127.0.0.1/test") def setUp(self): self.db.execute("CREATE TABLE IF NOT EXISTS cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16), bin BYTEA)") From c2d12ee7a3fb60c330036ab0a0bc3fcaf8a39fb5 Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Tue, 27 Jul 2021 06:46:51 -0400 Subject: [PATCH 092/159] Support postgres:// scheme --- .github/workflows/main.yml | 12 ++++-------- .gitignore | 1 + src/cs50/_engine.py | 22 ++++++++++++++++++++++ tests/sql.py | 3 +++ 4 files changed, 30 insertions(+), 8 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 0eb0e2c..30d894b 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -26,18 +26,14 @@ jobs: python-version: '3.6' - name: Setup databases run: | - python setup.py install - pip install mysqlclient - pip install psycopg2-binary - touch test.db test1.db + pip install . + pip install mysqlclient psycopg2-binary - name: Run tests run: python tests/sql.py - name: Install pypa/build - run: | - python -m pip install build --user + run: python -m pip install build --user - name: Build a binary wheel and a source tarball - run: | - python -m build --sdist --wheel --outdir dist/ . + run: python -m build --sdist --wheel --outdir dist/ . - name: Deploy to PyPI if: ${{ github.ref == 'refs/heads/main' }} uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/.gitignore b/.gitignore index 0a2a684..0ce3062 100644 --- a/.gitignore +++ b/.gitignore @@ -5,5 +5,6 @@ *.db *.egg-info/ *.pyc +build/ dist/ test.db diff --git a/src/cs50/_engine.py b/src/cs50/_engine.py index d74992c..55489d1 100644 --- a/src/cs50/_engine.py +++ b/src/cs50/_engine.py @@ -1,4 +1,5 @@ import threading +import warnings from ._engine_util import create_engine @@ -11,6 +12,7 @@ class Engine: """ def __init__(self, url): + url = _replace_scheme_if_postgres(url) self._engine = create_engine(url) def get_transaction_connection(self): @@ -64,3 +66,23 @@ def _thread_local_connections(): connections = thread_local_data.connections = {} return connections + +def _replace_scheme_if_postgres(url): + """ + Replaces the postgres scheme with the postgresql scheme if possible since the postgres scheme + is deprecated. + + :returns: url with postgresql scheme if the scheme was postgres; otherwise returns url as is + """ + + if url.startswith("postgres://"): + with warnings.catch_warnings(): + warnings.simplefilter("always") + warnings.warn( + "The postgres:// scheme is deprecated and will not be supported in the next major" + + " release of the library. Please use the postgresql:// scheme instead.", + DeprecationWarning + ) + url = f"postgresql{url[len('postgres'):]}" + + return url diff --git a/tests/sql.py b/tests/sql.py index 89853a7..c473a66 100644 --- a/tests/sql.py +++ b/tests/sql.py @@ -169,6 +169,9 @@ def setUp(self): def test_cte(self): self.assertEqual(self.db.execute("WITH foo AS ( SELECT 1 AS bar ) SELECT bar FROM foo"), [{"bar": 1}]) + def test_postgres_scheme(self): + db = SQL("postgres://postgres:postgres@127.0.0.1/test") + db.execute("SELECT 1") class SQLiteTests(SQLTests): From 0bcf6b35ea6477cd156db72ad414447ab328433c Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Tue, 27 Jul 2021 06:48:53 -0400 Subject: [PATCH 093/159] Remove extra newlines --- tests/sql.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tests/sql.py b/tests/sql.py index c473a66..14cc035 100644 --- a/tests/sql.py +++ b/tests/sql.py @@ -9,7 +9,6 @@ class SQLTests(unittest.TestCase): - def test_multiple_statements(self): self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO cs50(val) VALUES('baz'); INSERT INTO cs50(val) VALUES('qux')") @@ -146,7 +145,6 @@ def tearDownClass(self): if not str(e).startswith("(1051"): raise e - class MySQLTests(SQLTests): @classmethod def setUpClass(self): @@ -156,7 +154,6 @@ def setUp(self): self.db.execute("CREATE TABLE IF NOT EXISTS cs50 (id INTEGER NOT NULL AUTO_INCREMENT, val VARCHAR(16), bin BLOB, PRIMARY KEY (id))") self.db.execute("DELETE FROM cs50") - class PostgresTests(SQLTests): @classmethod def setUpClass(self): @@ -174,7 +171,6 @@ def test_postgres_scheme(self): db.execute("SELECT 1") class SQLiteTests(SQLTests): - @classmethod def setUpClass(self): open("test.db", "w").close() @@ -286,7 +282,6 @@ def test_named(self): self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo VALUES (:bar, :baz)", bar='bar', baz='baz', qux='qux') self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo VALUES (:bar, :baz)", 'baz', bar='bar') - def test_numeric(self): self.db.execute("CREATE TABLE foo (firstname STRING, lastname STRING)") @@ -322,7 +317,6 @@ def test_numeric(self): def test_cte(self): self.assertEqual(self.db.execute("WITH foo AS ( SELECT 1 AS bar ) SELECT bar FROM foo"), [{"bar": 1}]) - if __name__ == "__main__": suite = unittest.TestSuite([ unittest.TestLoader().loadTestsFromTestCase(SQLiteTests), From 7ed6c52035ed2e2fcd56f9fb11a62ddfe08f6028 Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Tue, 27 Jul 2021 07:04:23 -0400 Subject: [PATCH 094/159] Support None parameters --- setup.py | 2 +- src/cs50/_sql_sanitizer.py | 4 +--- tests/sql.py | 15 +++++---------- 3 files changed, 7 insertions(+), 14 deletions(-) diff --git a/setup.py b/setup.py index de271f8..e5f01ce 100644 --- a/setup.py +++ b/setup.py @@ -16,5 +16,5 @@ package_dir={"": "src"}, packages=["cs50"], url="https://github.com/cs50/python-cs50", - version="7.0.0" + version="7.0.1" ) diff --git a/src/cs50/_sql_sanitizer.py b/src/cs50/_sql_sanitizer.py index 17fc5fa..3803bb8 100644 --- a/src/cs50/_sql_sanitizer.py +++ b/src/cs50/_sql_sanitizer.py @@ -66,9 +66,7 @@ def escape(self, value): return sqlparse.sql.Token(sqlparse.tokens.String, string_processor(value)) if value is None: - return sqlparse.sql.Token( - sqlparse.tokens.Keyword, - sqlalchemy.types.NullType().literal_processor(self._dialect)(value)) + return sqlparse.sql.Token(sqlparse.tokens.Keyword, "NULL") raise RuntimeError(f"unsupported value: {value}") diff --git a/tests/sql.py b/tests/sql.py index 14cc035..cf8c5ae 100644 --- a/tests/sql.py +++ b/tests/sql.py @@ -132,19 +132,10 @@ def test_identifier_case(self): self.assertIn("count", self.db.execute("SELECT 1 AS count")[0]) def tearDown(self): - self.db.execute("DROP TABLE cs50") + self.db.execute("DROP TABLE IF EXISTS cs50") self.db.execute("DROP TABLE IF EXISTS foo") self.db.execute("DROP TABLE IF EXISTS bar") - @classmethod - def tearDownClass(self): - try: - self.db.execute("DROP TABLE IF EXISTS cs50") - except Warning as e: - # suppress "unknown table" - if not str(e).startswith("(1051"): - raise e - class MySQLTests(SQLTests): @classmethod def setUpClass(self): @@ -317,6 +308,10 @@ def test_numeric(self): def test_cte(self): self.assertEqual(self.db.execute("WITH foo AS ( SELECT 1 AS bar ) SELECT bar FROM foo"), [{"bar": 1}]) + def test_none(self): + self.db.execute("CREATE TABLE foo (val INTEGER)") + self.db.execute("SELECT * FROM foo WHERE val = ?", None) + if __name__ == "__main__": suite = unittest.TestSuite([ unittest.TestLoader().loadTestsFromTestCase(SQLiteTests), From a06139ddf2e17362f03e61bc119ba7f2f391d43b Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Tue, 27 Jul 2021 08:44:39 -0400 Subject: [PATCH 095/159] Replace hard-coded NULL [skip ci] --- src/cs50/_sql_sanitizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cs50/_sql_sanitizer.py b/src/cs50/_sql_sanitizer.py index 3803bb8..388cbe9 100644 --- a/src/cs50/_sql_sanitizer.py +++ b/src/cs50/_sql_sanitizer.py @@ -66,7 +66,7 @@ def escape(self, value): return sqlparse.sql.Token(sqlparse.tokens.String, string_processor(value)) if value is None: - return sqlparse.sql.Token(sqlparse.tokens.Keyword, "NULL") + return sqlparse.sql.Token(sqlparse.tokens.Keyword, sqlalchemy.null()) raise RuntimeError(f"unsupported value: {value}") From a078a4419101f78cada544194fb4c64e69d9345c Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Wed, 28 Jul 2021 06:18:46 -0400 Subject: [PATCH 096/159] Move common tests to SQLTests --- tests/sql.py | 111 ++++++++++++++++++++++++++------------------------- 1 file changed, 56 insertions(+), 55 deletions(-) diff --git a/tests/sql.py b/tests/sql.py index cf8c5ae..34d718c 100644 --- a/tests/sql.py +++ b/tests/sql.py @@ -27,6 +27,7 @@ def test_delete_returns_affected_rows(self): def test_insert_returns_last_row_id(self): self.assertEqual(self.db.execute("INSERT INTO cs50(val) VALUES('foo')"), 1) self.assertEqual(self.db.execute("INSERT INTO cs50(val) VALUES('bar')"), 2) + self.assertEqual(self.db.execute("INSERT INTO cs50(val) VALUES('qux')"), 3) def test_select_all(self): self.assertEqual(self.db.execute("SELECT * FROM cs50"), []) @@ -131,55 +132,13 @@ def test_rollback(self): def test_identifier_case(self): self.assertIn("count", self.db.execute("SELECT 1 AS count")[0]) - def tearDown(self): - self.db.execute("DROP TABLE IF EXISTS cs50") - self.db.execute("DROP TABLE IF EXISTS foo") - self.db.execute("DROP TABLE IF EXISTS bar") - -class MySQLTests(SQLTests): - @classmethod - def setUpClass(self): - self.db = SQL("mysql://root@127.0.0.1/test") - - def setUp(self): - self.db.execute("CREATE TABLE IF NOT EXISTS cs50 (id INTEGER NOT NULL AUTO_INCREMENT, val VARCHAR(16), bin BLOB, PRIMARY KEY (id))") - self.db.execute("DELETE FROM cs50") - -class PostgresTests(SQLTests): - @classmethod - def setUpClass(self): - self.db = SQL("postgresql://postgres:postgres@127.0.0.1/test") - - def setUp(self): - self.db.execute("CREATE TABLE IF NOT EXISTS cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16), bin BYTEA)") - self.db.execute("DELETE FROM cs50") - - def test_cte(self): - self.assertEqual(self.db.execute("WITH foo AS ( SELECT 1 AS bar ) SELECT bar FROM foo"), [{"bar": 1}]) - - def test_postgres_scheme(self): - db = SQL("postgres://postgres:postgres@127.0.0.1/test") - db.execute("SELECT 1") - -class SQLiteTests(SQLTests): - @classmethod - def setUpClass(self): - open("test.db", "w").close() - self.db = SQL("sqlite:///test.db") - - def setUp(self): - self.db.execute("CREATE TABLE IF NOT EXISTS cs50 (id INTEGER PRIMARY KEY, val TEXT, bin BLOB)") - self.db.execute("DELETE FROM cs50") - - def test_lastrowid(self): - self.db.execute("CREATE TABLE foo(id INTEGER PRIMARY KEY AUTOINCREMENT, firstname TEXT, lastname TEXT)") - self.assertEqual(self.db.execute("INSERT INTO foo (firstname, lastname) VALUES('firstname', 'lastname')"), 1) - self.assertRaises(ValueError, self.db.execute, "INSERT INTO foo (id, firstname, lastname) VALUES(1, 'firstname', 'lastname')") - self.assertEqual(self.db.execute("INSERT OR IGNORE INTO foo (id, firstname, lastname) VALUES(1, 'firstname', 'lastname')"), None) + def test_none(self): + self.db.execute("CREATE TABLE foo (val INTEGER)") + self.db.execute("SELECT * FROM foo WHERE val = ?", None) def test_integrity_constraints(self): self.db.execute("CREATE TABLE foo(id INTEGER PRIMARY KEY)") - self.assertEqual(self.db.execute("INSERT INTO foo VALUES(1)"), 1) + self.db.execute("INSERT INTO foo VALUES(1)") self.assertRaises(ValueError, self.db.execute, "INSERT INTO foo VALUES(1)") def test_foreign_key_support(self): @@ -188,7 +147,7 @@ def test_foreign_key_support(self): self.assertRaises(ValueError, self.db.execute, "INSERT INTO bar VALUES(50)") def test_qmark(self): - self.db.execute("CREATE TABLE foo (firstname STRING, lastname STRING)") + self.db.execute("CREATE TABLE foo (firstname VARCHAR(255), lastname VARCHAR(255))") self.db.execute("INSERT INTO foo VALUES (?, 'bar')", "baz") self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "baz", "lastname": "bar"}]) @@ -218,7 +177,7 @@ def test_qmark(self): self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "bar", "lastname": "baz"}]) self.db.execute("DELETE FROM foo") - self.db.execute("CREATE TABLE bar (firstname STRING)") + self.db.execute("CREATE TABLE bar (firstname VARCHAR(255))") self.db.execute("INSERT INTO bar VALUES (?)", "baz") self.assertEqual(self.db.execute("SELECT * FROM bar"), [{"firstname": "baz"}]) @@ -242,7 +201,7 @@ def test_qmark(self): self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo VALUES (?, ?)", 'bar', baz='baz') def test_named(self): - self.db.execute("CREATE TABLE foo (firstname STRING, lastname STRING)") + self.db.execute("CREATE TABLE foo (firstname VARCHAR(255), lastname VARCHAR(255))") self.db.execute("INSERT INTO foo VALUES (:baz, 'bar')", baz="baz") self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "baz", "lastname": "bar"}]) @@ -264,7 +223,7 @@ def test_named(self): self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "bar", "lastname": "baz"}]) self.db.execute("DELETE FROM foo") - self.db.execute("CREATE TABLE bar (firstname STRING)") + self.db.execute("CREATE TABLE bar (firstname VARCHAR(255))") self.db.execute("INSERT INTO bar VALUES (:baz)", baz="baz") self.assertEqual(self.db.execute("SELECT * FROM bar"), [{"firstname": "baz"}]) @@ -274,7 +233,7 @@ def test_named(self): self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo VALUES (:bar, :baz)", 'baz', bar='bar') def test_numeric(self): - self.db.execute("CREATE TABLE foo (firstname STRING, lastname STRING)") + self.db.execute("CREATE TABLE foo (firstname VARCHAR(255), lastname VARCHAR(255))") self.db.execute("INSERT INTO foo VALUES (:1, 'bar')", "baz") self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "baz", "lastname": "bar"}]) @@ -296,7 +255,7 @@ def test_numeric(self): self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "bar", "lastname": "baz"}]) self.db.execute("DELETE FROM foo") - self.db.execute("CREATE TABLE bar (firstname STRING)") + self.db.execute("CREATE TABLE bar (firstname VARCHAR(255))") self.db.execute("INSERT INTO bar VALUES (:1)", "baz") self.assertEqual(self.db.execute("SELECT * FROM bar"), [{"firstname": "baz"}]) @@ -308,9 +267,51 @@ def test_numeric(self): def test_cte(self): self.assertEqual(self.db.execute("WITH foo AS ( SELECT 1 AS bar ) SELECT bar FROM foo"), [{"bar": 1}]) - def test_none(self): - self.db.execute("CREATE TABLE foo (val INTEGER)") - self.db.execute("SELECT * FROM foo WHERE val = ?", None) + def tearDown(self): + self.db.execute("DROP TABLE IF EXISTS cs50") + self.db.execute("DROP TABLE IF EXISTS bar") + self.db.execute("DROP TABLE IF EXISTS foo") + +class MySQLTests(SQLTests): + @classmethod + def setUpClass(self): + self.db = SQL("mysql://root@127.0.0.1/test") + + def setUp(self): + self.db.execute("CREATE TABLE IF NOT EXISTS cs50 (id INTEGER NOT NULL AUTO_INCREMENT, val VARCHAR(16), bin BLOB, PRIMARY KEY (id))") + self.db.execute("DELETE FROM cs50") + +class PostgresTests(SQLTests): + @classmethod + def setUpClass(self): + self.db = SQL("postgresql://postgres:postgres@127.0.0.1/test") + + def setUp(self): + self.db.execute("CREATE TABLE IF NOT EXISTS cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16), bin BYTEA)") + self.db.execute("DELETE FROM cs50") + + def test_cte(self): + self.assertEqual(self.db.execute("WITH foo AS ( SELECT 1 AS bar ) SELECT bar FROM foo"), [{"bar": 1}]) + + def test_postgres_scheme(self): + db = SQL("postgres://postgres:postgres@127.0.0.1/test") + db.execute("SELECT 1") + +class SQLiteTests(SQLTests): + @classmethod + def setUpClass(self): + open("test.db", "w").close() + self.db = SQL("sqlite:///test.db") + + def setUp(self): + self.db.execute("CREATE TABLE IF NOT EXISTS cs50 (id INTEGER PRIMARY KEY, val TEXT, bin BLOB)") + self.db.execute("DELETE FROM cs50") + + def test_lastrowid(self): + self.db.execute("CREATE TABLE foo(id INTEGER PRIMARY KEY AUTOINCREMENT, firstname TEXT, lastname TEXT)") + self.assertEqual(self.db.execute("INSERT INTO foo (firstname, lastname) VALUES('firstname', 'lastname')"), 1) + self.assertRaises(ValueError, self.db.execute, "INSERT INTO foo (id, firstname, lastname) VALUES(1, 'firstname', 'lastname')") + self.assertEqual(self.db.execute("INSERT OR IGNORE INTO foo (id, firstname, lastname) VALUES(1, 'firstname', 'lastname')"), None) if __name__ == "__main__": suite = unittest.TestSuite([ From b346ed899d3da2bde6a26adf464a2318d6b59893 Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Wed, 28 Jul 2021 06:37:34 -0400 Subject: [PATCH 097/159] Fix using named param more than once --- src/cs50/_statement.py | 13 ++++++++----- tests/sql.py | 4 ++++ 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/src/cs50/_statement.py b/src/cs50/_statement.py index 2de956a..a96a282 100644 --- a/src/cs50/_statement.py +++ b/src/cs50/_statement.py @@ -197,18 +197,21 @@ def _substitute_named_or_pyformat_markers(self): Raises a ``RuntimeError`` if any parameters are missing or unused. """ - unused_params = set(self._kwargs.keys()) + unused_params = {param_name: True for param_name in self._kwargs.keys()} for token_index, param_name in self._placeholders.items(): if param_name not in self._kwargs: raise RuntimeError(f"missing value for placeholder ({param_name})") self._tokens[token_index] = self._kwargs[param_name] - unused_params.remove(param_name) + unused_params[param_name] = False - if len(unused_params) > 0: - joined_unused_params = get_human_readable_list(sorted(unused_params)) + sorted_unique_unused_param_names = sorted(set( + param_name for param_name, unused in unused_params.items() if unused)) + if len(sorted_unique_unused_param_names) > 0: + joined_unused_params = get_human_readable_list(sorted_unique_unused_param_names) raise RuntimeError( - f"unused value{'' if len(unused_params) == 1 else 's'} ({joined_unused_params})") + f"unused value{'' if len(sorted_unique_unused_param_names) == 1 else 's'}" + + " ({joined_unused_params})") def _escape_verbatim_colons(self): """Escapes verbatim colons from string literal and identifier tokens so they aren't treated diff --git a/tests/sql.py b/tests/sql.py index 34d718c..7caaf63 100644 --- a/tests/sql.py +++ b/tests/sql.py @@ -223,6 +223,10 @@ def test_named(self): self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "bar", "lastname": "baz"}]) self.db.execute("DELETE FROM foo") + self.db.execute("INSERT INTO foo VALUES (:baz, :baz)", baz="baz") + self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "baz", "lastname": "baz"}]) + self.db.execute("DELETE FROM foo") + self.db.execute("CREATE TABLE bar (firstname VARCHAR(255))") self.db.execute("INSERT INTO bar VALUES (:baz)", baz="baz") self.assertEqual(self.db.execute("SELECT * FROM bar"), [{"firstname": "baz"}]) From eb57b43c8bd6587214f8abc85b33a3368c6e50c4 Mon Sep 17 00:00:00 2001 From: Kareem Zidane <kzidane@cs50.harvard.edu> Date: Wed, 28 Jul 2021 06:37:43 -0400 Subject: [PATCH 098/159] Up version --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index e5f01ce..f38557d 100644 --- a/setup.py +++ b/setup.py @@ -16,5 +16,5 @@ package_dir={"": "src"}, packages=["cs50"], url="https://github.com/cs50/python-cs50", - version="7.0.1" + version="7.0.2" ) From 98840ee15eae1148879594a86049a73d34986e66 Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Wed, 8 Dec 2021 18:12:31 -0500 Subject: [PATCH 099/159] porting thread-local storage to v6 --- .gitignore | 6 +- setup.py | 4 +- src/cs50/__init__.py | 23 +- src/cs50/_engine.py | 88 ----- src/cs50/_engine_util.py | 43 --- src/cs50/_logger.py | 98 ----- src/cs50/_sql_sanitizer.py | 93 ----- src/cs50/_sql_util.py | 51 --- src/cs50/_statement.py | 247 ------------- src/cs50/_statement_util.py | 101 ----- src/cs50/cs50.py | 171 +++++---- src/cs50/flask.py | 38 ++ src/cs50/sql.py | 586 +++++++++++++++++++++++++----- tests/flask/application.py | 22 ++ tests/flask/requirements.txt | 2 + tests/flask/templates/error.html | 10 + tests/flask/templates/index.html | 10 + tests/foo.py | 48 +++ tests/mysql.py | 8 + tests/python.py | 8 + tests/redirect/application.py | 12 + tests/redirect/templates/foo.html | 1 + tests/sql.py | 123 ++++--- tests/sqlite.py | 44 +++ tests/tb.py | 10 + tests/test_cs50.py | 141 ------- 26 files changed, 906 insertions(+), 1082 deletions(-) delete mode 100644 src/cs50/_engine.py delete mode 100644 src/cs50/_engine_util.py delete mode 100644 src/cs50/_logger.py delete mode 100644 src/cs50/_sql_sanitizer.py delete mode 100644 src/cs50/_sql_util.py delete mode 100644 src/cs50/_statement.py delete mode 100644 src/cs50/_statement_util.py create mode 100644 src/cs50/flask.py create mode 100644 tests/flask/application.py create mode 100644 tests/flask/requirements.txt create mode 100644 tests/flask/templates/error.html create mode 100644 tests/flask/templates/index.html create mode 100644 tests/foo.py create mode 100644 tests/mysql.py create mode 100644 tests/python.py create mode 100644 tests/redirect/application.py create mode 100644 tests/redirect/templates/foo.html create mode 100644 tests/sqlite.py create mode 100644 tests/tb.py delete mode 100644 tests/test_cs50.py diff --git a/.gitignore b/.gitignore index 0ce3062..4286ed6 100644 --- a/.gitignore +++ b/.gitignore @@ -1,10 +1,6 @@ .* -!.github +!/.github/ !.gitignore -!.travis.yml *.db *.egg-info/ *.pyc -build/ -dist/ -test.db diff --git a/setup.py b/setup.py index f38557d..faf7abd 100644 --- a/setup.py +++ b/setup.py @@ -10,11 +10,11 @@ "Topic :: Software Development :: Libraries :: Python Modules" ], description="CS50 library for Python", - install_requires=["Flask>=1.0", "SQLAlchemy<2", "sqlparse", "termcolor"], + install_requires=["Flask>=1.0", "SQLAlchemy", "sqlparse", "termcolor"], keywords="cs50", name="cs50", package_dir={"": "src"}, packages=["cs50"], url="https://github.com/cs50/python-cs50", - version="7.0.2" + version="7.1.0" ) diff --git a/src/cs50/__init__.py b/src/cs50/__init__.py index e5ec787..aaec161 100644 --- a/src/cs50/__init__.py +++ b/src/cs50/__init__.py @@ -1,5 +1,20 @@ -from .cs50 import get_float, get_int, get_string -from .sql import SQL -from ._logger import _setup_logger +import logging +import os +import sys + + +# Disable cs50 logger by default +logging.getLogger("cs50").disabled = True -_setup_logger() +# Import cs50_* +from .cs50 import get_char, get_float, get_int, get_string +try: + from .cs50 import get_long +except ImportError: + pass + +# Hook into flask importing +from . import flask + +# Wrap SQLAlchemy +from .sql import SQL diff --git a/src/cs50/_engine.py b/src/cs50/_engine.py deleted file mode 100644 index 55489d1..0000000 --- a/src/cs50/_engine.py +++ /dev/null @@ -1,88 +0,0 @@ -import threading -import warnings - -from ._engine_util import create_engine - - -thread_local_data = threading.local() - - -class Engine: - """Wraps a SQLAlchemy engine. - """ - - def __init__(self, url): - url = _replace_scheme_if_postgres(url) - self._engine = create_engine(url) - - def get_transaction_connection(self): - """ - :returns: A new connection with autocommit disabled (to be used for transactions). - """ - - _thread_local_connections()[self] = self._engine.connect().execution_options( - autocommit=False) - return self.get_existing_transaction_connection() - - def get_connection(self): - """ - :returns: A new connection with autocommit enabled - """ - - return self._engine.connect().execution_options(autocommit=True) - - def get_existing_transaction_connection(self): - """ - :returns: The transaction connection bound to this Engine instance, if one exists, or None. - """ - - return _thread_local_connections().get(self) - - def close_transaction_connection(self): - """Closes the transaction connection bound to this Engine instance, if one exists and - removes it. - """ - - connection = self.get_existing_transaction_connection() - if connection: - connection.close() - del _thread_local_connections()[self] - - def is_postgres(self): - return self._engine.dialect.name in {"postgres", "postgresql"} - - def __getattr__(self, attr): - return getattr(self._engine, attr) - -def _thread_local_connections(): - """ - :returns: A thread local dict to keep track of transaction connection. If one does not exist, - creates one. - """ - - try: - connections = thread_local_data.connections - except AttributeError: - connections = thread_local_data.connections = {} - - return connections - -def _replace_scheme_if_postgres(url): - """ - Replaces the postgres scheme with the postgresql scheme if possible since the postgres scheme - is deprecated. - - :returns: url with postgresql scheme if the scheme was postgres; otherwise returns url as is - """ - - if url.startswith("postgres://"): - with warnings.catch_warnings(): - warnings.simplefilter("always") - warnings.warn( - "The postgres:// scheme is deprecated and will not be supported in the next major" - + " release of the library. Please use the postgresql:// scheme instead.", - DeprecationWarning - ) - url = f"postgresql{url[len('postgres'):]}" - - return url diff --git a/src/cs50/_engine_util.py b/src/cs50/_engine_util.py deleted file mode 100644 index c55b8f2..0000000 --- a/src/cs50/_engine_util.py +++ /dev/null @@ -1,43 +0,0 @@ -"""Utility functions used by _session.py. -""" - -import os -import sqlite3 - -import sqlalchemy - -sqlite_url_prefix = "sqlite:///" - - -def create_engine(url, **kwargs): - """Creates a new SQLAlchemy engine. If ``url`` is a URL for a SQLite database, makes sure that - the SQLite file exits and enables foreign key constraints. - """ - - try: - engine = sqlalchemy.create_engine(url, **kwargs) - except sqlalchemy.exc.ArgumentError: - raise RuntimeError(f"invalid URL: {url}") from None - - if _is_sqlite_url(url): - _assert_sqlite_file_exists(url) - sqlalchemy.event.listen(engine, "connect", _enable_sqlite_foreign_key_constraints) - - return engine - -def _is_sqlite_url(url): - return url.startswith(sqlite_url_prefix) - - -def _assert_sqlite_file_exists(url): - path = url[len(sqlite_url_prefix):] - if not os.path.exists(path): - raise RuntimeError(f"does not exist: {path}") - if not os.path.isfile(path): - raise RuntimeError(f"not a file: {path}") - - -def _enable_sqlite_foreign_key_constraints(dbapi_connection, _): - cursor = dbapi_connection.cursor() - cursor.execute("PRAGMA foreign_keys=ON") - cursor.close() diff --git a/src/cs50/_logger.py b/src/cs50/_logger.py deleted file mode 100644 index e7b03ca..0000000 --- a/src/cs50/_logger.py +++ /dev/null @@ -1,98 +0,0 @@ -"""Sets up logging for the library. -""" - -import logging -import os.path -import re -import sys -import traceback - -import termcolor - - -def green(msg): - return _colored(msg, "green") - - -def red(msg): - return _colored(msg, "red") - - -def yellow(msg): - return _colored(msg, "yellow") - - -def _colored(msg, color): - return termcolor.colored(str(msg), color) - - -def _setup_logger(): - _configure_default_logger() - _patch_root_handler_format_exception() - _configure_cs50_logger() - _patch_excepthook() - - -def _configure_default_logger(): - """Configures a default handler and formatter to prevent flask and werkzeug from adding theirs. - """ - - logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.DEBUG) - - -def _patch_root_handler_format_exception(): - """Patches formatException for the root handler to use ``_format_exception``. - """ - - try: - formatter = logging.root.handlers[0].formatter - formatter.formatException = lambda exc_info: _format_exception(*exc_info) - except IndexError: - pass - - -def _configure_cs50_logger(): - """Disables the cs50 logger by default. Disables logging propagation to prevent messages from - being logged more than once. Sets the logging handler and formatter. - """ - - _logger = logging.getLogger("cs50") - _logger.disabled = True - _logger.setLevel(logging.DEBUG) - - # Log messages once - _logger.propagate = False - - handler = logging.StreamHandler() - handler.setLevel(logging.DEBUG) - - formatter = logging.Formatter("%(levelname)s: %(message)s") - formatter.formatException = lambda exc_info: _format_exception(*exc_info) - handler.setFormatter(formatter) - _logger.addHandler(handler) - - -def _patch_excepthook(): - sys.excepthook = lambda type_, value, exc_tb: print( - _format_exception(type_, value, exc_tb), file=sys.stderr) - - -def _format_exception(type_, value, exc_tb): - """Formats traceback, darkening entries from global site-packages directories and user-specific - site-packages directory. - https://stackoverflow.com/a/46071447/5156190 - """ - - # Absolute paths to site-packages - packages = tuple(os.path.join(os.path.abspath(p), "") for p in sys.path[1:]) - - # Highlight lines not referring to files in site-packages - lines = [] - for line in traceback.format_exception(type_, value, exc_tb): - matches = re.search(r"^ File \"([^\"]+)\", line \d+, in .+", line) - if matches and matches.group(1).startswith(packages): - lines += line - else: - matches = re.search(r"^(\s*)(.*?)(\s*)$", line, re.DOTALL) - lines.append(matches.group(1) + yellow(matches.group(2)) + matches.group(3)) - return "".join(lines).rstrip() diff --git a/src/cs50/_sql_sanitizer.py b/src/cs50/_sql_sanitizer.py deleted file mode 100644 index 388cbe9..0000000 --- a/src/cs50/_sql_sanitizer.py +++ /dev/null @@ -1,93 +0,0 @@ -import datetime -import re - -import sqlalchemy -import sqlparse - - -class SQLSanitizer: - """Sanitizes SQL values. - """ - - def __init__(self, dialect): - self._dialect = dialect - - def escape(self, value): - """Escapes value using engine's conversion function. - https://docs.sqlalchemy.org/en/latest/core/type_api.html#sqlalchemy.types.TypeEngine.literal_processor - - :param value: The value to be sanitized - - :returns: The sanitized value - """ - # pylint: disable=too-many-return-statements - if isinstance(value, (list, tuple)): - return self.escape_iterable(value) - - if isinstance(value, bool): - return sqlparse.sql.Token( - sqlparse.tokens.Number, - sqlalchemy.types.Boolean().literal_processor(self._dialect)(value)) - - if isinstance(value, bytes): - if self._dialect.name in {"mysql", "sqlite"}: - # https://dev.mysql.com/doc/refman/8.0/en/hexadecimal-literals.html - return sqlparse.sql.Token(sqlparse.tokens.Other, f"x'{value.hex()}'") - if self._dialect.name in {"postgres", "postgresql"}: - # https://dba.stackexchange.com/a/203359 - return sqlparse.sql.Token(sqlparse.tokens.Other, f"'\\x{value.hex()}'") - - raise RuntimeError(f"unsupported value: {value}") - - string_processor = sqlalchemy.types.String().literal_processor(self._dialect) - if isinstance(value, datetime.date): - return sqlparse.sql.Token( - sqlparse.tokens.String, string_processor(value.strftime("%Y-%m-%d"))) - - if isinstance(value, datetime.datetime): - return sqlparse.sql.Token( - sqlparse.tokens.String, string_processor(value.strftime("%Y-%m-%d %H:%M:%S"))) - - if isinstance(value, datetime.time): - return sqlparse.sql.Token( - sqlparse.tokens.String, string_processor(value.strftime("%H:%M:%S"))) - - if isinstance(value, float): - return sqlparse.sql.Token( - sqlparse.tokens.Number, - sqlalchemy.types.Float().literal_processor(self._dialect)(value)) - - if isinstance(value, int): - return sqlparse.sql.Token( - sqlparse.tokens.Number, - sqlalchemy.types.Integer().literal_processor(self._dialect)(value)) - - if isinstance(value, str): - return sqlparse.sql.Token(sqlparse.tokens.String, string_processor(value)) - - if value is None: - return sqlparse.sql.Token(sqlparse.tokens.Keyword, sqlalchemy.null()) - - raise RuntimeError(f"unsupported value: {value}") - - def escape_iterable(self, iterable): - """Escapes each value in iterable and joins all the escaped values with ", ", formatted for - SQL's ``IN`` operator. - - :param: An iterable of values to be escaped - - :returns: A comma-separated list of escaped values from ``iterable`` - :rtype: :class:`sqlparse.sql.TokenList` - """ - - return sqlparse.sql.TokenList( - sqlparse.parse(", ".join([str(self.escape(v)) for v in iterable]))) - - -def escape_verbatim_colon(value): - """Escapes verbatim colon from a value so as it is not confused with a parameter marker. - """ - - # E.g., ':foo, ":foo, :foo will be replaced with - # '\:foo, "\:foo, \:foo respectively - return re.sub(r"(^(?:'|\")|\s+):", r"\1\:", value) diff --git a/src/cs50/_sql_util.py b/src/cs50/_sql_util.py deleted file mode 100644 index 2dbfecf..0000000 --- a/src/cs50/_sql_util.py +++ /dev/null @@ -1,51 +0,0 @@ -"""Utility functions used by sql.py. -""" - -import contextlib -import decimal -import warnings - -import sqlalchemy - - -def process_select_result(result): - """Converts a SQLAlchemy result to a ``list`` of ``dict`` objects, each of which represents a - row in the result set. - - :param result: A SQLAlchemy result - :type result: :class:`sqlalchemy.engine.Result` - """ - rows = [dict(row) for row in result.fetchall()] - for row in rows: - for column in row: - # Coerce decimal.Decimal objects to float objects - # https://groups.google.com/d/msg/sqlalchemy/0qXMYJvq8SA/oqtvMD9Uw-kJ - if isinstance(row[column], decimal.Decimal): - row[column] = float(row[column]) - - # Coerce memoryview objects (as from PostgreSQL's bytea columns) to bytes - elif isinstance(row[column], memoryview): - row[column] = bytes(row[column]) - - return rows - - -@contextlib.contextmanager -def raise_errors_for_warnings(): - """Catches warnings and raises errors instead. - """ - - with warnings.catch_warnings(): - warnings.simplefilter("error") - yield - - -def postgres_lastval(connection): - """ - :returns: The ID of the last inserted row, if defined in this session, or None - """ - - try: - return connection.execute("SELECT LASTVAL()").first()[0] - except sqlalchemy.exc.OperationalError: - return None diff --git a/src/cs50/_statement.py b/src/cs50/_statement.py deleted file mode 100644 index a96a282..0000000 --- a/src/cs50/_statement.py +++ /dev/null @@ -1,247 +0,0 @@ -import collections - -from ._sql_sanitizer import SQLSanitizer, escape_verbatim_colon -from ._statement_util import ( - format_and_parse, - get_human_readable_list, - is_identifier, - is_operation_token, - is_placeholder, - is_string_literal, - operation_keywords, - Paramstyle, - parse_placeholder, -) - - -def statement_factory(dialect): - """Creates a sanitizer for ``dialect`` and injects it into ``Statement``, exposing a simpler - interface for ``Statement``. - - :param dialect: a SQLAlchemy dialect - :type dialect: :class:`sqlalchemy.engine.Dialect` - """ - - sql_sanitizer = SQLSanitizer(dialect) - - def statement(sql, *args, **kwargs): - return Statement(sql_sanitizer, sql, *args, **kwargs) - - return statement - - -class Statement: - """Parses a SQL statement and substitutes any parameter markers with their corresponding - placeholders. - """ - - def __init__(self, sql_sanitizer, sql, *args, **kwargs): - """ - :param sql_sanitizer: The SQL sanitizer used to sanitize the parameters - :type sql_sanitizer: :class:`_sql_sanitizer.SQLSanitizer` - - :param sql: The SQL statement - :type sql: str - - :param *args: Zero or more positional parameters to be substituted for the parameter markers - - :param *kwargs: Zero or more keyword arguments to be substituted for the parameter markers - """ - - if len(args) > 0 and len(kwargs) > 0: - raise RuntimeError("cannot pass both positional and named parameters") - - self._sql_sanitizer = sql_sanitizer - - self._args = self._get_escaped_args(args) - self._kwargs = self._get_escaped_kwargs(kwargs) - - self._statement = format_and_parse(sql) - self._tokens = self._tokenize() - - self._operation_keyword = self._get_operation_keyword() - - self._paramstyle = self._get_paramstyle() - self._placeholders = self._get_placeholders() - self._substitute_markers_with_escaped_params() - # self._escape_verbatim_colons() - - def _get_escaped_args(self, args): - return [self._sql_sanitizer.escape(arg) for arg in args] - - def _get_escaped_kwargs(self, kwargs): - return {k: self._sql_sanitizer.escape(v) for k, v in kwargs.items()} - - def _tokenize(self): - """ - :returns: A flattened list of SQLParse tokens that represent the SQL statement - """ - - return list(self._statement.flatten()) - - def _get_operation_keyword(self): - """ - :returns: The operation keyword of the SQL statement (e.g., ``SELECT``, ``DELETE``, etc) - :rtype: str - """ - - for token in self._statement: - if is_operation_token(token.ttype): - token_value = token.value.upper() - if token_value in operation_keywords: - operation_keyword = token_value - break - else: - operation_keyword = None - - return operation_keyword - - def _get_paramstyle(self): - """ - :returns: The paramstyle used in the SQL statement (if any) - :rtype: :class:_statement_util.Paramstyle`` - """ - - paramstyle = None - for token in self._tokens: - if is_placeholder(token.ttype): - paramstyle, _ = parse_placeholder(token.value) - break - else: - paramstyle = self._default_paramstyle() - - return paramstyle - - def _default_paramstyle(self): - """ - :returns: If positional args were passed, returns ``Paramstyle.QMARK``; if keyword arguments - were passed, returns ``Paramstyle.NAMED``; otherwise, returns ``None`` - """ - - paramstyle = None - if self._args: - paramstyle = Paramstyle.QMARK - elif self._kwargs: - paramstyle = Paramstyle.NAMED - - return paramstyle - - def _get_placeholders(self): - """ - :returns: A dict that maps the index of each parameter marker in the tokens list to the name - of that parameter marker (if applicable) or ``None`` - :rtype: dict - """ - - placeholders = collections.OrderedDict() - for index, token in enumerate(self._tokens): - if is_placeholder(token.ttype): - paramstyle, name = parse_placeholder(token.value) - if paramstyle != self._paramstyle: - raise RuntimeError("inconsistent paramstyle") - - placeholders[index] = name - - return placeholders - - def _substitute_markers_with_escaped_params(self): - if self._paramstyle in {Paramstyle.FORMAT, Paramstyle.QMARK}: - self._substitute_format_or_qmark_markers() - elif self._paramstyle == Paramstyle.NUMERIC: - self._substitue_numeric_markers() - if self._paramstyle in {Paramstyle.NAMED, Paramstyle.PYFORMAT}: - self._substitute_named_or_pyformat_markers() - - def _substitute_format_or_qmark_markers(self): - """Substitutes format or qmark parameter markers with their corresponding parameters. - """ - - self._assert_valid_arg_count() - for arg_index, token_index in enumerate(self._placeholders.keys()): - self._tokens[token_index] = self._args[arg_index] - - def _assert_valid_arg_count(self): - """Raises a ``RuntimeError`` if the number of arguments does not match the number of - placeholders. - """ - - if len(self._placeholders) != len(self._args): - placeholders = get_human_readable_list(self._placeholders.values()) - args = get_human_readable_list(self._args) - if len(self._placeholders) < len(self._args): - raise RuntimeError(f"fewer placeholders ({placeholders}) than values ({args})") - - raise RuntimeError(f"more placeholders ({placeholders}) than values ({args})") - - def _substitue_numeric_markers(self): - """Substitutes numeric parameter markers with their corresponding parameters. Raises a - ``RuntimeError`` if any parameters are missing or unused. - """ - - unused_arg_indices = set(range(len(self._args))) - for token_index, num in self._placeholders.items(): - if num >= len(self._args): - raise RuntimeError(f"missing value for placeholder ({num + 1})") - - self._tokens[token_index] = self._args[num] - unused_arg_indices.remove(num) - - if len(unused_arg_indices) > 0: - unused_args = get_human_readable_list( - [self._args[i] for i in sorted(unused_arg_indices)]) - raise RuntimeError( - f"unused value{'' if len(unused_args) == 1 else 's'} ({unused_args})") - - def _substitute_named_or_pyformat_markers(self): - """Substitutes named or pyformat parameter markers with their corresponding parameters. - Raises a ``RuntimeError`` if any parameters are missing or unused. - """ - - unused_params = {param_name: True for param_name in self._kwargs.keys()} - for token_index, param_name in self._placeholders.items(): - if param_name not in self._kwargs: - raise RuntimeError(f"missing value for placeholder ({param_name})") - - self._tokens[token_index] = self._kwargs[param_name] - unused_params[param_name] = False - - sorted_unique_unused_param_names = sorted(set( - param_name for param_name, unused in unused_params.items() if unused)) - if len(sorted_unique_unused_param_names) > 0: - joined_unused_params = get_human_readable_list(sorted_unique_unused_param_names) - raise RuntimeError( - f"unused value{'' if len(sorted_unique_unused_param_names) == 1 else 's'}" - + " ({joined_unused_params})") - - def _escape_verbatim_colons(self): - """Escapes verbatim colons from string literal and identifier tokens so they aren't treated - as parameter markers. - """ - - for token in self._tokens: - if is_string_literal(token.ttype) or is_identifier(token.ttype): - token.value = escape_verbatim_colon(token.value) - - def is_transaction_start(self): - return self._operation_keyword in {"BEGIN", "START"} - - def is_transaction_end(self): - return self._operation_keyword in {"COMMIT", "ROLLBACK"} - - def is_delete(self): - return self._operation_keyword == "DELETE" - - def is_insert(self): - return self._operation_keyword == "INSERT" - - def is_select(self): - return self._operation_keyword == "SELECT" - - def is_update(self): - return self._operation_keyword == "UPDATE" - - def __str__(self): - """Joins the statement tokens into a string. - """ - - return "".join([str(token) for token in self._tokens]) diff --git a/src/cs50/_statement_util.py b/src/cs50/_statement_util.py deleted file mode 100644 index 34ca6ff..0000000 --- a/src/cs50/_statement_util.py +++ /dev/null @@ -1,101 +0,0 @@ -"""Utility functions used by _statement.py. -""" - -import enum -import re - -import sqlparse - - -operation_keywords = { - "BEGIN", - "COMMIT", - "DELETE", - "INSERT", - "ROLLBACK", - "SELECT", - "START", - "UPDATE" -} - - -class Paramstyle(enum.Enum): - """Represents the supported parameter marker styles. - """ - - FORMAT = enum.auto() - NAMED = enum.auto() - NUMERIC = enum.auto() - PYFORMAT = enum.auto() - QMARK = enum.auto() - - -def format_and_parse(sql): - """Formats and parses a SQL statement. Raises ``RuntimeError`` if ``sql`` represents more than - one statement. - - :param sql: The SQL statement to be formatted and parsed - :type sql: str - - :returns: A list of unflattened SQLParse tokens that represent the parsed statement - """ - - formatted_statements = sqlparse.format(sql, strip_comments=True).strip() - parsed_statements = sqlparse.parse(formatted_statements) - statement_count = len(parsed_statements) - if statement_count == 0: - raise RuntimeError("missing statement") - if statement_count > 1: - raise RuntimeError("too many statements at once") - - return parsed_statements[0] - - -def is_placeholder(ttype): - return ttype == sqlparse.tokens.Name.Placeholder - - -def parse_placeholder(value): - """ - :returns: A tuple of the paramstyle and the name of the parameter marker (if any) or ``None`` - :rtype: tuple - """ - if value == "?": - return Paramstyle.QMARK, None - - # E.g., :1 - matches = re.search(r"^:([1-9]\d*)$", value) - if matches: - return Paramstyle.NUMERIC, int(matches.group(1)) - 1 - - # E.g., :foo - matches = re.search(r"^:([a-zA-Z]\w*)$", value) - if matches: - return Paramstyle.NAMED, matches.group(1) - - if value == "%s": - return Paramstyle.FORMAT, None - - # E.g., %(foo)s - matches = re.search(r"%\((\w+)\)s$", value) - if matches: - return Paramstyle.PYFORMAT, matches.group(1) - - raise RuntimeError(f"{value}: invalid placeholder") - - -def is_operation_token(ttype): - return ttype in { - sqlparse.tokens.Keyword, sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML} - - -def is_string_literal(ttype): - return ttype in [sqlparse.tokens.Literal.String, sqlparse.tokens.Literal.String.Single] - - -def is_identifier(ttype): - return ttype == sqlparse.tokens.Literal.String.Symbol - - -def get_human_readable_list(iterable): - return ", ".join(str(v) for v in iterable) diff --git a/src/cs50/cs50.py b/src/cs50/cs50.py index 11fa20a..1d7b6ea 100644 --- a/src/cs50/cs50.py +++ b/src/cs50/cs50.py @@ -1,104 +1,143 @@ -"""Exposes simple API for getting and validating user input""" +from __future__ import print_function +import inspect +import logging +import os import re import sys +from distutils.sysconfig import get_python_lib +from os.path import abspath, join +from termcolor import colored +from traceback import format_exception -def get_float(prompt): - """Reads a line of text from standard input and returns the equivalent float as precisely as - possible; if text does not represent a float, user is prompted to retry. If line can't be read, - returns None. - :type prompt: str +# Configure default logging handler and formatter +# Prevent flask, werkzeug, etc from adding default handler +logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.DEBUG) - """ +try: + # Patch formatException + logging.root.handlers[0].formatter.formatException = lambda exc_info: _formatException(*exc_info) +except IndexError: + pass - while True: - try: - return _get_float(prompt) - except (OverflowError, ValueError): - pass +# Configure cs50 logger +_logger = logging.getLogger("cs50") +_logger.setLevel(logging.DEBUG) +# Log messages once +_logger.propagate = False -def _get_float(prompt): - user_input = get_string(prompt) - if user_input is None: - return None - - if len(user_input) > 0 and re.search(r"^[+-]?\d*(?:\.\d*)?$", user_input): - return float(user_input) +handler = logging.StreamHandler() +handler.setLevel(logging.DEBUG) - raise ValueError(f"invalid float literal: {user_input}") +formatter = logging.Formatter("%(levelname)s: %(message)s") +formatter.formatException = lambda exc_info: _formatException(*exc_info) +handler.setFormatter(formatter) +_logger.addHandler(handler) -def get_int(prompt): - """Reads a line of text from standard input and return the equivalent int; if text does not - represent an int, user is prompted to retry. If line can't be read, returns None. +class _flushfile(): + """ + Disable buffering for standard output and standard error. - :type prompt: str + http://stackoverflow.com/a/231216 """ - while True: - try: - return _get_int(prompt) - except (MemoryError, ValueError): - pass + def __init__(self, f): + self.f = f + def __getattr__(self, name): + return object.__getattribute__(self.f, name) -def _get_int(prompt): - user_input = get_string(prompt) - if user_input is None: - return None + def write(self, x): + self.f.write(x) + self.f.flush() - if re.search(r"^[+-]?\d+$", user_input): - return int(user_input, 10) - raise ValueError(f"invalid int literal for base 10: {user_input}") +sys.stderr = _flushfile(sys.stderr) +sys.stdout = _flushfile(sys.stdout) -def get_string(prompt): - """Reads a line of text from standard input and returns it as a string, sans trailing line - ending. Supports CR (\r), LF (\n), and CRLF (\r\n) as line endings. If user inputs only a line - ending, returns "", not None. Returns None upon error or no input whatsoever (i.e., just EOF). +def _formatException(type, value, tb): + """ + Format traceback, darkening entries from global site-packages directories + and user-specific site-packages directory. - :type prompt: str + https://stackoverflow.com/a/46071447/5156190 """ - if not isinstance(prompt, str): - raise TypeError("prompt must be of type str") + # Absolute paths to site-packages + packages = tuple(join(abspath(p), "") for p in sys.path[1:]) - try: - return _get_input(prompt) - except EOFError: - return None + # Highlight lines not referring to files in site-packages + lines = [] + for line in format_exception(type, value, tb): + matches = re.search(r"^ File \"([^\"]+)\", line \d+, in .+", line) + if matches and matches.group(1).startswith(packages): + lines += line + else: + matches = re.search(r"^(\s*)(.*?)(\s*)$", line, re.DOTALL) + lines.append(matches.group(1) + colored(matches.group(2), "yellow") + matches.group(3)) + return "".join(lines).rstrip() -def _get_input(prompt): - return input(prompt) +sys.excepthook = lambda type, value, tb: print(_formatException(type, value, tb), file=sys.stderr) -class _flushfile(): - """ Disable buffering for standard output and standard error. - http://stackoverflow.com/a/231216 - """ +def eprint(*args, **kwargs): + raise RuntimeError("The CS50 Library for Python no longer supports eprint, but you can use print instead!") - def __init__(self, stream): - self.stream = stream - def __getattr__(self, name): - return object.__getattribute__(self.stream, name) +def get_char(prompt): + raise RuntimeError("The CS50 Library for Python no longer supports get_char, but you can use get_string instead!") - def write(self, data): - """Writes data to stream""" - self.stream.write(data) - self.stream.flush() +def get_float(prompt): + """ + Read a line of text from standard input and return the equivalent float + as precisely as possible; if text does not represent a double, user is + prompted to retry. If line can't be read, return None. + """ + while True: + s = get_string(prompt) + if s is None: + return None + if len(s) > 0 and re.search(r"^[+-]?\d*(?:\.\d*)?$", s): + try: + return float(s) + except (OverflowError, ValueError): + pass -def disable_output_buffering(): - """Disables output buffering to prevent prompts from being buffered. + +def get_int(prompt): """ - sys.stderr = _flushfile(sys.stderr) - sys.stdout = _flushfile(sys.stdout) + Read a line of text from standard input and return the equivalent int; + if text does not represent an int, user is prompted to retry. If line + can't be read, return None. + """ + while True: + s = get_string(prompt) + if s is None: + return None + if re.search(r"^[+-]?\d+$", s): + try: + return int(s, 10) + except ValueError: + pass -disable_output_buffering() +def get_string(prompt): + """ + Read a line of text from standard input and return it as a string, + sans trailing line ending. Supports CR (\r), LF (\n), and CRLF (\r\n) + as line endings. If user inputs only a line ending, returns "", not None. + Returns None upon error or no input whatsoever (i.e., just EOF). + """ + if type(prompt) is not str: + raise TypeError("prompt must be of type str") + try: + return input(prompt) + except EOFError: + return None diff --git a/src/cs50/flask.py b/src/cs50/flask.py new file mode 100644 index 0000000..324ec30 --- /dev/null +++ b/src/cs50/flask.py @@ -0,0 +1,38 @@ +import os +import pkgutil +import sys + +def _wrap_flask(f): + if f is None: + return + + from distutils.version import StrictVersion + from .cs50 import _formatException + + if f.__version__ < StrictVersion("1.0"): + return + + if os.getenv("CS50_IDE_TYPE") == "online": + from werkzeug.middleware.proxy_fix import ProxyFix + _flask_init_before = f.Flask.__init__ + def _flask_init_after(self, *args, **kwargs): + _flask_init_before(self, *args, **kwargs) + self.wsgi_app = ProxyFix(self.wsgi_app, x_proto=1) # For HTTPS-to-HTTP proxy + f.Flask.__init__ = _flask_init_after + + +# If Flask was imported before cs50 +if "flask" in sys.modules: + _wrap_flask(sys.modules["flask"]) + +# If Flask wasn't imported +else: + flask_loader = pkgutil.get_loader('flask') + if flask_loader: + _exec_module_before = flask_loader.exec_module + + def _exec_module_after(*args, **kwargs): + _exec_module_before(*args, **kwargs) + _wrap_flask(sys.modules["flask"]) + + flask_loader.exec_module = _exec_module_after diff --git a/src/cs50/sql.py b/src/cs50/sql.py index 64d30e3..2188c6e 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -1,116 +1,536 @@ -import logging +import threading -import sqlalchemy +# Thread-local data +_data = threading.local() -from ._logger import green, red, yellow -from ._engine import Engine -from ._statement import statement_factory -from ._sql_util import postgres_lastval, process_select_result, raise_errors_for_warnings +def _enable_logging(f): + """Enable logging of SQL statements when Flask is in use.""" -_logger = logging.getLogger("cs50") + import logging + import functools + @functools.wraps(f) + def decorator(*args, **kwargs): -class SQL: - """An API for executing SQL Statements. - """ + # Infer whether Flask is installed + try: + import flask + except ModuleNotFoundError: + return f(*args, **kwargs) - def __init__(self, url): + # Enable logging + disabled = logging.getLogger("cs50").disabled + if flask.current_app: + logging.getLogger("cs50").disabled = False + try: + return f(*args, **kwargs) + finally: + logging.getLogger("cs50").disabled = disabled + + return decorator + + +class SQL(object): + """Wrap SQLAlchemy to provide a simple SQL API.""" + + def __init__(self, url, **kwargs): """ - :param url: The database URL + Create instance of sqlalchemy.engine.Engine. + + URL should be a string that indicates database dialect and connection arguments. + + http://docs.sqlalchemy.org/en/latest/core/engines.html#sqlalchemy.create_engine + http://docs.sqlalchemy.org/en/latest/dialects/index.html """ - self._engine = Engine(url) - self._substitute_markers_with_params = statement_factory(self._engine.dialect) + # Lazily import + import logging + import os + import re + import sqlalchemy + import sqlalchemy.orm + import sqlite3 + import threading + + # Require that file already exist for SQLite + matches = re.search(r"^sqlite:///(.+)$", url) + if matches: + if not os.path.exists(matches.group(1)): + raise RuntimeError("does not exist: {}".format(matches.group(1))) + if not os.path.isfile(matches.group(1)): + raise RuntimeError("not a file: {}".format(matches.group(1))) + + # Create engine, disabling SQLAlchemy's own autocommit mode, raising exception if back end's module not installed + self._engine = sqlalchemy.create_engine(url, **kwargs).execution_options(autocommit=False) + + # Get logger + self._logger = logging.getLogger("cs50") + + # Listener for connections + def connect(dbapi_connection, connection_record): + + # Disable underlying API's own emitting of BEGIN and COMMIT so we can ourselves + # https://docs.sqlalchemy.org/en/13/dialects/sqlite.html#serializable-isolation-savepoints-transactional-ddl + dbapi_connection.isolation_level = None + + # Enable foreign key constraints + if type(dbapi_connection) is sqlite3.Connection: # If back end is sqlite + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + + # Register listener + sqlalchemy.event.listen(self._engine, "connect", connect) + + # Autocommit by default + self._autocommit = True + # Test database + disabled = self._logger.disabled + self._logger.disabled = True + try: + connection = self._engine.connect() + connection.execute("SELECT 1") + connection.close() + except sqlalchemy.exc.OperationalError as e: + e = RuntimeError(_parse_exception(e)) + e.__cause__ = None + raise e + finally: + self._logger.disabled = disabled + + def __del__(self): + """Disconnect from database.""" + self._disconnect() + + def _disconnect(self): + """Close database connection.""" + if hasattr(_data, self._name()): + getattr(_data, self._name()).close() + delattr(_data, self._name()) + + def _name(self): + """Return object's hash as a str.""" + return str(hash(self)) + + @_enable_logging def execute(self, sql, *args, **kwargs): - """Executes a SQL statement. + """Execute a SQL statement.""" + + # Lazily import + import decimal + import re + import sqlalchemy + import sqlparse + import termcolor + import warnings + + # Parse statement, stripping comments and then leading/trailing whitespace + statements = sqlparse.parse(sqlparse.format(sql, strip_comments=True).strip()) + + # Allow only one statement at a time, since SQLite doesn't support multiple + # https://docs.python.org/3/library/sqlite3.html#sqlite3.Cursor.execute + if len(statements) > 1: + raise RuntimeError("too many statements at once") + elif len(statements) == 0: + raise RuntimeError("missing statement") + + # Ensure named and positional parameters are mutually exclusive + if len(args) > 0 and len(kwargs) > 0: + raise RuntimeError("cannot pass both positional and named parameters") + + # Infer command from (unflattened) statement + for token in statements[0]: + if token.ttype in [sqlparse.tokens.Keyword, sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML]: + token_value = token.value.upper() + if token_value in ["BEGIN", "DELETE", "INSERT", "SELECT", "START", "UPDATE"]: + command = token_value + break + else: + command = None + + # Flatten statement + tokens = list(statements[0].flatten()) + + # Validate paramstyle + placeholders = {} + paramstyle = None + for index, token in enumerate(tokens): + + # If token is a placeholder + if token.ttype == sqlparse.tokens.Name.Placeholder: + + # Determine paramstyle, name + _paramstyle, name = _parse_placeholder(token) + + # Remember paramstyle + if not paramstyle: + paramstyle = _paramstyle + + # Ensure paramstyle is consistent + elif _paramstyle != paramstyle: + raise RuntimeError("inconsistent paramstyle") + + # Remember placeholder's index, name + placeholders[index] = name + + # If no placeholders + if not paramstyle: + + # Error-check like qmark if args + if args: + paramstyle = "qmark" + + # Error-check like named if kwargs + elif kwargs: + paramstyle = "named" + + # In case of errors + _placeholders = ", ".join([str(tokens[index]) for index in placeholders]) + _args = ", ".join([str(self._escape(arg)) for arg in args]) + + # qmark + if paramstyle == "qmark": + + # Validate number of placeholders + if len(placeholders) != len(args): + if len(placeholders) < len(args): + raise RuntimeError("fewer placeholders ({}) than values ({})".format(_placeholders, _args)) + else: + raise RuntimeError("more placeholders ({}) than values ({})".format(_placeholders, _args)) + + # Escape values + for i, index in enumerate(placeholders.keys()): + tokens[index] = self._escape(args[i]) + + # numeric + elif paramstyle == "numeric": + + # Escape values + for index, i in placeholders.items(): + if i >= len(args): + raise RuntimeError("missing value for placeholder (:{})".format(i + 1, len(args))) + tokens[index] = self._escape(args[i]) + + # Check if any values unused + indices = set(range(len(args))) - set(placeholders.values()) + if indices: + raise RuntimeError("unused {} ({})".format( + "value" if len(indices) == 1 else "values", + ", ".join([str(self._escape(args[index])) for index in indices]))) + + # named + elif paramstyle == "named": + + # Escape values + for index, name in placeholders.items(): + if name not in kwargs: + raise RuntimeError("missing value for placeholder (:{})".format(name)) + tokens[index] = self._escape(kwargs[name]) + + # Check if any keys unused + keys = kwargs.keys() - placeholders.values() + if keys: + raise RuntimeError("unused values ({})".format(", ".join(keys))) + + # format + elif paramstyle == "format": + + # Validate number of placeholders + if len(placeholders) != len(args): + if len(placeholders) < len(args): + raise RuntimeError("fewer placeholders ({}) than values ({})".format(_placeholders, _args)) + else: + raise RuntimeError("more placeholders ({}) than values ({})".format(_placeholders, _args)) + + # Escape values + for i, index in enumerate(placeholders.keys()): + tokens[index] = self._escape(args[i]) + + # pyformat + elif paramstyle == "pyformat": + + # Escape values + for index, name in placeholders.items(): + if name not in kwargs: + raise RuntimeError("missing value for placeholder (%{}s)".format(name)) + tokens[index] = self._escape(kwargs[name]) + + # Check if any keys unused + keys = kwargs.keys() - placeholders.values() + if keys: + raise RuntimeError("unused {} ({})".format( + "value" if len(keys) == 1 else "values", + ", ".join(keys))) - :param sql: a SQL statement, possibly with parameters markers - :type sql: str - :param *args: zero or more positional arguments to substitute the parameter markers with - :param **kwargs: zero or more keyword arguments to substitute the parameter markers with + # For SQL statements where a colon is required verbatim, as within an inline string, use a backslash to escape + # https://docs.sqlalchemy.org/en/13/core/sqlelement.html?highlight=text#sqlalchemy.sql.expression.text + for index, token in enumerate(tokens): - :returns: For ``SELECT``, a :py:class:`list` of :py:class:`dict` objects, each of which - represents a row in the result set; for ``INSERT``, the primary key of a newly inserted row - (or ``None`` if none); for ``UPDATE``, the number of rows updated; for ``DELETE``, the - number of rows deleted; for other statements, ``True``; on integrity errors, a - :py:class:`ValueError` is raised, on other errors, a :py:class:`RuntimeError` is raised + # In string literal + # https://www.sqlite.org/lang_keywords.html + if token.ttype in [sqlparse.tokens.Literal.String, sqlparse.tokens.Literal.String.Single]: + token.value = re.sub("(^'|\s+):", r"\1\:", token.value) + # In identifier + # https://www.sqlite.org/lang_keywords.html + elif token.ttype == sqlparse.tokens.Literal.String.Symbol: + token.value = re.sub("(^\"|\s+):", r"\1\:", token.value) + + # Join tokens into statement + statement = "".join([str(token) for token in tokens]) + + # If no connection yet + if not hasattr(_data, self._name()): + + # Connect to database + setattr(_data, self._name(), self._engine.connect()) + + # Use this connection + connection = getattr(_data, self._name()) + + """TODO + try: + import flask + assert flask.current_app + def teardown_appcontext(exception): + self._disconnect() + if teardown_appcontext not in flask.current_app.teardown_appcontext_funcs: + flask.current_app.teardown_appcontext(teardown_appcontext) + except (ModuleNotFoundError, AssertionError): + pass """ - statement = self._substitute_markers_with_params(sql, *args, **kwargs) - connection = self._engine.get_existing_transaction_connection() - if connection is None: - if statement.is_transaction_start(): - connection = self._engine.get_transaction_connection() - else: - connection = self._engine.get_connection() - elif statement.is_transaction_start(): - raise RuntimeError("nested transactions are not supported") + # Catch SQLAlchemy warnings + with warnings.catch_warnings(): - return self._execute(statement, connection) + # Raise exceptions for warnings + warnings.simplefilter("error") - def _execute(self, statement, connection): - with raise_errors_for_warnings(): + # Prepare, execute statement try: - result = connection.execute(str(statement)) - # E.g., failed constraint - except sqlalchemy.exc.IntegrityError as exc: - _logger.debug(yellow(statement)) - if self._engine.get_existing_transaction_connection() is None: - connection.close() - raise ValueError(exc.orig) from None - # E.g., connection error or syntax error - except (sqlalchemy.exc.OperationalError, sqlalchemy.exc.ProgrammingError) as exc: - if self._engine.get_existing_transaction_connection(): - self._engine.close_transaction_connection() - else: - connection.close() - _logger.debug(red(statement)) - raise RuntimeError(exc.orig) from None - - _logger.debug(green(statement)) - - if statement.is_select(): - ret = process_select_result(result) - elif statement.is_insert(): - ret = self._last_row_id_or_none(result) - elif statement.is_delete() or statement.is_update(): - ret = result.rowcount - else: + + # Join tokens into statement, abbreviating binary data as <class 'bytes'> + _statement = "".join([str(bytes) if token.ttype == sqlparse.tokens.Other else str(token) for token in tokens]) + + # Check for start of transaction + if command in ["BEGIN", "START"]: + self._autocommit = False + + # Execute statement + if self._autocommit: + connection.execute(sqlalchemy.text("BEGIN")) + result = connection.execute(sqlalchemy.text(statement)) + if self._autocommit: + connection.execute(sqlalchemy.text("COMMIT")) + + # Check for end of transaction + if command in ["COMMIT", "ROLLBACK"]: + self._autocommit = True + + # Return value ret = True - if self._engine.get_existing_transaction_connection(): - if statement.is_transaction_end(): - self._engine.close_transaction_connection() - else: - connection.close() + # If SELECT, return result set as list of dict objects + if command == "SELECT": + + # Coerce types + rows = [dict(row) for row in result.fetchall()] + for row in rows: + for column in row: + + # Coerce decimal.Decimal objects to float objects + # https://groups.google.com/d/msg/sqlalchemy/0qXMYJvq8SA/oqtvMD9Uw-kJ + if type(row[column]) is decimal.Decimal: + row[column] = float(row[column]) + + # Coerce memoryview objects (as from PostgreSQL's bytea columns) to bytes + elif type(row[column]) is memoryview: + row[column] = bytes(row[column]) + + # Rows to be returned + ret = rows + + # If INSERT, return primary key value for a newly inserted row (or None if none) + elif command == "INSERT": + if self._engine.url.get_backend_name() in ["postgres", "postgresql"]: + try: + result = connection.execute("SELECT LASTVAL()") + ret = result.first()[0] + except sqlalchemy.exc.OperationalError: # If lastval is not yet defined for this connection + ret = None + else: + ret = result.lastrowid if result.rowcount == 1 else None + + # If DELETE or UPDATE, return number of rows matched + elif command in ["DELETE", "UPDATE"]: + ret = result.rowcount - return ret + # If constraint violated, return None + except sqlalchemy.exc.IntegrityError as e: + self._logger.debug(termcolor.colored(statement, "yellow")) + e = ValueError(e.orig) + e.__cause__ = None + raise e - def _last_row_id_or_none(self, result): + # If user error + except (sqlalchemy.exc.OperationalError, sqlalchemy.exc.ProgrammingError) as e: + self._disconnect() + self._logger.debug(termcolor.colored(statement, "red")) + e = RuntimeError(e.orig) + e.__cause__ = None + raise e + + # Return value + else: + self._logger.debug(termcolor.colored(_statement, "green")) + if self._autocommit: # Don't stay connected unnecessarily + self._disconnect() + return ret + + def _escape(self, value): """ - :param result: A SQLAlchemy result object - :type result: :class:`sqlalchemy.engine.Result` + Escapes value using engine's conversion function. - :returns: The ID of the last inserted row or ``None`` + https://docs.sqlalchemy.org/en/latest/core/type_api.html#sqlalchemy.types.TypeEngine.literal_processor """ - if self._engine.is_postgres(): - return postgres_lastval(result.connection) - return result.lastrowid if result.rowcount == 1 else None + # Lazily import + import sqlparse - def init_app(self, app): - """Enables logging and registers a ``teardown_appcontext`` listener to remove the session. + def __escape(value): - :param app: a Flask application instance - :type app: :class:`flask.Flask` - """ + # Lazily import + import datetime + import sqlalchemy + + # bool + if type(value) is bool: + return sqlparse.sql.Token( + sqlparse.tokens.Number, + sqlalchemy.types.Boolean().literal_processor(self._engine.dialect)(value)) + + # bytes + elif type(value) is bytes: + if self._engine.url.get_backend_name() in ["mysql", "sqlite"]: + return sqlparse.sql.Token(sqlparse.tokens.Other, f"x'{value.hex()}'") # https://dev.mysql.com/doc/refman/8.0/en/hexadecimal-literals.html + elif self._engine.url.get_backend_name() == "postgresql": + return sqlparse.sql.Token(sqlparse.tokens.Other, f"'\\x{value.hex()}'") # https://dba.stackexchange.com/a/203359 + else: + raise RuntimeError("unsupported value: {}".format(value)) + + # datetime.date + elif type(value) is datetime.date: + return sqlparse.sql.Token( + sqlparse.tokens.String, + sqlalchemy.types.String().literal_processor(self._engine.dialect)(value.strftime("%Y-%m-%d"))) + + # datetime.datetime + elif type(value) is datetime.datetime: + return sqlparse.sql.Token( + sqlparse.tokens.String, + sqlalchemy.types.String().literal_processor(self._engine.dialect)(value.strftime("%Y-%m-%d %H:%M:%S"))) + + # datetime.time + elif type(value) is datetime.time: + return sqlparse.sql.Token( + sqlparse.tokens.String, + sqlalchemy.types.String().literal_processor(self._engine.dialect)(value.strftime("%H:%M:%S"))) + + # float + elif type(value) is float: + return sqlparse.sql.Token( + sqlparse.tokens.Number, + sqlalchemy.types.Float().literal_processor(self._engine.dialect)(value)) + + # int + elif type(value) is int: + return sqlparse.sql.Token( + sqlparse.tokens.Number, + sqlalchemy.types.Integer().literal_processor(self._engine.dialect)(value)) + + # str + elif type(value) is str: + return sqlparse.sql.Token( + sqlparse.tokens.String, + sqlalchemy.types.String().literal_processor(self._engine.dialect)(value)) + + # None + elif value is None: + return sqlparse.sql.Token( + sqlparse.tokens.Keyword, + sqlalchemy.types.NullType().literal_processor(self._engine.dialect)(value)) + + # Unsupported value + else: + raise RuntimeError("unsupported value: {}".format(value)) + + # Escape value(s), separating with commas as needed + if type(value) in [list, tuple]: + return sqlparse.sql.TokenList(sqlparse.parse(", ".join([str(__escape(v)) for v in value]))) + else: + return __escape(value) + + +def _parse_exception(e): + """Parses an exception, returns its message.""" + + # Lazily import + import re + + # MySQL + matches = re.search(r"^\(_mysql_exceptions\.OperationalError\) \(\d+, \"(.+)\"\)$", str(e)) + if matches: + return matches.group(1) + + # PostgreSQL + matches = re.search(r"^\(psycopg2\.OperationalError\) (.+)$", str(e)) + if matches: + return matches.group(1) + + # SQLite + matches = re.search(r"^\(sqlite3\.OperationalError\) (.+)$", str(e)) + if matches: + return matches.group(1) + + # Default + return str(e) + + +def _parse_placeholder(token): + """Infers paramstyle, name from sqlparse.tokens.Name.Placeholder.""" + + # Lazily load + import re + import sqlparse + + # Validate token + if not isinstance(token, sqlparse.sql.Token) or token.ttype != sqlparse.tokens.Name.Placeholder: + raise TypeError() + + # qmark + if token.value == "?": + return "qmark", None + + # numeric + matches = re.search(r"^:([1-9]\d*)$", token.value) + if matches: + return "numeric", int(matches.group(1)) - 1 + + # named + matches = re.search(r"^:([a-zA-Z]\w*)$", token.value) + if matches: + return "named", matches.group(1) - @app.teardown_appcontext - def _(_): - self._engine.close_transaction_connection() + # format + if token.value == "%s": + return "format", None + # pyformat + matches = re.search(r"%\((\w+)\)s$", token.value) + if matches: + return "pyformat", matches.group(1) - logging.getLogger("cs50").disabled = False + # Invalid + raise RuntimeError("{}: invalid placeholder".format(token.value)) diff --git a/tests/flask/application.py b/tests/flask/application.py new file mode 100644 index 0000000..939a8f9 --- /dev/null +++ b/tests/flask/application.py @@ -0,0 +1,22 @@ +import requests +import sys +from flask import Flask, render_template + +sys.path.insert(0, "../../src") + +import cs50 +import cs50.flask + +app = Flask(__name__) + +db = cs50.SQL("sqlite:///../sqlite.db") + +@app.route("/") +def index(): + db.execute("SELECT 1") + """ + def f(): + res = requests.get("cs50.harvard.edu") + f() + """ + return render_template("index.html") diff --git a/tests/flask/requirements.txt b/tests/flask/requirements.txt new file mode 100644 index 0000000..7d0c101 --- /dev/null +++ b/tests/flask/requirements.txt @@ -0,0 +1,2 @@ +cs50 +Flask diff --git a/tests/flask/templates/error.html b/tests/flask/templates/error.html new file mode 100644 index 0000000..3302040 --- /dev/null +++ b/tests/flask/templates/error.html @@ -0,0 +1,10 @@ +<!DOCTYPE html> + +<html> + <head> + <title>error</title> + </head> + <body> + error + </body> +</html> diff --git a/tests/flask/templates/index.html b/tests/flask/templates/index.html new file mode 100644 index 0000000..2f6a145 --- /dev/null +++ b/tests/flask/templates/index.html @@ -0,0 +1,10 @@ +<!DOCTYPE html> + +<html> + <head> + <title>flask</title> + </head> + <body> + flask + </body> +</html> diff --git a/tests/foo.py b/tests/foo.py new file mode 100644 index 0000000..7f32a00 --- /dev/null +++ b/tests/foo.py @@ -0,0 +1,48 @@ +import logging +import sys + +sys.path.insert(0, "../src") + +import cs50 + +""" +db = cs50.SQL("sqlite:///foo.db") + +logging.getLogger("cs50").disabled = False + +#db.execute("SELECT ? FROM ? ORDER BY ?", "a", "tbl", "c") +db.execute("CREATE TABLE IF NOT EXISTS bar (firstname STRING)") + +db.execute("INSERT INTO bar VALUES (?)", "baz") +db.execute("INSERT INTO bar VALUES (?)", "qux") +db.execute("SELECT * FROM bar WHERE firstname IN (?)", ("baz", "qux")) +db.execute("DELETE FROM bar") +""" + +db = cs50.SQL("postgresql://postgres@localhost/test") + +""" +print(db.execute("DROP TABLE IF EXISTS cs50")) +print(db.execute("CREATE TABLE cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16), bin BYTEA)")) +print(db.execute("INSERT INTO cs50 (val) VALUES('foo')")) +print(db.execute("SELECT * FROM cs50")) + +print(db.execute("DROP TABLE IF EXISTS cs50")) +print(db.execute("CREATE TABLE cs50 (val VARCHAR(16), bin BYTEA)")) +print(db.execute("INSERT INTO cs50 (val) VALUES('foo')")) +print(db.execute("SELECT * FROM cs50")) +""" + +print(db.execute("DROP TABLE IF EXISTS cs50")) +print(db.execute("CREATE TABLE cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16), bin BYTEA)")) +print(db.execute("INSERT INTO cs50 (val) VALUES('foo')")) +print(db.execute("INSERT INTO cs50 (val) VALUES('bar')")) +print(db.execute("INSERT INTO cs50 (val) VALUES('baz')")) +print(db.execute("SELECT * FROM cs50")) +try: + print(db.execute("INSERT INTO cs50 (id, val) VALUES(1, 'bar')")) +except Exception as e: + print(e) + pass +print(db.execute("INSERT INTO cs50 (val) VALUES('qux')")) +#print(db.execute("DELETE FROM cs50")) diff --git a/tests/mysql.py b/tests/mysql.py new file mode 100644 index 0000000..2a431c3 --- /dev/null +++ b/tests/mysql.py @@ -0,0 +1,8 @@ +import sys + +sys.path.insert(0, "../src") + +from cs50 import SQL + +db = SQL("mysql://root@localhost/test") +db.execute("SELECT 1") diff --git a/tests/python.py b/tests/python.py new file mode 100644 index 0000000..6a265cb --- /dev/null +++ b/tests/python.py @@ -0,0 +1,8 @@ +import sys + +sys.path.insert(0, "../src") + +import cs50 + +i = cs50.get_int("Input: ") +print(f"Output: {i}") diff --git a/tests/redirect/application.py b/tests/redirect/application.py new file mode 100644 index 0000000..6aff187 --- /dev/null +++ b/tests/redirect/application.py @@ -0,0 +1,12 @@ +import cs50 +from flask import Flask, redirect, render_template + +app = Flask(__name__) + +@app.route("/") +def index(): + return redirect("/foo") + +@app.route("/foo") +def foo(): + return render_template("foo.html") diff --git a/tests/redirect/templates/foo.html b/tests/redirect/templates/foo.html new file mode 100644 index 0000000..257cc56 --- /dev/null +++ b/tests/redirect/templates/foo.html @@ -0,0 +1 @@ +foo diff --git a/tests/sql.py b/tests/sql.py index 7caaf63..e4757c7 100644 --- a/tests/sql.py +++ b/tests/sql.py @@ -9,6 +9,7 @@ class SQLTests(unittest.TestCase): + def test_multiple_statements(self): self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO cs50(val) VALUES('baz'); INSERT INTO cs50(val) VALUES('qux')") @@ -27,7 +28,6 @@ def test_delete_returns_affected_rows(self): def test_insert_returns_last_row_id(self): self.assertEqual(self.db.execute("INSERT INTO cs50(val) VALUES('foo')"), 1) self.assertEqual(self.db.execute("INSERT INTO cs50(val) VALUES('bar')"), 2) - self.assertEqual(self.db.execute("INSERT INTO cs50(val) VALUES('qux')"), 3) def test_select_all(self): self.assertEqual(self.db.execute("SELECT * FROM cs50"), []) @@ -132,13 +132,64 @@ def test_rollback(self): def test_identifier_case(self): self.assertIn("count", self.db.execute("SELECT 1 AS count")[0]) - def test_none(self): - self.db.execute("CREATE TABLE foo (val INTEGER)") - self.db.execute("SELECT * FROM foo WHERE val = ?", None) + def tearDown(self): + self.db.execute("DROP TABLE cs50") + self.db.execute("DROP TABLE IF EXISTS foo") + self.db.execute("DROP TABLE IF EXISTS bar") + + @classmethod + def tearDownClass(self): + try: + self.db.execute("DROP TABLE IF EXISTS cs50") + except Warning as e: + # suppress "unknown table" + if not str(e).startswith("(1051"): + raise e + + +class MySQLTests(SQLTests): + @classmethod + def setUpClass(self): + self.db = SQL("mysql://root@localhost/test") + + def setUp(self): + self.db.execute("CREATE TABLE IF NOT EXISTS cs50 (id INTEGER NOT NULL AUTO_INCREMENT, val VARCHAR(16), bin BLOB, PRIMARY KEY (id))") + self.db.execute("DELETE FROM cs50") + + +class PostgresTests(SQLTests): + @classmethod + def setUpClass(self): + self.db = SQL("postgresql://postgres@localhost/test") + + def setUp(self): + self.db.execute("CREATE TABLE IF NOT EXISTS cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16), bin BYTEA)") + self.db.execute("DELETE FROM cs50") + + def test_cte(self): + self.assertEqual(self.db.execute("WITH foo AS ( SELECT 1 AS bar ) SELECT bar FROM foo"), [{"bar": 1}]) + + +class SQLiteTests(SQLTests): + + @classmethod + def setUpClass(self): + open("test.db", "w").close() + self.db = SQL("sqlite:///test.db") + + def setUp(self): + self.db.execute("CREATE TABLE IF NOT EXISTS cs50 (id INTEGER PRIMARY KEY, val TEXT, bin BLOB)") + self.db.execute("DELETE FROM cs50") + + def test_lastrowid(self): + self.db.execute("CREATE TABLE foo(id INTEGER PRIMARY KEY AUTOINCREMENT, firstname TEXT, lastname TEXT)") + self.assertEqual(self.db.execute("INSERT INTO foo (firstname, lastname) VALUES('firstname', 'lastname')"), 1) + self.assertRaises(ValueError, self.db.execute, "INSERT INTO foo (id, firstname, lastname) VALUES(1, 'firstname', 'lastname')") + self.assertEqual(self.db.execute("INSERT OR IGNORE INTO foo (id, firstname, lastname) VALUES(1, 'firstname', 'lastname')"), None) def test_integrity_constraints(self): self.db.execute("CREATE TABLE foo(id INTEGER PRIMARY KEY)") - self.db.execute("INSERT INTO foo VALUES(1)") + self.assertEqual(self.db.execute("INSERT INTO foo VALUES(1)"), 1) self.assertRaises(ValueError, self.db.execute, "INSERT INTO foo VALUES(1)") def test_foreign_key_support(self): @@ -147,7 +198,7 @@ def test_foreign_key_support(self): self.assertRaises(ValueError, self.db.execute, "INSERT INTO bar VALUES(50)") def test_qmark(self): - self.db.execute("CREATE TABLE foo (firstname VARCHAR(255), lastname VARCHAR(255))") + self.db.execute("CREATE TABLE foo (firstname STRING, lastname STRING)") self.db.execute("INSERT INTO foo VALUES (?, 'bar')", "baz") self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "baz", "lastname": "bar"}]) @@ -177,7 +228,7 @@ def test_qmark(self): self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "bar", "lastname": "baz"}]) self.db.execute("DELETE FROM foo") - self.db.execute("CREATE TABLE bar (firstname VARCHAR(255))") + self.db.execute("CREATE TABLE bar (firstname STRING)") self.db.execute("INSERT INTO bar VALUES (?)", "baz") self.assertEqual(self.db.execute("SELECT * FROM bar"), [{"firstname": "baz"}]) @@ -201,7 +252,7 @@ def test_qmark(self): self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo VALUES (?, ?)", 'bar', baz='baz') def test_named(self): - self.db.execute("CREATE TABLE foo (firstname VARCHAR(255), lastname VARCHAR(255))") + self.db.execute("CREATE TABLE foo (firstname STRING, lastname STRING)") self.db.execute("INSERT INTO foo VALUES (:baz, 'bar')", baz="baz") self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "baz", "lastname": "bar"}]) @@ -223,11 +274,7 @@ def test_named(self): self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "bar", "lastname": "baz"}]) self.db.execute("DELETE FROM foo") - self.db.execute("INSERT INTO foo VALUES (:baz, :baz)", baz="baz") - self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "baz", "lastname": "baz"}]) - self.db.execute("DELETE FROM foo") - - self.db.execute("CREATE TABLE bar (firstname VARCHAR(255))") + self.db.execute("CREATE TABLE bar (firstname STRING)") self.db.execute("INSERT INTO bar VALUES (:baz)", baz="baz") self.assertEqual(self.db.execute("SELECT * FROM bar"), [{"firstname": "baz"}]) @@ -236,8 +283,9 @@ def test_named(self): self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo VALUES (:bar, :baz)", bar='bar', baz='baz', qux='qux') self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo VALUES (:bar, :baz)", 'baz', bar='bar') + def test_numeric(self): - self.db.execute("CREATE TABLE foo (firstname VARCHAR(255), lastname VARCHAR(255))") + self.db.execute("CREATE TABLE foo (firstname STRING, lastname STRING)") self.db.execute("INSERT INTO foo VALUES (:1, 'bar')", "baz") self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "baz", "lastname": "bar"}]) @@ -259,7 +307,7 @@ def test_numeric(self): self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "bar", "lastname": "baz"}]) self.db.execute("DELETE FROM foo") - self.db.execute("CREATE TABLE bar (firstname VARCHAR(255))") + self.db.execute("CREATE TABLE bar (firstname STRING)") self.db.execute("INSERT INTO bar VALUES (:1)", "baz") self.assertEqual(self.db.execute("SELECT * FROM bar"), [{"firstname": "baz"}]) @@ -271,51 +319,6 @@ def test_numeric(self): def test_cte(self): self.assertEqual(self.db.execute("WITH foo AS ( SELECT 1 AS bar ) SELECT bar FROM foo"), [{"bar": 1}]) - def tearDown(self): - self.db.execute("DROP TABLE IF EXISTS cs50") - self.db.execute("DROP TABLE IF EXISTS bar") - self.db.execute("DROP TABLE IF EXISTS foo") - -class MySQLTests(SQLTests): - @classmethod - def setUpClass(self): - self.db = SQL("mysql://root@127.0.0.1/test") - - def setUp(self): - self.db.execute("CREATE TABLE IF NOT EXISTS cs50 (id INTEGER NOT NULL AUTO_INCREMENT, val VARCHAR(16), bin BLOB, PRIMARY KEY (id))") - self.db.execute("DELETE FROM cs50") - -class PostgresTests(SQLTests): - @classmethod - def setUpClass(self): - self.db = SQL("postgresql://postgres:postgres@127.0.0.1/test") - - def setUp(self): - self.db.execute("CREATE TABLE IF NOT EXISTS cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16), bin BYTEA)") - self.db.execute("DELETE FROM cs50") - - def test_cte(self): - self.assertEqual(self.db.execute("WITH foo AS ( SELECT 1 AS bar ) SELECT bar FROM foo"), [{"bar": 1}]) - - def test_postgres_scheme(self): - db = SQL("postgres://postgres:postgres@127.0.0.1/test") - db.execute("SELECT 1") - -class SQLiteTests(SQLTests): - @classmethod - def setUpClass(self): - open("test.db", "w").close() - self.db = SQL("sqlite:///test.db") - - def setUp(self): - self.db.execute("CREATE TABLE IF NOT EXISTS cs50 (id INTEGER PRIMARY KEY, val TEXT, bin BLOB)") - self.db.execute("DELETE FROM cs50") - - def test_lastrowid(self): - self.db.execute("CREATE TABLE foo(id INTEGER PRIMARY KEY AUTOINCREMENT, firstname TEXT, lastname TEXT)") - self.assertEqual(self.db.execute("INSERT INTO foo (firstname, lastname) VALUES('firstname', 'lastname')"), 1) - self.assertRaises(ValueError, self.db.execute, "INSERT INTO foo (id, firstname, lastname) VALUES(1, 'firstname', 'lastname')") - self.assertEqual(self.db.execute("INSERT OR IGNORE INTO foo (id, firstname, lastname) VALUES(1, 'firstname', 'lastname')"), None) if __name__ == "__main__": suite = unittest.TestSuite([ diff --git a/tests/sqlite.py b/tests/sqlite.py new file mode 100644 index 0000000..05c2cea --- /dev/null +++ b/tests/sqlite.py @@ -0,0 +1,44 @@ +import logging +import sys + +sys.path.insert(0, "../src") + +from cs50 import SQL + +logging.getLogger("cs50").disabled = False + +db = SQL("sqlite:///sqlite.db") +db.execute("SELECT 1") + +# TODO +#db.execute("SELECT * FROM Employee WHERE FirstName = ?", b'\x00') + +db.execute("SELECT * FROM Employee WHERE FirstName = ?", "' OR 1 = 1") + +db.execute("SELECT * FROM Employee WHERE FirstName IN (?)", "Andrew") +db.execute("SELECT * FROM Employee WHERE FirstName IN (?)", ["Andrew"]) +db.execute("SELECT * FROM Employee WHERE FirstName IN (?)", ("Andrew",)) +db.execute("SELECT * FROM Employee WHERE FirstName IN (?)", ["Andrew", "Nancy"]) +db.execute("SELECT * FROM Employee WHERE FirstName IN (?)", ("Andrew", "Nancy")) +db.execute("SELECT * FROM Employee WHERE FirstName IN (?)", []) +db.execute("SELECT * FROM Employee WHERE FirstName IN (?)", ()) + +db.execute("SELECT * FROM Employee WHERE FirstName = ? AND LastName = ?", "Andrew", "Adams") +db.execute("SELECT * FROM Employee WHERE FirstName = ? AND LastName = ?", ["Andrew", "Adams"]) +db.execute("SELECT * FROM Employee WHERE FirstName = ? AND LastName = ?", ("Andrew", "Adams")) + +db.execute("SELECT * FROM Employee WHERE FirstName = :1 AND LastName = :2", "Andrew", "Adams") +db.execute("SELECT * FROM Employee WHERE FirstName = :1 AND LastName = :2", ["Andrew", "Adams"]) +db.execute("SELECT * FROM Employee WHERE FirstName = :1 AND LastName = :2", ("Andrew", "Adams")) + +db.execute("SELECT * FROM Employee WHERE FirstName = ':Andrew :Adams'") + +db.execute("SELECT * FROM Employee WHERE FirstName = :first AND LastName = :last", first="Andrew", last="Adams") +db.execute("SELECT * FROM Employee WHERE FirstName = :first AND LastName = :last", {"first": "Andrew", "last": "Adams"}) + +db.execute("SELECT * FROM Employee WHERE FirstName = %s AND LastName = %s", "Andrew", "Adams") +db.execute("SELECT * FROM Employee WHERE FirstName = %s AND LastName = %s", ["Andrew", "Adams"]) +db.execute("SELECT * FROM Employee WHERE FirstName = %s AND LastName = %s", ("Andrew", "Adams")) + +db.execute("SELECT * FROM Employee WHERE FirstName = %(first)s AND LastName = %(last)s", first="Andrew", last="Adams") +db.execute("SELECT * FROM Employee WHERE FirstName = %(first)s AND LastName = %(last)s", {"first": "Andrew", "last": "Adams"}) diff --git a/tests/tb.py b/tests/tb.py new file mode 100644 index 0000000..3ad8175 --- /dev/null +++ b/tests/tb.py @@ -0,0 +1,10 @@ +import sys + +sys.path.insert(0, "../src") + +import cs50 +import requests + +def f(): + res = requests.get("cs50.harvard.edu") +f() diff --git a/tests/test_cs50.py b/tests/test_cs50.py deleted file mode 100644 index 9a0faca..0000000 --- a/tests/test_cs50.py +++ /dev/null @@ -1,141 +0,0 @@ -import sys -import unittest - -from unittest.mock import patch - -from cs50.cs50 import get_string, _get_int, _get_float - - -class TestCS50(unittest.TestCase): - @patch("cs50.cs50._get_input", return_value="") - def test_get_string_empty_input(self, mock_get_input): - """Returns empty string when input is empty""" - self.assertEqual(get_string("Answer: "), "") - mock_get_input.assert_called_with("Answer: ") - - @patch("cs50.cs50._get_input", return_value="test") - def test_get_string_nonempty_input(self, mock_get_input): - """Returns the provided non-empty input""" - self.assertEqual(get_string("Answer: "), "test") - mock_get_input.assert_called_with("Answer: ") - - @patch("cs50.cs50._get_input", side_effect=EOFError) - def test_get_string_eof(self, mock_get_input): - """Returns None on EOF""" - self.assertIs(get_string("Answer: "), None) - mock_get_input.assert_called_with("Answer: ") - - def test_get_string_invalid_prompt(self): - """Raises TypeError when prompt is not str""" - with self.assertRaises(TypeError): - get_string(1) - - @patch("cs50.cs50.get_string", return_value=None) - def test_get_int_eof(self, mock_get_string): - """Returns None on EOF""" - self.assertIs(_get_int("Answer: "), None) - mock_get_string.assert_called_with("Answer: ") - - def test_get_int_valid_input(self): - """Returns the provided integer input""" - - def assert_equal(return_value, expected_value): - with patch("cs50.cs50.get_string", return_value=return_value) as mock_get_string: - self.assertEqual(_get_int("Answer: "), expected_value) - mock_get_string.assert_called_with("Answer: ") - - values = [ - ("0", 0), - ("50", 50), - ("+50", 50), - ("+42", 42), - ("-42", -42), - ("42", 42), - ] - - for return_value, expected_value in values: - assert_equal(return_value, expected_value) - - def test_get_int_invalid_input(self): - """Raises ValueError when input is invalid base-10 int""" - - def assert_raises_valueerror(return_value): - with patch("cs50.cs50.get_string", return_value=return_value) as mock_get_string: - with self.assertRaises(ValueError): - _get_int("Answer: ") - - mock_get_string.assert_called_with("Answer: ") - - return_values = [ - "++50", - "--50", - "50+", - "50-", - " 50", - " +50", - " -50", - "50 ", - "ab50", - "50ab", - "ab50ab", - ] - - for return_value in return_values: - assert_raises_valueerror(return_value) - - @patch("cs50.cs50.get_string", return_value=None) - def test_get_float_eof(self, mock_get_string): - """Returns None on EOF""" - self.assertIs(_get_float("Answer: "), None) - mock_get_string.assert_called_with("Answer: ") - - def test_get_float_valid_input(self): - """Returns the provided integer input""" - def assert_equal(return_value, expected_value): - with patch("cs50.cs50.get_string", return_value=return_value) as mock_get_string: - f = _get_float("Answer: ") - self.assertAlmostEqual(f, expected_value) - mock_get_string.assert_called_with("Answer: ") - - values = [ - (".0", 0.0), - ("0.", 0.0), - (".42", 0.42), - ("42.", 42.0), - ("50", 50.0), - ("+50", 50.0), - ("-50", -50.0), - ("+3.14", 3.14), - ("-3.14", -3.14), - ] - - for return_value, expected_value in values: - assert_equal(return_value, expected_value) - - def test_get_float_invalid_input(self): - """Raises ValueError when input is invalid float""" - - def assert_raises_valueerror(return_value): - with patch("cs50.cs50.get_string", return_value=return_value) as mock_get_string: - with self.assertRaises(ValueError): - _get_float("Answer: ") - - mock_get_string.assert_called_with("Answer: ") - - return_values = [ - ".", - "..5", - "a.5", - ".5a" - "0.5a", - "a0.42", - " .42", - "3.14 ", - "++3.14", - "3.14+", - "--3.14", - "3.14--", - ] - - for return_value in return_values: - assert_raises_valueerror(return_value) From 31136282158f17e58a5ea840a19fc2cc94bee553 Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Wed, 8 Dec 2021 18:25:31 -0500 Subject: [PATCH 100/159] adding teardown_appcontext --- src/cs50/sql.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index 2188c6e..74078aa 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -290,7 +290,7 @@ def execute(self, sql, *args, **kwargs): # Use this connection connection = getattr(_data, self._name()) - """TODO + # Disconnect if/when a Flask app is torn down try: import flask assert flask.current_app @@ -300,7 +300,6 @@ def teardown_appcontext(exception): flask.current_app.teardown_appcontext(teardown_appcontext) except (ModuleNotFoundError, AssertionError): pass - """ # Catch SQLAlchemy warnings with warnings.catch_warnings(): From 4c6108dbc45046725e3d61ec4d1b2b6ad2eea544 Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Wed, 8 Dec 2021 18:33:15 -0500 Subject: [PATCH 101/159] updated tests --- tests/sql.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/sql.py b/tests/sql.py index e4757c7..e02502d 100644 --- a/tests/sql.py +++ b/tests/sql.py @@ -151,6 +151,7 @@ class MySQLTests(SQLTests): @classmethod def setUpClass(self): self.db = SQL("mysql://root@localhost/test") + self.db = SQL("mysql://root@127.0.0.1/test") def setUp(self): self.db.execute("CREATE TABLE IF NOT EXISTS cs50 (id INTEGER NOT NULL AUTO_INCREMENT, val VARCHAR(16), bin BLOB, PRIMARY KEY (id))") @@ -160,7 +161,7 @@ def setUp(self): class PostgresTests(SQLTests): @classmethod def setUpClass(self): - self.db = SQL("postgresql://postgres@localhost/test") + self.db = SQL("postgresql://postgres:postgres@127.0.0.1/test") def setUp(self): self.db.execute("CREATE TABLE IF NOT EXISTS cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16), bin BYTEA)") @@ -169,6 +170,10 @@ def setUp(self): def test_cte(self): self.assertEqual(self.db.execute("WITH foo AS ( SELECT 1 AS bar ) SELECT bar FROM foo"), [{"bar": 1}]) + def test_postgres_scheme(self): + db = SQL("postgres://postgres:postgres@127.0.0.1/test") + db.execute("SELECT 1") + class SQLiteTests(SQLTests): From e1ed9579f008b7345094fc2480d23b7fff681397 Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Wed, 8 Dec 2021 18:35:14 -0500 Subject: [PATCH 102/159] removed old MySQL test --- tests/sql.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/sql.py b/tests/sql.py index e02502d..4fd4c5f 100644 --- a/tests/sql.py +++ b/tests/sql.py @@ -150,7 +150,6 @@ def tearDownClass(self): class MySQLTests(SQLTests): @classmethod def setUpClass(self): - self.db = SQL("mysql://root@localhost/test") self.db = SQL("mysql://root@127.0.0.1/test") def setUp(self): From 7240955409661eb0dad76e44e306f609466a1d38 Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Wed, 8 Dec 2021 18:37:56 -0500 Subject: [PATCH 103/159] removed postgres, but not postgresql, support --- setup.py | 2 +- tests/sql.py | 4 ---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/setup.py b/setup.py index faf7abd..af297a0 100644 --- a/setup.py +++ b/setup.py @@ -16,5 +16,5 @@ package_dir={"": "src"}, packages=["cs50"], url="https://github.com/cs50/python-cs50", - version="7.1.0" + version="8.0.0" ) diff --git a/tests/sql.py b/tests/sql.py index 4fd4c5f..89853a7 100644 --- a/tests/sql.py +++ b/tests/sql.py @@ -169,10 +169,6 @@ def setUp(self): def test_cte(self): self.assertEqual(self.db.execute("WITH foo AS ( SELECT 1 AS bar ) SELECT bar FROM foo"), [{"bar": 1}]) - def test_postgres_scheme(self): - db = SQL("postgres://postgres:postgres@127.0.0.1/test") - db.execute("SELECT 1") - class SQLiteTests(SQLTests): From d827119a10882c9ee588e2f17c14cdaff577475f Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Wed, 8 Dec 2021 19:01:50 -0500 Subject: [PATCH 104/159] removed other postgres mention --- src/cs50/sql.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index 74078aa..a9759e7 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -353,7 +353,7 @@ def teardown_appcontext(exception): # If INSERT, return primary key value for a newly inserted row (or None if none) elif command == "INSERT": - if self._engine.url.get_backend_name() in ["postgres", "postgresql"]: + if self._engine.url.get_backend_name() == "postgresql": try: result = connection.execute("SELECT LASTVAL()") ret = result.first()[0] From cb031402a16edd1bb19cbd905d8b648427694dac Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Wed, 8 Dec 2021 19:12:42 -0500 Subject: [PATCH 105/159] updated tests --- README.md | 2 +- tests/sql.py | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 3d6eed8..c5830b2 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ s = cs50.get_string(); 1. In `/etc/profile.d/cli.sh`, remove `valgrind` function for now. 1. Run `service mysql start`. 1. Run `mysql -e 'CREATE DATABASE IF NOT EXISTS test;'`. -1. In `/etc/postgresql/10/main/pg_hba.conf, change: +1. In `/etc/postgresql/12/main/pg_hba.conf, change: ``` local all postgres peer host all all 127.0.0.1/32 md5 diff --git a/tests/sql.py b/tests/sql.py index 89853a7..ff61b64 100644 --- a/tests/sql.py +++ b/tests/sql.py @@ -132,6 +132,11 @@ def test_rollback(self): def test_identifier_case(self): self.assertIn("count", self.db.execute("SELECT 1 AS count")[0]) + def test_lastrowid(self): + self.db.execute("CREATE TABLE foo(id SERIAL PRIMARY KEY, firstname TEXT, lastname TEXT)") + self.assertEqual(self.db.execute("INSERT INTO foo (firstname, lastname) VALUES('firstname', 'lastname')"), 1) + self.assertRaises(ValueError, self.db.execute, "INSERT INTO foo (id, firstname, lastname) VALUES(1, 'firstname', 'lastname')") + def tearDown(self): self.db.execute("DROP TABLE cs50") self.db.execute("DROP TABLE IF EXISTS foo") @@ -166,6 +171,7 @@ def setUp(self): self.db.execute("CREATE TABLE IF NOT EXISTS cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16), bin BYTEA)") self.db.execute("DELETE FROM cs50") + def test_cte(self): self.assertEqual(self.db.execute("WITH foo AS ( SELECT 1 AS bar ) SELECT bar FROM foo"), [{"bar": 1}]) @@ -323,7 +329,7 @@ def test_cte(self): if __name__ == "__main__": suite = unittest.TestSuite([ unittest.TestLoader().loadTestsFromTestCase(SQLiteTests), - unittest.TestLoader().loadTestsFromTestCase(MySQLTests), + #unittest.TestLoader().loadTestsFromTestCase(MySQLTests), unittest.TestLoader().loadTestsFromTestCase(PostgresTests) ]) From d9a631530acea069f8671700b6f4efe2b00794bb Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Tue, 14 Dec 2021 09:23:32 -0500 Subject: [PATCH 106/159] added wheel, fixes #162 --- setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index af297a0..5814b49 100644 --- a/setup.py +++ b/setup.py @@ -10,11 +10,11 @@ "Topic :: Software Development :: Libraries :: Python Modules" ], description="CS50 library for Python", - install_requires=["Flask>=1.0", "SQLAlchemy", "sqlparse", "termcolor"], + install_requires=["Flask>=1.0", "SQLAlchemy", "sqlparse", "termcolor", "wheel"], keywords="cs50", name="cs50", package_dir={"": "src"}, packages=["cs50"], url="https://github.com/cs50/python-cs50", - version="8.0.0" + version="8.0.1" ) From 8c08044d1d479baaa9aa68e8f791cd72da3871b6 Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Sat, 5 Feb 2022 22:29:22 -0500 Subject: [PATCH 107/159] updated README --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index c5830b2..f415d33 100644 --- a/README.md +++ b/README.md @@ -24,13 +24,14 @@ s = cs50.get_string(); 1. Run `cli50` in `python-cs50`. 1. Run `sudo su -`. +1. Run `apt update`. 1. Run `apt install -y libmysqlclient-dev mysql-server postgresql`. 1. Run `pip3 install mysqlclient psycopg2-binary`. 1. In `/etc/mysql/mysql.conf.d/mysqld.cnf`, add `skip-grant-tables` under `[mysqld]`. 1. In `/etc/profile.d/cli.sh`, remove `valgrind` function for now. 1. Run `service mysql start`. 1. Run `mysql -e 'CREATE DATABASE IF NOT EXISTS test;'`. -1. In `/etc/postgresql/12/main/pg_hba.conf, change: +1. In `/etc/postgresql/12/main/pg_hba.conf`, change: ``` local all postgres peer host all all 127.0.0.1/32 md5 From 9342365395541bf3d816a8ae6c17e0a5e90cc156 Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Sun, 6 Feb 2022 12:19:40 -0500 Subject: [PATCH 108/159] Fixes #128 by catch PostgreSQL exceptions that otherwise roll back transactions --- Dockerfile | 5 +++++ docker-compose.yml | 37 +++++++++++++++++++++++++++++++++++++ setup.py | 2 +- src/cs50/sql.py | 38 +++++++++++++++++++++++++++----------- 4 files changed, 70 insertions(+), 12 deletions(-) create mode 100644 Dockerfile create mode 100644 docker-compose.yml diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..2c6b969 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,5 @@ +FROM cs50/cli + +RUN sudo apt update && sudo apt install --yes libmysqlclient-dev +RUN sudo pip3 install mysqlclient psycopg2-binary +WORKDIR /mnt diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..98eb59c --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,37 @@ +services: + cli: + build: . + container_name: python-cs50 + depends_on: + - mysql + - postgres + environment: + MYSQL_DATABASE: code50 + MYSQL_HOST: mysql + MYSQL_PASSWORD: crimson + MYSQL_USERNAME: root + OAUTHLIB_INSECURE_TRANSPORT: 1 + links: + - mysql + - postgres + tty: true + volumes: + - .:/mnt + mysql: + environment: + MYSQL_DATABASE: test + MYSQL_ALLOW_EMPTY_PASSWORD: yes + healthcheck: + test: ["CMD", "mysqladmin", "-uroot", "ping"] + image: cs50/mysql:8 + ports: + - 3306:3306 + postgres: + image: postgres + environment: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + POSTGRES_DB: test + ports: + - 5432:5432 +version: "3.6" diff --git a/setup.py b/setup.py index 5814b49..59e5cf2 100644 --- a/setup.py +++ b/setup.py @@ -16,5 +16,5 @@ package_dir={"": "src"}, packages=["cs50"], url="https://github.com/cs50/python-cs50", - version="8.0.1" + version="8.0.2" ) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index a9759e7..b0aa94e 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -61,8 +61,10 @@ def __init__(self, url, **kwargs): if not os.path.isfile(matches.group(1)): raise RuntimeError("not a file: {}".format(matches.group(1))) - # Create engine, disabling SQLAlchemy's own autocommit mode, raising exception if back end's module not installed - self._engine = sqlalchemy.create_engine(url, **kwargs).execution_options(autocommit=False) + # Create engine, disabling SQLAlchemy's own autocommit mode raising exception if back end's module not installed; + # without isolation_level, PostgreSQL warns with "there is already a transaction in progress" for our own BEGIN and + # "there is no transaction in progress" for our own COMMIT + self._engine = sqlalchemy.create_engine(url, **kwargs).execution_options(autocommit=False, isolation_level="AUTOCOMMIT") # Get logger self._logger = logging.getLogger("cs50") @@ -70,10 +72,6 @@ def __init__(self, url, **kwargs): # Listener for connections def connect(dbapi_connection, connection_record): - # Disable underlying API's own emitting of BEGIN and COMMIT so we can ourselves - # https://docs.sqlalchemy.org/en/13/dialects/sqlite.html#serializable-isolation-savepoints-transactional-ddl - dbapi_connection.isolation_level = None - # Enable foreign key constraints if type(dbapi_connection) is sqlite3.Connection: # If back end is sqlite cursor = dbapi_connection.cursor() @@ -353,12 +351,30 @@ def teardown_appcontext(exception): # If INSERT, return primary key value for a newly inserted row (or None if none) elif command == "INSERT": + + # If PostgreSQL if self._engine.url.get_backend_name() == "postgresql": - try: - result = connection.execute("SELECT LASTVAL()") - ret = result.first()[0] - except sqlalchemy.exc.OperationalError: # If lastval is not yet defined for this connection - ret = None + + # Return LASTVAL() or NULL, avoiding + # "(psycopg2.errors.ObjectNotInPrerequisiteState) lastval is not yet defined in this session", + # a la https://stackoverflow.com/a/24186770/5156190; + # cf. https://www.psycopg.org/docs/errors.html re 55000 + result = connection.execute(""" + CREATE OR REPLACE FUNCTION _LASTVAL() + RETURNS integer LANGUAGE plpgsql + AS $$ + BEGIN + BEGIN + RETURN (SELECT LASTVAL()); + EXCEPTION + WHEN SQLSTATE '55000' THEN RETURN NULL; + END; + END $$; + SELECT _LASTVAL(); + """) + ret = result.first()[0] + + # If not PostgreSQL else: ret = result.lastrowid if result.rowcount == 1 else None From b813f9f71ca2252713b25efd3069ae96c3435cb4 Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Sun, 6 Feb 2022 13:21:49 -0500 Subject: [PATCH 109/159] updated tests, Docker, README --- Dockerfile | 3 ++- README.md | 40 +++++++++++++++++++--------------------- docker-compose.yml | 2 +- tests/sql.py | 4 ++-- 4 files changed, 24 insertions(+), 25 deletions(-) diff --git a/Dockerfile b/Dockerfile index 2c6b969..ccc4552 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,6 @@ FROM cs50/cli -RUN sudo apt update && sudo apt install --yes libmysqlclient-dev +RUN sudo apt update && sudo apt install --yes libmysqlclient-dev pgloader postgresql RUN sudo pip3 install mysqlclient psycopg2-binary + WORKDIR /mnt diff --git a/README.md b/README.md index f415d33..cf2c62d 100644 --- a/README.md +++ b/README.md @@ -22,27 +22,25 @@ s = cs50.get_string(); ## Testing -1. Run `cli50` in `python-cs50`. -1. Run `sudo su -`. -1. Run `apt update`. -1. Run `apt install -y libmysqlclient-dev mysql-server postgresql`. -1. Run `pip3 install mysqlclient psycopg2-binary`. -1. In `/etc/mysql/mysql.conf.d/mysqld.cnf`, add `skip-grant-tables` under `[mysqld]`. -1. In `/etc/profile.d/cli.sh`, remove `valgrind` function for now. -1. Run `service mysql start`. -1. Run `mysql -e 'CREATE DATABASE IF NOT EXISTS test;'`. -1. In `/etc/postgresql/12/main/pg_hba.conf`, change: - ``` - local all postgres peer - host all all 127.0.0.1/32 md5 - ``` - to: - ``` - local all postgres trust - host all all 127.0.0.1/32 trust - ``` -1. Run `service postgresql start`. -1. Run `psql -c 'create database test;' -U postgres`. +1. In one terminal, execute: + + ``` + cd python-cs50 + docker compose build + docker compose up + ``` + +1. In another terminal, execute: + + ``` + docker exec -it python-cs50 bash -l + ``` + + And then execute, e.g.: + + ``` + python tests/sql.py + ``` ### Sample Tests diff --git a/docker-compose.yml b/docker-compose.yml index 98eb59c..ce92de3 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -27,7 +27,7 @@ services: ports: - 3306:3306 postgres: - image: postgres + image: postgres:12 environment: POSTGRES_USER: postgres POSTGRES_PASSWORD: postgres diff --git a/tests/sql.py b/tests/sql.py index ff61b64..0f2f3a3 100644 --- a/tests/sql.py +++ b/tests/sql.py @@ -155,7 +155,7 @@ def tearDownClass(self): class MySQLTests(SQLTests): @classmethod def setUpClass(self): - self.db = SQL("mysql://root@127.0.0.1/test") + self.db = SQL("mysql://root@mysql/test") def setUp(self): self.db.execute("CREATE TABLE IF NOT EXISTS cs50 (id INTEGER NOT NULL AUTO_INCREMENT, val VARCHAR(16), bin BLOB, PRIMARY KEY (id))") @@ -165,7 +165,7 @@ def setUp(self): class PostgresTests(SQLTests): @classmethod def setUpClass(self): - self.db = SQL("postgresql://postgres:postgres@127.0.0.1/test") + self.db = SQL("postgresql://postgres:postgres@postgres/test") def setUp(self): self.db.execute("CREATE TABLE IF NOT EXISTS cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16), bin BYTEA)") From be80c55a5fcc974ac0a61668ffb96061a80327ad Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Sun, 6 Feb 2022 13:33:03 -0500 Subject: [PATCH 110/159] added env var for MySQL, PostgreSQL hosts --- .github/workflows/main.yml | 3 +++ docker-compose.yml | 5 +---- tests/sql.py | 5 +++-- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 30d894b..e32f995 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -30,6 +30,9 @@ jobs: pip install mysqlclient psycopg2-binary - name: Run tests run: python tests/sql.py + env: + MYSQL_HOST: 127.0.0.1 + POSTGRESQL_HOST: 127.0.0.1 - name: Install pypa/build run: python -m pip install build --user - name: Build a binary wheel and a source tarball diff --git a/docker-compose.yml b/docker-compose.yml index ce92de3..f795750 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -6,11 +6,8 @@ services: - mysql - postgres environment: - MYSQL_DATABASE: code50 MYSQL_HOST: mysql - MYSQL_PASSWORD: crimson - MYSQL_USERNAME: root - OAUTHLIB_INSECURE_TRANSPORT: 1 + POSTGRESQL_HOST: postgresql links: - mysql - postgres diff --git a/tests/sql.py b/tests/sql.py index 0f2f3a3..968f98b 100644 --- a/tests/sql.py +++ b/tests/sql.py @@ -1,4 +1,5 @@ import logging +import os import sys import unittest import warnings @@ -155,7 +156,7 @@ def tearDownClass(self): class MySQLTests(SQLTests): @classmethod def setUpClass(self): - self.db = SQL("mysql://root@mysql/test") + self.db = SQL(f"mysql://root@{os.getenv('MYSQL_HOST')}/test") def setUp(self): self.db.execute("CREATE TABLE IF NOT EXISTS cs50 (id INTEGER NOT NULL AUTO_INCREMENT, val VARCHAR(16), bin BLOB, PRIMARY KEY (id))") @@ -165,7 +166,7 @@ def setUp(self): class PostgresTests(SQLTests): @classmethod def setUpClass(self): - self.db = SQL("postgresql://postgres:postgres@postgres/test") + self.db = SQL(f"postgresql://postgres:postgres@{os.getenv('POSTGRESQL_HOST')}/test") def setUp(self): self.db.execute("CREATE TABLE IF NOT EXISTS cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16), bin BYTEA)") From c908273af448d016ffeebf197fd59bbb01394131 Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Sun, 6 Feb 2022 14:50:49 -0500 Subject: [PATCH 111/159] Update LICENSE --- LICENSE | 675 +++++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 671 insertions(+), 4 deletions(-) diff --git a/LICENSE b/LICENSE index 61ab664..f288702 100644 --- a/LICENSE +++ b/LICENSE @@ -1,7 +1,674 @@ -Copyright 2012-2018 CS50 + GNU GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 -Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/> + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. -The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + Preamble -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + The GNU General Public License is a free, copyleft license for +software and other kinds of works. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +the GNU General Public License is intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. We, the Free Software Foundation, use the +GNU General Public License for most of our software; it applies also to +any other work released this way by its authors. You can apply it to +your programs, too. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + To protect your rights, we need to prevent others from denying you +these rights or asking you to surrender the rights. Therefore, you have +certain responsibilities if you distribute copies of the software, or if +you modify it: responsibilities to respect the freedom of others. + + For example, if you distribute copies of such a program, whether +gratis or for a fee, you must pass on to the recipients the same +freedoms that you received. You must make sure that they, too, receive +or can get the source code. And you must show them these terms so they +know their rights. + + Developers that use the GNU GPL protect your rights with two steps: +(1) assert copyright on the software, and (2) offer you this License +giving you legal permission to copy, distribute and/or modify it. + + For the developers' and authors' protection, the GPL clearly explains +that there is no warranty for this free software. For both users' and +authors' sake, the GPL requires that modified versions be marked as +changed, so that their problems will not be attributed erroneously to +authors of previous versions. + + Some devices are designed to deny users access to install or run +modified versions of the software inside them, although the manufacturer +can do so. This is fundamentally incompatible with the aim of +protecting users' freedom to change the software. The systematic +pattern of such abuse occurs in the area of products for individuals to +use, which is precisely where it is most unacceptable. Therefore, we +have designed this version of the GPL to prohibit the practice for those +products. If such problems arise substantially in other domains, we +stand ready to extend this provision to those domains in future versions +of the GPL, as needed to protect the freedom of users. + + Finally, every program is threatened constantly by software patents. +States should not allow patents to restrict development and use of +software on general-purpose computers, but in those that do, we wish to +avoid the special danger that patents applied to a free program could +make it effectively proprietary. To prevent this, the GPL assures that +patents cannot be used to render the program non-free. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Use with the GNU Affero General Public License. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU Affero General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the special requirements of the GNU Affero General Public License, +section 13, concerning interaction through a network will apply to the +combination as such. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU General Public License from time to time. Such new versions will +be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + <one line to give the program's name and a brief idea of what it does.> + Copyright (C) <year> <name of author> + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see <https://www.gnu.org/licenses/>. + +Also add information on how to contact you by electronic and paper mail. + + If the program does terminal interaction, make it output a short +notice like this when it starts in an interactive mode: + + <program> Copyright (C) <year> <name of author> + This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. + This is free software, and you are welcome to redistribute it + under certain conditions; type `show c' for details. + +The hypothetical commands `show w' and `show c' should show the appropriate +parts of the General Public License. Of course, your program's commands +might be different; for a GUI interface, you would use an "about box". + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU GPL, see +<https://www.gnu.org/licenses/>. + + The GNU General Public License does not permit incorporating your program +into proprietary programs. If your program is a subroutine library, you +may consider it more useful to permit linking proprietary applications with +the library. If this is what you want to do, use the GNU Lesser General +Public License instead of this License. But first, please read +<https://www.gnu.org/licenses/why-not-lgpl.html>. From 351fcb37ce8510f436c846262dd7cc3d7d4250a6 Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Sun, 6 Feb 2022 14:52:00 -0500 Subject: [PATCH 112/159] Revert "Update LICENSE" This reverts commit c908273af448d016ffeebf197fd59bbb01394131. --- LICENSE | 675 +------------------------------------------------------- 1 file changed, 4 insertions(+), 671 deletions(-) diff --git a/LICENSE b/LICENSE index f288702..61ab664 100644 --- a/LICENSE +++ b/LICENSE @@ -1,674 +1,7 @@ - GNU GENERAL PUBLIC LICENSE - Version 3, 29 June 2007 +Copyright 2012-2018 CS50 - Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/> - Everyone is permitted to copy and distribute verbatim copies - of this license document, but changing it is not allowed. +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: - Preamble +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. - The GNU General Public License is a free, copyleft license for -software and other kinds of works. - - The licenses for most software and other practical works are designed -to take away your freedom to share and change the works. By contrast, -the GNU General Public License is intended to guarantee your freedom to -share and change all versions of a program--to make sure it remains free -software for all its users. We, the Free Software Foundation, use the -GNU General Public License for most of our software; it applies also to -any other work released this way by its authors. You can apply it to -your programs, too. - - When we speak of free software, we are referring to freedom, not -price. Our General Public Licenses are designed to make sure that you -have the freedom to distribute copies of free software (and charge for -them if you wish), that you receive source code or can get it if you -want it, that you can change the software or use pieces of it in new -free programs, and that you know you can do these things. - - To protect your rights, we need to prevent others from denying you -these rights or asking you to surrender the rights. Therefore, you have -certain responsibilities if you distribute copies of the software, or if -you modify it: responsibilities to respect the freedom of others. - - For example, if you distribute copies of such a program, whether -gratis or for a fee, you must pass on to the recipients the same -freedoms that you received. You must make sure that they, too, receive -or can get the source code. And you must show them these terms so they -know their rights. - - Developers that use the GNU GPL protect your rights with two steps: -(1) assert copyright on the software, and (2) offer you this License -giving you legal permission to copy, distribute and/or modify it. - - For the developers' and authors' protection, the GPL clearly explains -that there is no warranty for this free software. For both users' and -authors' sake, the GPL requires that modified versions be marked as -changed, so that their problems will not be attributed erroneously to -authors of previous versions. - - Some devices are designed to deny users access to install or run -modified versions of the software inside them, although the manufacturer -can do so. This is fundamentally incompatible with the aim of -protecting users' freedom to change the software. The systematic -pattern of such abuse occurs in the area of products for individuals to -use, which is precisely where it is most unacceptable. Therefore, we -have designed this version of the GPL to prohibit the practice for those -products. If such problems arise substantially in other domains, we -stand ready to extend this provision to those domains in future versions -of the GPL, as needed to protect the freedom of users. - - Finally, every program is threatened constantly by software patents. -States should not allow patents to restrict development and use of -software on general-purpose computers, but in those that do, we wish to -avoid the special danger that patents applied to a free program could -make it effectively proprietary. To prevent this, the GPL assures that -patents cannot be used to render the program non-free. - - The precise terms and conditions for copying, distribution and -modification follow. - - TERMS AND CONDITIONS - - 0. Definitions. - - "This License" refers to version 3 of the GNU General Public License. - - "Copyright" also means copyright-like laws that apply to other kinds of -works, such as semiconductor masks. - - "The Program" refers to any copyrightable work licensed under this -License. Each licensee is addressed as "you". "Licensees" and -"recipients" may be individuals or organizations. - - To "modify" a work means to copy from or adapt all or part of the work -in a fashion requiring copyright permission, other than the making of an -exact copy. The resulting work is called a "modified version" of the -earlier work or a work "based on" the earlier work. - - A "covered work" means either the unmodified Program or a work based -on the Program. - - To "propagate" a work means to do anything with it that, without -permission, would make you directly or secondarily liable for -infringement under applicable copyright law, except executing it on a -computer or modifying a private copy. Propagation includes copying, -distribution (with or without modification), making available to the -public, and in some countries other activities as well. - - To "convey" a work means any kind of propagation that enables other -parties to make or receive copies. Mere interaction with a user through -a computer network, with no transfer of a copy, is not conveying. - - An interactive user interface displays "Appropriate Legal Notices" -to the extent that it includes a convenient and prominently visible -feature that (1) displays an appropriate copyright notice, and (2) -tells the user that there is no warranty for the work (except to the -extent that warranties are provided), that licensees may convey the -work under this License, and how to view a copy of this License. If -the interface presents a list of user commands or options, such as a -menu, a prominent item in the list meets this criterion. - - 1. Source Code. - - The "source code" for a work means the preferred form of the work -for making modifications to it. "Object code" means any non-source -form of a work. - - A "Standard Interface" means an interface that either is an official -standard defined by a recognized standards body, or, in the case of -interfaces specified for a particular programming language, one that -is widely used among developers working in that language. - - The "System Libraries" of an executable work include anything, other -than the work as a whole, that (a) is included in the normal form of -packaging a Major Component, but which is not part of that Major -Component, and (b) serves only to enable use of the work with that -Major Component, or to implement a Standard Interface for which an -implementation is available to the public in source code form. A -"Major Component", in this context, means a major essential component -(kernel, window system, and so on) of the specific operating system -(if any) on which the executable work runs, or a compiler used to -produce the work, or an object code interpreter used to run it. - - The "Corresponding Source" for a work in object code form means all -the source code needed to generate, install, and (for an executable -work) run the object code and to modify the work, including scripts to -control those activities. However, it does not include the work's -System Libraries, or general-purpose tools or generally available free -programs which are used unmodified in performing those activities but -which are not part of the work. For example, Corresponding Source -includes interface definition files associated with source files for -the work, and the source code for shared libraries and dynamically -linked subprograms that the work is specifically designed to require, -such as by intimate data communication or control flow between those -subprograms and other parts of the work. - - The Corresponding Source need not include anything that users -can regenerate automatically from other parts of the Corresponding -Source. - - The Corresponding Source for a work in source code form is that -same work. - - 2. Basic Permissions. - - All rights granted under this License are granted for the term of -copyright on the Program, and are irrevocable provided the stated -conditions are met. This License explicitly affirms your unlimited -permission to run the unmodified Program. The output from running a -covered work is covered by this License only if the output, given its -content, constitutes a covered work. This License acknowledges your -rights of fair use or other equivalent, as provided by copyright law. - - You may make, run and propagate covered works that you do not -convey, without conditions so long as your license otherwise remains -in force. You may convey covered works to others for the sole purpose -of having them make modifications exclusively for you, or provide you -with facilities for running those works, provided that you comply with -the terms of this License in conveying all material for which you do -not control copyright. Those thus making or running the covered works -for you must do so exclusively on your behalf, under your direction -and control, on terms that prohibit them from making any copies of -your copyrighted material outside their relationship with you. - - Conveying under any other circumstances is permitted solely under -the conditions stated below. Sublicensing is not allowed; section 10 -makes it unnecessary. - - 3. Protecting Users' Legal Rights From Anti-Circumvention Law. - - No covered work shall be deemed part of an effective technological -measure under any applicable law fulfilling obligations under article -11 of the WIPO copyright treaty adopted on 20 December 1996, or -similar laws prohibiting or restricting circumvention of such -measures. - - When you convey a covered work, you waive any legal power to forbid -circumvention of technological measures to the extent such circumvention -is effected by exercising rights under this License with respect to -the covered work, and you disclaim any intention to limit operation or -modification of the work as a means of enforcing, against the work's -users, your or third parties' legal rights to forbid circumvention of -technological measures. - - 4. Conveying Verbatim Copies. - - You may convey verbatim copies of the Program's source code as you -receive it, in any medium, provided that you conspicuously and -appropriately publish on each copy an appropriate copyright notice; -keep intact all notices stating that this License and any -non-permissive terms added in accord with section 7 apply to the code; -keep intact all notices of the absence of any warranty; and give all -recipients a copy of this License along with the Program. - - You may charge any price or no price for each copy that you convey, -and you may offer support or warranty protection for a fee. - - 5. Conveying Modified Source Versions. - - You may convey a work based on the Program, or the modifications to -produce it from the Program, in the form of source code under the -terms of section 4, provided that you also meet all of these conditions: - - a) The work must carry prominent notices stating that you modified - it, and giving a relevant date. - - b) The work must carry prominent notices stating that it is - released under this License and any conditions added under section - 7. This requirement modifies the requirement in section 4 to - "keep intact all notices". - - c) You must license the entire work, as a whole, under this - License to anyone who comes into possession of a copy. This - License will therefore apply, along with any applicable section 7 - additional terms, to the whole of the work, and all its parts, - regardless of how they are packaged. This License gives no - permission to license the work in any other way, but it does not - invalidate such permission if you have separately received it. - - d) If the work has interactive user interfaces, each must display - Appropriate Legal Notices; however, if the Program has interactive - interfaces that do not display Appropriate Legal Notices, your - work need not make them do so. - - A compilation of a covered work with other separate and independent -works, which are not by their nature extensions of the covered work, -and which are not combined with it such as to form a larger program, -in or on a volume of a storage or distribution medium, is called an -"aggregate" if the compilation and its resulting copyright are not -used to limit the access or legal rights of the compilation's users -beyond what the individual works permit. Inclusion of a covered work -in an aggregate does not cause this License to apply to the other -parts of the aggregate. - - 6. Conveying Non-Source Forms. - - You may convey a covered work in object code form under the terms -of sections 4 and 5, provided that you also convey the -machine-readable Corresponding Source under the terms of this License, -in one of these ways: - - a) Convey the object code in, or embodied in, a physical product - (including a physical distribution medium), accompanied by the - Corresponding Source fixed on a durable physical medium - customarily used for software interchange. - - b) Convey the object code in, or embodied in, a physical product - (including a physical distribution medium), accompanied by a - written offer, valid for at least three years and valid for as - long as you offer spare parts or customer support for that product - model, to give anyone who possesses the object code either (1) a - copy of the Corresponding Source for all the software in the - product that is covered by this License, on a durable physical - medium customarily used for software interchange, for a price no - more than your reasonable cost of physically performing this - conveying of source, or (2) access to copy the - Corresponding Source from a network server at no charge. - - c) Convey individual copies of the object code with a copy of the - written offer to provide the Corresponding Source. This - alternative is allowed only occasionally and noncommercially, and - only if you received the object code with such an offer, in accord - with subsection 6b. - - d) Convey the object code by offering access from a designated - place (gratis or for a charge), and offer equivalent access to the - Corresponding Source in the same way through the same place at no - further charge. You need not require recipients to copy the - Corresponding Source along with the object code. If the place to - copy the object code is a network server, the Corresponding Source - may be on a different server (operated by you or a third party) - that supports equivalent copying facilities, provided you maintain - clear directions next to the object code saying where to find the - Corresponding Source. Regardless of what server hosts the - Corresponding Source, you remain obligated to ensure that it is - available for as long as needed to satisfy these requirements. - - e) Convey the object code using peer-to-peer transmission, provided - you inform other peers where the object code and Corresponding - Source of the work are being offered to the general public at no - charge under subsection 6d. - - A separable portion of the object code, whose source code is excluded -from the Corresponding Source as a System Library, need not be -included in conveying the object code work. - - A "User Product" is either (1) a "consumer product", which means any -tangible personal property which is normally used for personal, family, -or household purposes, or (2) anything designed or sold for incorporation -into a dwelling. In determining whether a product is a consumer product, -doubtful cases shall be resolved in favor of coverage. For a particular -product received by a particular user, "normally used" refers to a -typical or common use of that class of product, regardless of the status -of the particular user or of the way in which the particular user -actually uses, or expects or is expected to use, the product. A product -is a consumer product regardless of whether the product has substantial -commercial, industrial or non-consumer uses, unless such uses represent -the only significant mode of use of the product. - - "Installation Information" for a User Product means any methods, -procedures, authorization keys, or other information required to install -and execute modified versions of a covered work in that User Product from -a modified version of its Corresponding Source. The information must -suffice to ensure that the continued functioning of the modified object -code is in no case prevented or interfered with solely because -modification has been made. - - If you convey an object code work under this section in, or with, or -specifically for use in, a User Product, and the conveying occurs as -part of a transaction in which the right of possession and use of the -User Product is transferred to the recipient in perpetuity or for a -fixed term (regardless of how the transaction is characterized), the -Corresponding Source conveyed under this section must be accompanied -by the Installation Information. But this requirement does not apply -if neither you nor any third party retains the ability to install -modified object code on the User Product (for example, the work has -been installed in ROM). - - The requirement to provide Installation Information does not include a -requirement to continue to provide support service, warranty, or updates -for a work that has been modified or installed by the recipient, or for -the User Product in which it has been modified or installed. Access to a -network may be denied when the modification itself materially and -adversely affects the operation of the network or violates the rules and -protocols for communication across the network. - - Corresponding Source conveyed, and Installation Information provided, -in accord with this section must be in a format that is publicly -documented (and with an implementation available to the public in -source code form), and must require no special password or key for -unpacking, reading or copying. - - 7. Additional Terms. - - "Additional permissions" are terms that supplement the terms of this -License by making exceptions from one or more of its conditions. -Additional permissions that are applicable to the entire Program shall -be treated as though they were included in this License, to the extent -that they are valid under applicable law. If additional permissions -apply only to part of the Program, that part may be used separately -under those permissions, but the entire Program remains governed by -this License without regard to the additional permissions. - - When you convey a copy of a covered work, you may at your option -remove any additional permissions from that copy, or from any part of -it. (Additional permissions may be written to require their own -removal in certain cases when you modify the work.) You may place -additional permissions on material, added by you to a covered work, -for which you have or can give appropriate copyright permission. - - Notwithstanding any other provision of this License, for material you -add to a covered work, you may (if authorized by the copyright holders of -that material) supplement the terms of this License with terms: - - a) Disclaiming warranty or limiting liability differently from the - terms of sections 15 and 16 of this License; or - - b) Requiring preservation of specified reasonable legal notices or - author attributions in that material or in the Appropriate Legal - Notices displayed by works containing it; or - - c) Prohibiting misrepresentation of the origin of that material, or - requiring that modified versions of such material be marked in - reasonable ways as different from the original version; or - - d) Limiting the use for publicity purposes of names of licensors or - authors of the material; or - - e) Declining to grant rights under trademark law for use of some - trade names, trademarks, or service marks; or - - f) Requiring indemnification of licensors and authors of that - material by anyone who conveys the material (or modified versions of - it) with contractual assumptions of liability to the recipient, for - any liability that these contractual assumptions directly impose on - those licensors and authors. - - All other non-permissive additional terms are considered "further -restrictions" within the meaning of section 10. If the Program as you -received it, or any part of it, contains a notice stating that it is -governed by this License along with a term that is a further -restriction, you may remove that term. If a license document contains -a further restriction but permits relicensing or conveying under this -License, you may add to a covered work material governed by the terms -of that license document, provided that the further restriction does -not survive such relicensing or conveying. - - If you add terms to a covered work in accord with this section, you -must place, in the relevant source files, a statement of the -additional terms that apply to those files, or a notice indicating -where to find the applicable terms. - - Additional terms, permissive or non-permissive, may be stated in the -form of a separately written license, or stated as exceptions; -the above requirements apply either way. - - 8. Termination. - - You may not propagate or modify a covered work except as expressly -provided under this License. Any attempt otherwise to propagate or -modify it is void, and will automatically terminate your rights under -this License (including any patent licenses granted under the third -paragraph of section 11). - - However, if you cease all violation of this License, then your -license from a particular copyright holder is reinstated (a) -provisionally, unless and until the copyright holder explicitly and -finally terminates your license, and (b) permanently, if the copyright -holder fails to notify you of the violation by some reasonable means -prior to 60 days after the cessation. - - Moreover, your license from a particular copyright holder is -reinstated permanently if the copyright holder notifies you of the -violation by some reasonable means, this is the first time you have -received notice of violation of this License (for any work) from that -copyright holder, and you cure the violation prior to 30 days after -your receipt of the notice. - - Termination of your rights under this section does not terminate the -licenses of parties who have received copies or rights from you under -this License. If your rights have been terminated and not permanently -reinstated, you do not qualify to receive new licenses for the same -material under section 10. - - 9. Acceptance Not Required for Having Copies. - - You are not required to accept this License in order to receive or -run a copy of the Program. Ancillary propagation of a covered work -occurring solely as a consequence of using peer-to-peer transmission -to receive a copy likewise does not require acceptance. However, -nothing other than this License grants you permission to propagate or -modify any covered work. These actions infringe copyright if you do -not accept this License. Therefore, by modifying or propagating a -covered work, you indicate your acceptance of this License to do so. - - 10. Automatic Licensing of Downstream Recipients. - - Each time you convey a covered work, the recipient automatically -receives a license from the original licensors, to run, modify and -propagate that work, subject to this License. You are not responsible -for enforcing compliance by third parties with this License. - - An "entity transaction" is a transaction transferring control of an -organization, or substantially all assets of one, or subdividing an -organization, or merging organizations. If propagation of a covered -work results from an entity transaction, each party to that -transaction who receives a copy of the work also receives whatever -licenses to the work the party's predecessor in interest had or could -give under the previous paragraph, plus a right to possession of the -Corresponding Source of the work from the predecessor in interest, if -the predecessor has it or can get it with reasonable efforts. - - You may not impose any further restrictions on the exercise of the -rights granted or affirmed under this License. For example, you may -not impose a license fee, royalty, or other charge for exercise of -rights granted under this License, and you may not initiate litigation -(including a cross-claim or counterclaim in a lawsuit) alleging that -any patent claim is infringed by making, using, selling, offering for -sale, or importing the Program or any portion of it. - - 11. Patents. - - A "contributor" is a copyright holder who authorizes use under this -License of the Program or a work on which the Program is based. The -work thus licensed is called the contributor's "contributor version". - - A contributor's "essential patent claims" are all patent claims -owned or controlled by the contributor, whether already acquired or -hereafter acquired, that would be infringed by some manner, permitted -by this License, of making, using, or selling its contributor version, -but do not include claims that would be infringed only as a -consequence of further modification of the contributor version. For -purposes of this definition, "control" includes the right to grant -patent sublicenses in a manner consistent with the requirements of -this License. - - Each contributor grants you a non-exclusive, worldwide, royalty-free -patent license under the contributor's essential patent claims, to -make, use, sell, offer for sale, import and otherwise run, modify and -propagate the contents of its contributor version. - - In the following three paragraphs, a "patent license" is any express -agreement or commitment, however denominated, not to enforce a patent -(such as an express permission to practice a patent or covenant not to -sue for patent infringement). To "grant" such a patent license to a -party means to make such an agreement or commitment not to enforce a -patent against the party. - - If you convey a covered work, knowingly relying on a patent license, -and the Corresponding Source of the work is not available for anyone -to copy, free of charge and under the terms of this License, through a -publicly available network server or other readily accessible means, -then you must either (1) cause the Corresponding Source to be so -available, or (2) arrange to deprive yourself of the benefit of the -patent license for this particular work, or (3) arrange, in a manner -consistent with the requirements of this License, to extend the patent -license to downstream recipients. "Knowingly relying" means you have -actual knowledge that, but for the patent license, your conveying the -covered work in a country, or your recipient's use of the covered work -in a country, would infringe one or more identifiable patents in that -country that you have reason to believe are valid. - - If, pursuant to or in connection with a single transaction or -arrangement, you convey, or propagate by procuring conveyance of, a -covered work, and grant a patent license to some of the parties -receiving the covered work authorizing them to use, propagate, modify -or convey a specific copy of the covered work, then the patent license -you grant is automatically extended to all recipients of the covered -work and works based on it. - - A patent license is "discriminatory" if it does not include within -the scope of its coverage, prohibits the exercise of, or is -conditioned on the non-exercise of one or more of the rights that are -specifically granted under this License. You may not convey a covered -work if you are a party to an arrangement with a third party that is -in the business of distributing software, under which you make payment -to the third party based on the extent of your activity of conveying -the work, and under which the third party grants, to any of the -parties who would receive the covered work from you, a discriminatory -patent license (a) in connection with copies of the covered work -conveyed by you (or copies made from those copies), or (b) primarily -for and in connection with specific products or compilations that -contain the covered work, unless you entered into that arrangement, -or that patent license was granted, prior to 28 March 2007. - - Nothing in this License shall be construed as excluding or limiting -any implied license or other defenses to infringement that may -otherwise be available to you under applicable patent law. - - 12. No Surrender of Others' Freedom. - - If conditions are imposed on you (whether by court order, agreement or -otherwise) that contradict the conditions of this License, they do not -excuse you from the conditions of this License. If you cannot convey a -covered work so as to satisfy simultaneously your obligations under this -License and any other pertinent obligations, then as a consequence you may -not convey it at all. For example, if you agree to terms that obligate you -to collect a royalty for further conveying from those to whom you convey -the Program, the only way you could satisfy both those terms and this -License would be to refrain entirely from conveying the Program. - - 13. Use with the GNU Affero General Public License. - - Notwithstanding any other provision of this License, you have -permission to link or combine any covered work with a work licensed -under version 3 of the GNU Affero General Public License into a single -combined work, and to convey the resulting work. The terms of this -License will continue to apply to the part which is the covered work, -but the special requirements of the GNU Affero General Public License, -section 13, concerning interaction through a network will apply to the -combination as such. - - 14. Revised Versions of this License. - - The Free Software Foundation may publish revised and/or new versions of -the GNU General Public License from time to time. Such new versions will -be similar in spirit to the present version, but may differ in detail to -address new problems or concerns. - - Each version is given a distinguishing version number. If the -Program specifies that a certain numbered version of the GNU General -Public License "or any later version" applies to it, you have the -option of following the terms and conditions either of that numbered -version or of any later version published by the Free Software -Foundation. If the Program does not specify a version number of the -GNU General Public License, you may choose any version ever published -by the Free Software Foundation. - - If the Program specifies that a proxy can decide which future -versions of the GNU General Public License can be used, that proxy's -public statement of acceptance of a version permanently authorizes you -to choose that version for the Program. - - Later license versions may give you additional or different -permissions. However, no additional obligations are imposed on any -author or copyright holder as a result of your choosing to follow a -later version. - - 15. Disclaimer of Warranty. - - THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY -APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT -HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY -OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, -THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR -PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM -IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF -ALL NECESSARY SERVICING, REPAIR OR CORRECTION. - - 16. Limitation of Liability. - - IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING -WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS -THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY -GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE -USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF -DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD -PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), -EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF -SUCH DAMAGES. - - 17. Interpretation of Sections 15 and 16. - - If the disclaimer of warranty and limitation of liability provided -above cannot be given local legal effect according to their terms, -reviewing courts shall apply local law that most closely approximates -an absolute waiver of all civil liability in connection with the -Program, unless a warranty or assumption of liability accompanies a -copy of the Program in return for a fee. - - END OF TERMS AND CONDITIONS - - How to Apply These Terms to Your New Programs - - If you develop a new program, and you want it to be of the greatest -possible use to the public, the best way to achieve this is to make it -free software which everyone can redistribute and change under these terms. - - To do so, attach the following notices to the program. It is safest -to attach them to the start of each source file to most effectively -state the exclusion of warranty; and each file should have at least -the "copyright" line and a pointer to where the full notice is found. - - <one line to give the program's name and a brief idea of what it does.> - Copyright (C) <year> <name of author> - - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see <https://www.gnu.org/licenses/>. - -Also add information on how to contact you by electronic and paper mail. - - If the program does terminal interaction, make it output a short -notice like this when it starts in an interactive mode: - - <program> Copyright (C) <year> <name of author> - This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. - This is free software, and you are welcome to redistribute it - under certain conditions; type `show c' for details. - -The hypothetical commands `show w' and `show c' should show the appropriate -parts of the General Public License. Of course, your program's commands -might be different; for a GUI interface, you would use an "about box". - - You should also get your employer (if you work as a programmer) or school, -if any, to sign a "copyright disclaimer" for the program, if necessary. -For more information on this, and how to apply and follow the GNU GPL, see -<https://www.gnu.org/licenses/>. - - The GNU General Public License does not permit incorporating your program -into proprietary programs. If your program is a subroutine library, you -may consider it more useful to permit linking proprietary applications with -the library. If this is what you want to do, use the GNU Lesser General -Public License instead of this License. But first, please read -<https://www.gnu.org/licenses/why-not-lgpl.html>. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. From 90fab66516386c355726e5ef776237a038e89a2d Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Sun, 6 Feb 2022 14:52:55 -0500 Subject: [PATCH 113/159] Update LICENSE --- LICENSE | 675 ++++++++++++++++++++++++++++++++++++++++++++++++++++++- setup.py | 2 +- 2 files changed, 672 insertions(+), 5 deletions(-) diff --git a/LICENSE b/LICENSE index 61ab664..f288702 100644 --- a/LICENSE +++ b/LICENSE @@ -1,7 +1,674 @@ -Copyright 2012-2018 CS50 + GNU GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 -Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/> + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. -The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + Preamble -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + The GNU General Public License is a free, copyleft license for +software and other kinds of works. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +the GNU General Public License is intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. We, the Free Software Foundation, use the +GNU General Public License for most of our software; it applies also to +any other work released this way by its authors. You can apply it to +your programs, too. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + To protect your rights, we need to prevent others from denying you +these rights or asking you to surrender the rights. Therefore, you have +certain responsibilities if you distribute copies of the software, or if +you modify it: responsibilities to respect the freedom of others. + + For example, if you distribute copies of such a program, whether +gratis or for a fee, you must pass on to the recipients the same +freedoms that you received. You must make sure that they, too, receive +or can get the source code. And you must show them these terms so they +know their rights. + + Developers that use the GNU GPL protect your rights with two steps: +(1) assert copyright on the software, and (2) offer you this License +giving you legal permission to copy, distribute and/or modify it. + + For the developers' and authors' protection, the GPL clearly explains +that there is no warranty for this free software. For both users' and +authors' sake, the GPL requires that modified versions be marked as +changed, so that their problems will not be attributed erroneously to +authors of previous versions. + + Some devices are designed to deny users access to install or run +modified versions of the software inside them, although the manufacturer +can do so. This is fundamentally incompatible with the aim of +protecting users' freedom to change the software. The systematic +pattern of such abuse occurs in the area of products for individuals to +use, which is precisely where it is most unacceptable. Therefore, we +have designed this version of the GPL to prohibit the practice for those +products. If such problems arise substantially in other domains, we +stand ready to extend this provision to those domains in future versions +of the GPL, as needed to protect the freedom of users. + + Finally, every program is threatened constantly by software patents. +States should not allow patents to restrict development and use of +software on general-purpose computers, but in those that do, we wish to +avoid the special danger that patents applied to a free program could +make it effectively proprietary. To prevent this, the GPL assures that +patents cannot be used to render the program non-free. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Use with the GNU Affero General Public License. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU Affero General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the special requirements of the GNU Affero General Public License, +section 13, concerning interaction through a network will apply to the +combination as such. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU General Public License from time to time. Such new versions will +be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + <one line to give the program's name and a brief idea of what it does.> + Copyright (C) <year> <name of author> + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see <https://www.gnu.org/licenses/>. + +Also add information on how to contact you by electronic and paper mail. + + If the program does terminal interaction, make it output a short +notice like this when it starts in an interactive mode: + + <program> Copyright (C) <year> <name of author> + This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. + This is free software, and you are welcome to redistribute it + under certain conditions; type `show c' for details. + +The hypothetical commands `show w' and `show c' should show the appropriate +parts of the General Public License. Of course, your program's commands +might be different; for a GUI interface, you would use an "about box". + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU GPL, see +<https://www.gnu.org/licenses/>. + + The GNU General Public License does not permit incorporating your program +into proprietary programs. If your program is a subroutine library, you +may consider it more useful to permit linking proprietary applications with +the library. If this is what you want to do, use the GNU Lesser General +Public License instead of this License. But first, please read +<https://www.gnu.org/licenses/why-not-lgpl.html>. diff --git a/setup.py b/setup.py index 59e5cf2..d1c4f9b 100644 --- a/setup.py +++ b/setup.py @@ -16,5 +16,5 @@ package_dir={"": "src"}, packages=["cs50"], url="https://github.com/cs50/python-cs50", - version="8.0.2" + version="9.0.0" ) From d06017f86bd544f31bd40fc3b6f946d3c7af7c65 Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Sun, 5 Jun 2022 21:12:03 -0400 Subject: [PATCH 114/159] using INFO, WARNING, and ERROR instead of DEBUG --- setup.py | 2 +- src/cs50/sql.py | 6 +++--- tests/foo.py | 11 ++++++----- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/setup.py b/setup.py index d1c4f9b..c363ff1 100644 --- a/setup.py +++ b/setup.py @@ -16,5 +16,5 @@ package_dir={"": "src"}, packages=["cs50"], url="https://github.com/cs50/python-cs50", - version="9.0.0" + version="10.0.0" ) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index b0aa94e..8f0a1be 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -384,7 +384,7 @@ def teardown_appcontext(exception): # If constraint violated, return None except sqlalchemy.exc.IntegrityError as e: - self._logger.debug(termcolor.colored(statement, "yellow")) + self._logger.warning(termcolor.colored(statement, "yellow")) e = ValueError(e.orig) e.__cause__ = None raise e @@ -392,14 +392,14 @@ def teardown_appcontext(exception): # If user error except (sqlalchemy.exc.OperationalError, sqlalchemy.exc.ProgrammingError) as e: self._disconnect() - self._logger.debug(termcolor.colored(statement, "red")) + self._logger.error(termcolor.colored(statement, "red")) e = RuntimeError(e.orig) e.__cause__ = None raise e # Return value else: - self._logger.debug(termcolor.colored(_statement, "green")) + self._logger.info(termcolor.colored(_statement, "green")) if self._autocommit: # Don't stay connected unnecessarily self._disconnect() return ret diff --git a/tests/foo.py b/tests/foo.py index 7f32a00..f3955fc 100644 --- a/tests/foo.py +++ b/tests/foo.py @@ -5,23 +5,23 @@ import cs50 -""" db = cs50.SQL("sqlite:///foo.db") logging.getLogger("cs50").disabled = False +logging.getLogger("cs50").setLevel(logging.ERROR) -#db.execute("SELECT ? FROM ? ORDER BY ?", "a", "tbl", "c") -db.execute("CREATE TABLE IF NOT EXISTS bar (firstname STRING)") +db.execute("CREATE TABLE IF NOT EXISTS bar (firstname STRING UNIQUE)") +db.execute("INSERT INTO bar VALUES (?)", "baz") db.execute("INSERT INTO bar VALUES (?)", "baz") db.execute("INSERT INTO bar VALUES (?)", "qux") db.execute("SELECT * FROM bar WHERE firstname IN (?)", ("baz", "qux")) db.execute("DELETE FROM bar") + """ db = cs50.SQL("postgresql://postgres@localhost/test") -""" print(db.execute("DROP TABLE IF EXISTS cs50")) print(db.execute("CREATE TABLE cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16), bin BYTEA)")) print(db.execute("INSERT INTO cs50 (val) VALUES('foo')")) @@ -31,7 +31,6 @@ print(db.execute("CREATE TABLE cs50 (val VARCHAR(16), bin BYTEA)")) print(db.execute("INSERT INTO cs50 (val) VALUES('foo')")) print(db.execute("SELECT * FROM cs50")) -""" print(db.execute("DROP TABLE IF EXISTS cs50")) print(db.execute("CREATE TABLE cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16), bin BYTEA)")) @@ -46,3 +45,5 @@ pass print(db.execute("INSERT INTO cs50 (val) VALUES('qux')")) #print(db.execute("DELETE FROM cs50")) + +""" From afb51bc6163e9e82653e1aa251cd0e46afc10515 Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Mon, 6 Jun 2022 09:18:25 -0400 Subject: [PATCH 115/159] changing minor version instead --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index c363ff1..0eb27bf 100644 --- a/setup.py +++ b/setup.py @@ -16,5 +16,5 @@ package_dir={"": "src"}, packages=["cs50"], url="https://github.com/cs50/python-cs50", - version="10.0.0" + version="9.1.0" ) From 6a5bee036c16be15576c815c1918651f00fc8cca Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Mon, 6 Jun 2022 09:21:27 -0400 Subject: [PATCH 116/159] changed warning to error --- src/cs50/sql.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index 8f0a1be..cd631fc 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -382,9 +382,9 @@ def teardown_appcontext(exception): elif command in ["DELETE", "UPDATE"]: ret = result.rowcount - # If constraint violated, return None + # If constraint violated except sqlalchemy.exc.IntegrityError as e: - self._logger.warning(termcolor.colored(statement, "yellow")) + self._logger.error(termcolor.colored(statement, "red")) e = ValueError(e.orig) e.__cause__ = None raise e From 42482f9f9290501f45a6ff7552895d4fdef24e16 Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Mon, 6 Jun 2022 09:31:01 -0400 Subject: [PATCH 117/159] logging bytes-abbreviated _statement instead --- src/cs50/sql.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index cd631fc..c6cdc4b 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -384,7 +384,7 @@ def teardown_appcontext(exception): # If constraint violated except sqlalchemy.exc.IntegrityError as e: - self._logger.error(termcolor.colored(statement, "red")) + self._logger.error(termcolor.colored(_statement, "red")) e = ValueError(e.orig) e.__cause__ = None raise e @@ -392,7 +392,7 @@ def teardown_appcontext(exception): # If user error except (sqlalchemy.exc.OperationalError, sqlalchemy.exc.ProgrammingError) as e: self._disconnect() - self._logger.error(termcolor.colored(statement, "red")) + self._logger.error(termcolor.colored(_statement, "red")) e = RuntimeError(e.orig) e.__cause__ = None raise e From a43b565b6ae9199d062336edbc670fbb7b184fbc Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Mon, 6 Jun 2022 21:34:00 -0400 Subject: [PATCH 118/159] scoping INFO and ERROR to FLASK_ENV=development --- setup.py | 2 +- src/cs50/sql.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 0eb27bf..765213a 100644 --- a/setup.py +++ b/setup.py @@ -16,5 +16,5 @@ package_dir={"": "src"}, packages=["cs50"], url="https://github.com/cs50/python-cs50", - version="9.1.0" + version="9.2.0" ) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index c6cdc4b..35c180e 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -9,6 +9,7 @@ def _enable_logging(f): import logging import functools + import os @functools.wraps(f) def decorator(*args, **kwargs): @@ -19,9 +20,9 @@ def decorator(*args, **kwargs): except ModuleNotFoundError: return f(*args, **kwargs) - # Enable logging + # Enable logging in development mode disabled = logging.getLogger("cs50").disabled - if flask.current_app: + if flask.current_app and os.getenv("FLASK_ENV") == "development": logging.getLogger("cs50").disabled = False try: return f(*args, **kwargs) From a6ceac6f47e5c77fc814b35ebcf66e1de9ce9cc8 Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Tue, 14 Jun 2022 20:18:26 -0400 Subject: [PATCH 119/159] fixed support for None as NULL --- setup.py | 2 +- src/cs50/sql.py | 2 +- tests/foo.py | 8 +++++--- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/setup.py b/setup.py index 765213a..f0edaaf 100644 --- a/setup.py +++ b/setup.py @@ -16,5 +16,5 @@ package_dir={"": "src"}, packages=["cs50"], url="https://github.com/cs50/python-cs50", - version="9.2.0" + version="9.2.1" ) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index 35c180e..f008b6f 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -476,7 +476,7 @@ def __escape(value): elif value is None: return sqlparse.sql.Token( sqlparse.tokens.Keyword, - sqlalchemy.types.NullType().literal_processor(self._engine.dialect)(value)) + sqlalchemy.null()) # Unsupported value else: diff --git a/tests/foo.py b/tests/foo.py index f3955fc..2cf74e9 100644 --- a/tests/foo.py +++ b/tests/foo.py @@ -10,13 +10,15 @@ logging.getLogger("cs50").disabled = False logging.getLogger("cs50").setLevel(logging.ERROR) -db.execute("CREATE TABLE IF NOT EXISTS bar (firstname STRING UNIQUE)") +db.execute("DROP TABLE IF EXISTS bar") +db.execute("CREATE TABLE bar (firstname STRING UNIQUE)") -db.execute("INSERT INTO bar VALUES (?)", "baz") +db.execute("INSERT INTO bar VALUES (?)", None) db.execute("INSERT INTO bar VALUES (?)", "baz") db.execute("INSERT INTO bar VALUES (?)", "qux") db.execute("SELECT * FROM bar WHERE firstname IN (?)", ("baz", "qux")) -db.execute("DELETE FROM bar") +print(db.execute("SELECT * FROM bar")) +#db.execute("DELETE FROM bar") """ From 0039f7e2ba76f561ac5eaf93d4f63d160c660464 Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Tue, 14 Jun 2022 20:26:16 -0400 Subject: [PATCH 120/159] fixed license --- setup.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index f0edaaf..0c3f650 100644 --- a/setup.py +++ b/setup.py @@ -12,9 +12,11 @@ description="CS50 library for Python", install_requires=["Flask>=1.0", "SQLAlchemy", "sqlparse", "termcolor", "wheel"], keywords="cs50", + license="GPLv3", + long_description_content_type="text/markdown", name="cs50", package_dir={"": "src"}, packages=["cs50"], url="https://github.com/cs50/python-cs50", - version="9.2.1" + version="9.2.2" ) From 7448bc073000c1e3a9f0cd1e690eb45e4ad9537e Mon Sep 17 00:00:00 2001 From: Rongxin Liu <rongxinliu.dev@gmail.com> Date: Sat, 16 Jul 2022 15:50:55 -0400 Subject: [PATCH 121/159] added github release automation to workflow --- .github/workflows/main.yml | 40 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index e32f995..b0f91a9 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -28,18 +28,58 @@ jobs: run: | pip install . pip install mysqlclient psycopg2-binary + - name: Run tests run: python tests/sql.py env: MYSQL_HOST: 127.0.0.1 POSTGRESQL_HOST: 127.0.0.1 + - name: Install pypa/build run: python -m pip install build --user + - name: Build a binary wheel and a source tarball run: python -m build --sdist --wheel --outdir dist/ . + - name: Deploy to PyPI if: ${{ github.ref == 'refs/heads/main' }} uses: pypa/gh-action-pypi-publish@release/v1 with: user: __token__ password: ${{ secrets.PYPI_API_TOKEN }} + + - name: Get Version + id: py_version + run: | + echo ::set-output name=version::$(python3 setup.py --version) + + - name: Create Tag + uses: actions/github-script@v6 + with: + github-token: ${{ github.token }} + script: | + try { + await github.rest.git.updateRef({ + owner: context.repo.owner, + repo: context.repo.repo, + ref: "tags/v${{ steps.py_version.outputs.version }}", + sha: context.sha, + force: true + }) + } catch (e) { + await github.rest.git.createRef({ + owner: context.repo.owner, + repo: context.repo.repo, + ref: "refs/tags/v${{ steps.py_version.outputs.version }}", + sha: context.sha + }) + } + + - name: Create Release + run: | + curl \ + -X POST \ + -H "Accept: application/vnd.github+json" \ + -H "Authorization: token ${{ secrets.GH_RELEASE_TOKEN }}" \ + https://api.github.com/repos/${GITHUB_REPOSITORY}/releases \ + -d '{"tag_name":"v${{ steps.py_version.outputs.version }}","target_commitish":"${{ github.sha }}","name":"v${{ steps.py_version.outputs.version }}","body":"${{ github.event.head_commit.message }}","draft":false,"prerelease":false,"generate_release_notes":false}' From 23531820daf415732e07f39ce546f173e1da311a Mon Sep 17 00:00:00 2001 From: Rongxin Liu <rongxinliu.dev@gmail.com> Date: Thu, 1 Sep 2022 13:19:58 -0400 Subject: [PATCH 122/159] use v3 actions for checkout and setup-python --- .github/workflows/main.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index b0f91a9..dd919ef 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -20,8 +20,8 @@ jobs: ports: - 5432:5432 steps: - - uses: actions/checkout@v2 - - uses: actions/setup-python@v2 + - uses: actions/checkout@v3 + - uses: actions/setup-python@v3 with: python-version: '3.6' - name: Setup databases From 07ad96b72376ece9a6abeb7172ea0650b89b9580 Mon Sep 17 00:00:00 2001 From: Rongxin Liu <rongxinliu.dev@gmail.com> Date: Sat, 22 Oct 2022 09:15:54 -0400 Subject: [PATCH 123/159] remove travis badge --- README.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/README.md b/README.md index cf2c62d..c94bd1f 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,5 @@ # CS50 Library for Python -[](https://travis-ci.org/cs50/python-cs50) - ## Installation ``` From fcc68a16e7eaed3ec6899a487d18d01410ce8a0a Mon Sep 17 00:00:00 2001 From: Rongxin Liu <rongxinliu.dev@gmail.com> Date: Sat, 19 Nov 2022 13:52:16 -0500 Subject: [PATCH 124/159] simplified release process --- .github/workflows/main.yml | 33 +++++++-------------------------- 1 file changed, 7 insertions(+), 26 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index dd919ef..2059c50 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -53,33 +53,14 @@ jobs: run: | echo ::set-output name=version::$(python3 setup.py --version) - - name: Create Tag + - name: Create Release uses: actions/github-script@v6 with: github-token: ${{ github.token }} script: | - try { - await github.rest.git.updateRef({ - owner: context.repo.owner, - repo: context.repo.repo, - ref: "tags/v${{ steps.py_version.outputs.version }}", - sha: context.sha, - force: true - }) - } catch (e) { - await github.rest.git.createRef({ - owner: context.repo.owner, - repo: context.repo.repo, - ref: "refs/tags/v${{ steps.py_version.outputs.version }}", - sha: context.sha - }) - } - - - name: Create Release - run: | - curl \ - -X POST \ - -H "Accept: application/vnd.github+json" \ - -H "Authorization: token ${{ secrets.GH_RELEASE_TOKEN }}" \ - https://api.github.com/repos/${GITHUB_REPOSITORY}/releases \ - -d '{"tag_name":"v${{ steps.py_version.outputs.version }}","target_commitish":"${{ github.sha }}","name":"v${{ steps.py_version.outputs.version }}","body":"${{ github.event.head_commit.message }}","draft":false,"prerelease":false,"generate_release_notes":false}' + github.rest.repos.createRelease({ + owner: context.repo.owner, + repo: context.repo.repo, + tag_name: "v${{ steps.py_version.outputs.version }}", + tag_commitish: "${{ github.sha }}" + }) From b048ba21a2ad22058bac00f42be43b61f7fc744b Mon Sep 17 00:00:00 2001 From: Matthias Wenz <matthiaswenz@github.com> Date: Fri, 2 Dec 2022 00:25:51 +0100 Subject: [PATCH 125/159] Load sqlite module if sqlite connection --- src/cs50/sql.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index f008b6f..c5b2d94 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -51,12 +51,12 @@ def __init__(self, url, **kwargs): import re import sqlalchemy import sqlalchemy.orm - import sqlite3 import threading # Require that file already exist for SQLite matches = re.search(r"^sqlite:///(.+)$", url) if matches: + import sqlite3 if not os.path.exists(matches.group(1)): raise RuntimeError("does not exist: {}".format(matches.group(1))) if not os.path.isfile(matches.group(1)): From bd25298787cb55772f0a615eab5da8bbfc035683 Mon Sep 17 00:00:00 2001 From: Matthias Wenz <matthiaswenz@github.com> Date: Fri, 2 Dec 2022 16:57:01 +0100 Subject: [PATCH 126/159] Verify module loaded when accessing contents --- src/cs50/sql.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index c5b2d94..6611e49 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -74,7 +74,7 @@ def __init__(self, url, **kwargs): def connect(dbapi_connection, connection_record): # Enable foreign key constraints - if type(dbapi_connection) is sqlite3.Connection: # If back end is sqlite + if 'sqlite3' in sys.modules and type(dbapi_connection) is sqlite3.Connection: # If back end is sqlite cursor = dbapi_connection.cursor() cursor.execute("PRAGMA foreign_keys=ON") cursor.close() From 01236bcbc85191a7ed623f57d4c287e5ffb831e7 Mon Sep 17 00:00:00 2001 From: Rongxin Liu <rongxinliu.dev@gmail.com> Date: Fri, 2 Dec 2022 14:41:20 -0500 Subject: [PATCH 127/159] import sys module, tweaked styles --- src/cs50/sql.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index 6611e49..1d33edb 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -1,3 +1,4 @@ +import sys import threading # Thread-local data @@ -74,7 +75,7 @@ def __init__(self, url, **kwargs): def connect(dbapi_connection, connection_record): # Enable foreign key constraints - if 'sqlite3' in sys.modules and type(dbapi_connection) is sqlite3.Connection: # If back end is sqlite + if "sqlite3" in sys.modules and type(dbapi_connection) is sqlite3.Connection: # If back end is sqlite cursor = dbapi_connection.cursor() cursor.execute("PRAGMA foreign_keys=ON") cursor.close() From a3784cc827f2fe616538131b8ce5f00357d95625 Mon Sep 17 00:00:00 2001 From: Rongxin Liu <rongxinliu.dev@gmail.com> Date: Fri, 2 Dec 2022 14:46:06 -0500 Subject: [PATCH 128/159] bump version number --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 0c3f650..2d90788 100644 --- a/setup.py +++ b/setup.py @@ -18,5 +18,5 @@ package_dir={"": "src"}, packages=["cs50"], url="https://github.com/cs50/python-cs50", - version="9.2.2" + version="9.2.3" ) From 328a76c0b5af368ecba32005c6bfba2afd23daa4 Mon Sep 17 00:00:00 2001 From: Rongxin Liu <rongxinliu.dev@gmail.com> Date: Fri, 2 Dec 2022 14:51:48 -0500 Subject: [PATCH 129/159] update setup-python action --- .github/workflows/main.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 2059c50..26e072d 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -21,7 +21,7 @@ jobs: - 5432:5432 steps: - uses: actions/checkout@v3 - - uses: actions/setup-python@v3 + - uses: actions/setup-python@v4 with: python-version: '3.6' - name: Setup databases From 0c552d363b2b6122484a056615800244f0026824 Mon Sep 17 00:00:00 2001 From: Rongxin Liu <rongxinliu.dev@gmail.com> Date: Fri, 2 Dec 2022 15:22:59 -0500 Subject: [PATCH 130/159] use try-catch import sqlite3, update action --- .github/workflows/main.yml | 4 +++- src/cs50/sql.py | 9 +++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 26e072d..f438d7b 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -23,7 +23,8 @@ jobs: - uses: actions/checkout@v3 - uses: actions/setup-python@v4 with: - python-version: '3.6' + python-version: '3.7' + check-latest: true - name: Setup databases run: | pip install . @@ -54,6 +55,7 @@ jobs: echo ::set-output name=version::$(python3 setup.py --version) - name: Create Release + if: ${{ github.ref == 'refs/heads/main' }} uses: actions/github-script@v6 with: github-token: ${{ github.token }} diff --git a/src/cs50/sql.py b/src/cs50/sql.py index 1d33edb..8087657 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -53,11 +53,16 @@ def __init__(self, url, **kwargs): import sqlalchemy import sqlalchemy.orm import threading + + # Temporary fix for missing sqlite3 module on the buildpack stack + try: + import sqlite3 + except: + pass # Require that file already exist for SQLite matches = re.search(r"^sqlite:///(.+)$", url) if matches: - import sqlite3 if not os.path.exists(matches.group(1)): raise RuntimeError("does not exist: {}".format(matches.group(1))) if not os.path.isfile(matches.group(1)): @@ -75,7 +80,7 @@ def __init__(self, url, **kwargs): def connect(dbapi_connection, connection_record): # Enable foreign key constraints - if "sqlite3" in sys.modules and type(dbapi_connection) is sqlite3.Connection: # If back end is sqlite + if type(dbapi_connection) is sqlite3.Connection: # If back end is sqlite cursor = dbapi_connection.cursor() cursor.execute("PRAGMA foreign_keys=ON") cursor.close() From 3ccbd99c58d18df7e9385d421e85e8b7631bbe9d Mon Sep 17 00:00:00 2001 From: Rongxin Liu <rongxinliu.dev@gmail.com> Date: Fri, 2 Dec 2022 15:32:18 -0500 Subject: [PATCH 131/159] added try-catch workaround for SQL.connect when sqlite3 is not available --- setup.py | 2 +- src/cs50/sql.py | 12 ++++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/setup.py b/setup.py index 2d90788..62a7abe 100644 --- a/setup.py +++ b/setup.py @@ -18,5 +18,5 @@ package_dir={"": "src"}, packages=["cs50"], url="https://github.com/cs50/python-cs50", - version="9.2.3" + version="9.2.4" ) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index 8087657..f2c090d 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -80,10 +80,14 @@ def __init__(self, url, **kwargs): def connect(dbapi_connection, connection_record): # Enable foreign key constraints - if type(dbapi_connection) is sqlite3.Connection: # If back end is sqlite - cursor = dbapi_connection.cursor() - cursor.execute("PRAGMA foreign_keys=ON") - cursor.close() + try: + if type(dbapi_connection) is sqlite3.Connection: # If back end is sqlite + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + except: + # Temporary fix for missing sqlite3 module on the buildpack stack + pass # Register listener sqlalchemy.event.listen(self._engine, "connect", connect) From 777b4da5a1d2117b3959f72fbe74e4b50c2885de Mon Sep 17 00:00:00 2001 From: up-n-atom <adam.jaremko@gmail.com> Date: Sat, 28 Jan 2023 23:26:06 -0500 Subject: [PATCH 132/159] Respect pep8 and revert 659c8f4 As described in pep8: "Object type comparisons should always use isinstance() instead of comparing types directly:" Ref. https://peps.python.org/pep-0008/ --- src/cs50/cs50.py | 2 +- src/cs50/sql.py | 24 ++++++++++++------------ 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/cs50/cs50.py b/src/cs50/cs50.py index 1d7b6ea..16bfd0b 100644 --- a/src/cs50/cs50.py +++ b/src/cs50/cs50.py @@ -135,7 +135,7 @@ def get_string(prompt): as line endings. If user inputs only a line ending, returns "", not None. Returns None upon error or no input whatsoever (i.e., just EOF). """ - if type(prompt) is not str: + if not isinstance(prompt, str): raise TypeError("prompt must be of type str") try: return input(prompt) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index f2c090d..4f9457b 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -81,7 +81,7 @@ def connect(dbapi_connection, connection_record): # Enable foreign key constraints try: - if type(dbapi_connection) is sqlite3.Connection: # If back end is sqlite + if isinstance(dbapi_connection, sqlite3.Connection): # If back end is sqlite cursor = dbapi_connection.cursor() cursor.execute("PRAGMA foreign_keys=ON") cursor.close() @@ -350,11 +350,11 @@ def teardown_appcontext(exception): # Coerce decimal.Decimal objects to float objects # https://groups.google.com/d/msg/sqlalchemy/0qXMYJvq8SA/oqtvMD9Uw-kJ - if type(row[column]) is decimal.Decimal: + if isinstance(row[column], decimal.Decimal): row[column] = float(row[column]) # Coerce memoryview objects (as from PostgreSQL's bytea columns) to bytes - elif type(row[column]) is memoryview: + elif isinstance(row[column], memoryview): row[column] = bytes(row[column]) # Rows to be returned @@ -432,13 +432,13 @@ def __escape(value): import sqlalchemy # bool - if type(value) is bool: + if isinstance(value, bool): return sqlparse.sql.Token( sqlparse.tokens.Number, sqlalchemy.types.Boolean().literal_processor(self._engine.dialect)(value)) # bytes - elif type(value) is bytes: + elif isinstance(value, bytes): if self._engine.url.get_backend_name() in ["mysql", "sqlite"]: return sqlparse.sql.Token(sqlparse.tokens.Other, f"x'{value.hex()}'") # https://dev.mysql.com/doc/refman/8.0/en/hexadecimal-literals.html elif self._engine.url.get_backend_name() == "postgresql": @@ -447,37 +447,37 @@ def __escape(value): raise RuntimeError("unsupported value: {}".format(value)) # datetime.date - elif type(value) is datetime.date: + elif isinstance(value, datetime.date): return sqlparse.sql.Token( sqlparse.tokens.String, sqlalchemy.types.String().literal_processor(self._engine.dialect)(value.strftime("%Y-%m-%d"))) # datetime.datetime - elif type(value) is datetime.datetime: + elif isinstance(value, datetime.datetime): return sqlparse.sql.Token( sqlparse.tokens.String, sqlalchemy.types.String().literal_processor(self._engine.dialect)(value.strftime("%Y-%m-%d %H:%M:%S"))) # datetime.time - elif type(value) is datetime.time: + elif isinstance(value, datetime.time): return sqlparse.sql.Token( sqlparse.tokens.String, sqlalchemy.types.String().literal_processor(self._engine.dialect)(value.strftime("%H:%M:%S"))) # float - elif type(value) is float: + elif isinstance(value, float): return sqlparse.sql.Token( sqlparse.tokens.Number, sqlalchemy.types.Float().literal_processor(self._engine.dialect)(value)) # int - elif type(value) is int: + elif isinstance(value, int): return sqlparse.sql.Token( sqlparse.tokens.Number, sqlalchemy.types.Integer().literal_processor(self._engine.dialect)(value)) # str - elif type(value) is str: + elif isinstance(value, str): return sqlparse.sql.Token( sqlparse.tokens.String, sqlalchemy.types.String().literal_processor(self._engine.dialect)(value)) @@ -493,7 +493,7 @@ def __escape(value): raise RuntimeError("unsupported value: {}".format(value)) # Escape value(s), separating with commas as needed - if type(value) in [list, tuple]: + if isinstance(value, (list, tuple)): return sqlparse.sql.TokenList(sqlparse.parse(", ".join([str(__escape(v)) for v in value]))) else: return __escape(value) From 6096c7e47aa75f18d9aa1f728538e33e37b763ba Mon Sep 17 00:00:00 2001 From: up-n-atom <adam.jaremko@gmail.com> Date: Sun, 29 Jan 2023 03:00:43 -0500 Subject: [PATCH 133/159] Fix order of datetime type checks datetime.datetime inherits datetime.date and will prematurely evaluate as an instance of datetime.date. --- src/cs50/sql.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index 4f9457b..24690e3 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -446,18 +446,18 @@ def __escape(value): else: raise RuntimeError("unsupported value: {}".format(value)) - # datetime.date - elif isinstance(value, datetime.date): - return sqlparse.sql.Token( - sqlparse.tokens.String, - sqlalchemy.types.String().literal_processor(self._engine.dialect)(value.strftime("%Y-%m-%d"))) - # datetime.datetime elif isinstance(value, datetime.datetime): return sqlparse.sql.Token( sqlparse.tokens.String, sqlalchemy.types.String().literal_processor(self._engine.dialect)(value.strftime("%Y-%m-%d %H:%M:%S"))) + # datetime.date + elif isinstance(value, datetime.date): + return sqlparse.sql.Token( + sqlparse.tokens.String, + sqlalchemy.types.String().literal_processor(self._engine.dialect)(value.strftime("%Y-%m-%d"))) + # datetime.time elif isinstance(value, datetime.time): return sqlparse.sql.Token( From 53cf4d204f21d6c66f1b9ba6fe2055f0e6037feb Mon Sep 17 00:00:00 2001 From: Rongxin Liu <rongxinliu.dev@gmail.com> Date: Sun, 29 Jan 2023 16:58:03 +0800 Subject: [PATCH 134/159] bumped version number --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 62a7abe..73b37a1 100644 --- a/setup.py +++ b/setup.py @@ -18,5 +18,5 @@ package_dir={"": "src"}, packages=["cs50"], url="https://github.com/cs50/python-cs50", - version="9.2.4" + version="9.2.5" ) From 464f237e32bc2f87affd34722a498054dc0d57f6 Mon Sep 17 00:00:00 2001 From: Rongxin Liu <rongxinliu.dev@gmail.com> Date: Sun, 29 Jan 2023 18:39:42 +0800 Subject: [PATCH 135/159] fixated SQLAlchemy to version 1.4.46 --- .github/workflows/main.yml | 2 +- .gitignore | 1 + setup.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index f438d7b..964db68 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -28,7 +28,7 @@ jobs: - name: Setup databases run: | pip install . - pip install mysqlclient psycopg2-binary + pip install mysqlclient psycopg2-binary SQLAlchemy==1.4.46 - name: Run tests run: python tests/sql.py diff --git a/.gitignore b/.gitignore index 4286ed6..dd3ffcc 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ .* !/.github/ !.gitignore +build/ *.db *.egg-info/ *.pyc diff --git a/setup.py b/setup.py index 73b37a1..1a8ef3a 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ "Topic :: Software Development :: Libraries :: Python Modules" ], description="CS50 library for Python", - install_requires=["Flask>=1.0", "SQLAlchemy", "sqlparse", "termcolor", "wheel"], + install_requires=["Flask>=1.0", "SQLAlchemy==1.4.46", "sqlparse", "termcolor", "wheel"], keywords="cs50", license="GPLv3", long_description_content_type="text/markdown", From f807f1ef8b50b72d9307c0130828af41b99f517d Mon Sep 17 00:00:00 2001 From: Rongxin Liu <rongxinliu.dev@gmail.com> Date: Wed, 13 Sep 2023 11:10:54 -0400 Subject: [PATCH 136/159] added support for SQLAlchemy 2.0 --- setup.py | 4 ++-- src/cs50/sql.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/setup.py b/setup.py index 1a8ef3a..2c25e53 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ "Topic :: Software Development :: Libraries :: Python Modules" ], description="CS50 library for Python", - install_requires=["Flask>=1.0", "SQLAlchemy==1.4.46", "sqlparse", "termcolor", "wheel"], + install_requires=["Flask>=1.0", "SQLAlchemy<3", "sqlparse", "termcolor", "wheel"], keywords="cs50", license="GPLv3", long_description_content_type="text/markdown", @@ -18,5 +18,5 @@ package_dir={"": "src"}, packages=["cs50"], url="https://github.com/cs50/python-cs50", - version="9.2.5" + version="9.2.6" ) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index 24690e3..8110cba 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -53,7 +53,7 @@ def __init__(self, url, **kwargs): import sqlalchemy import sqlalchemy.orm import threading - + # Temporary fix for missing sqlite3 module on the buildpack stack try: import sqlite3 @@ -100,7 +100,7 @@ def connect(dbapi_connection, connection_record): self._logger.disabled = True try: connection = self._engine.connect() - connection.execute("SELECT 1") + connection.execute(sqlalchemy.text("SELECT 1")) connection.close() except sqlalchemy.exc.OperationalError as e: e = RuntimeError(_parse_exception(e)) @@ -344,7 +344,7 @@ def teardown_appcontext(exception): if command == "SELECT": # Coerce types - rows = [dict(row) for row in result.fetchall()] + rows = [dict(row) for row in result.mappings().all()] for row in rows: for column in row: From ca09441d54ab52805baa38d50fe47c33bf4dcc65 Mon Sep 17 00:00:00 2001 From: Rongxin Liu <rongxinliu.dev@gmail.com> Date: Wed, 13 Sep 2023 11:17:39 -0400 Subject: [PATCH 137/159] convert string to sqlalchemy text --- src/cs50/sql.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index 8110cba..527a98d 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -370,7 +370,7 @@ def teardown_appcontext(exception): # "(psycopg2.errors.ObjectNotInPrerequisiteState) lastval is not yet defined in this session", # a la https://stackoverflow.com/a/24186770/5156190; # cf. https://www.psycopg.org/docs/errors.html re 55000 - result = connection.execute(""" + result = connection.execute(sqlalchemy.text(""" CREATE OR REPLACE FUNCTION _LASTVAL() RETURNS integer LANGUAGE plpgsql AS $$ @@ -382,7 +382,7 @@ def teardown_appcontext(exception): END; END $$; SELECT _LASTVAL(); - """) + """)) ret = result.first()[0] # If not PostgreSQL From 2ed4803dbe6f29f0b10ddb1ef5b173d991d43805 Mon Sep 17 00:00:00 2001 From: Rongxin Liu <rongxinliu.dev@gmail.com> Date: Thu, 14 Sep 2023 23:28:20 -0400 Subject: [PATCH 138/159] bump SQLAlchemy version to 1.4.49 in workflow --- .github/workflows/main.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 964db68..6547cf2 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -28,7 +28,7 @@ jobs: - name: Setup databases run: | pip install . - pip install mysqlclient psycopg2-binary SQLAlchemy==1.4.46 + pip install mysqlclient psycopg2-binary SQLAlchemy==1.4.49 - name: Run tests run: python tests/sql.py From b8581fe410f95174edd555e428f32af8acb7bfd7 Mon Sep 17 00:00:00 2001 From: Aivar Annamaa <aivarannamaa@users.noreply.github.com> Date: Sat, 23 Sep 2023 12:55:48 +0300 Subject: [PATCH 139/159] Fix method delegation in _flushfile Required when the faked stream is already faked and the original fake also uses method delegation. --- src/cs50/cs50.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cs50/cs50.py b/src/cs50/cs50.py index 16bfd0b..425173c 100644 --- a/src/cs50/cs50.py +++ b/src/cs50/cs50.py @@ -49,7 +49,7 @@ def __init__(self, f): self.f = f def __getattr__(self, name): - return object.__getattribute__(self.f, name) + return getattr(self.f, name) def write(self, x): self.f.write(x) From 4424e657431475114fc46cb86889ac08d2ce62a9 Mon Sep 17 00:00:00 2001 From: Rongxin Liu <rongxinliu.dev@gmail.com> Date: Fri, 29 Sep 2023 16:07:48 -0400 Subject: [PATCH 140/159] added support for 'CREATE VIEW' statement --- setup.py | 2 +- src/cs50/sql.py | 24 ++++++++++++++---------- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/setup.py b/setup.py index 1a8ef3a..bd33a73 100644 --- a/setup.py +++ b/setup.py @@ -18,5 +18,5 @@ package_dir={"": "src"}, packages=["cs50"], url="https://github.com/cs50/python-cs50", - version="9.2.5" + version="9.2.6" ) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index 24690e3..8d07327 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -53,7 +53,7 @@ def __init__(self, url, **kwargs): import sqlalchemy import sqlalchemy.orm import threading - + # Temporary fix for missing sqlite3 module on the buildpack stack try: import sqlite3 @@ -149,15 +149,15 @@ def execute(self, sql, *args, **kwargs): if len(args) > 0 and len(kwargs) > 0: raise RuntimeError("cannot pass both positional and named parameters") - # Infer command from (unflattened) statement - for token in statements[0]: - if token.ttype in [sqlparse.tokens.Keyword, sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML]: - token_value = token.value.upper() - if token_value in ["BEGIN", "DELETE", "INSERT", "SELECT", "START", "UPDATE"]: - command = token_value - break - else: - command = None + # Infer command from flattened statement to a single string separated by spaces + full_statement = ' '.join(str(token) for token in statements[0].tokens if token.ttype in [sqlparse.tokens.Keyword, sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML]) + full_statement = full_statement.upper() + + # set of possible commands + commands = {"BEGIN", "CREATE VIEW", "DELETE", "INSERT", "SELECT", "START", "UPDATE"} + + # check if the full_statement starts with any command + command = next((cmd for cmd in commands if full_statement.startswith(cmd)), None) # Flatten statement tokens = list(statements[0].flatten()) @@ -393,6 +393,10 @@ def teardown_appcontext(exception): elif command in ["DELETE", "UPDATE"]: ret = result.rowcount + # If CREATE VIEW, return True + elif command == "CREATE VIEW": + ret = True + # If constraint violated except sqlalchemy.exc.IntegrityError as e: self._logger.error(termcolor.colored(_statement, "red")) From 0608340a51d5fc60742b4a94ab3f9b45e7d45f2a Mon Sep 17 00:00:00 2001 From: Rongxin Liu <rongxinliu.dev@gmail.com> Date: Wed, 4 Oct 2023 11:29:20 -0700 Subject: [PATCH 141/159] unpin SQLAlchemy version in workflow --- .github/workflows/main.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 6547cf2..b8165f7 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -28,7 +28,7 @@ jobs: - name: Setup databases run: | pip install . - pip install mysqlclient psycopg2-binary SQLAlchemy==1.4.49 + pip install mysqlclient psycopg2-binary SQLAlchemy - name: Run tests run: python tests/sql.py From b3f0a0c5d0d5324595791f02bc3faae49cf1f15d Mon Sep 17 00:00:00 2001 From: Rongxin Liu <rongxinliu.dev@gmail.com> Date: Tue, 17 Oct 2023 16:00:07 -0400 Subject: [PATCH 142/159] replaced distutils with packaging --- setup.py | 4 ++-- src/cs50/cs50.py | 1 - src/cs50/flask.py | 7 +++++-- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/setup.py b/setup.py index bd33a73..d9a2ee4 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ "Topic :: Software Development :: Libraries :: Python Modules" ], description="CS50 library for Python", - install_requires=["Flask>=1.0", "SQLAlchemy==1.4.46", "sqlparse", "termcolor", "wheel"], + install_requires=["Flask>=1.0", "packaging", "SQLAlchemy==1.4.46", "sqlparse", "termcolor", "wheel"], keywords="cs50", license="GPLv3", long_description_content_type="text/markdown", @@ -18,5 +18,5 @@ package_dir={"": "src"}, packages=["cs50"], url="https://github.com/cs50/python-cs50", - version="9.2.6" + version="9.2.7" ) diff --git a/src/cs50/cs50.py b/src/cs50/cs50.py index 16bfd0b..31313f8 100644 --- a/src/cs50/cs50.py +++ b/src/cs50/cs50.py @@ -6,7 +6,6 @@ import re import sys -from distutils.sysconfig import get_python_lib from os.path import abspath, join from termcolor import colored from traceback import format_exception diff --git a/src/cs50/flask.py b/src/cs50/flask.py index 324ec30..3668007 100644 --- a/src/cs50/flask.py +++ b/src/cs50/flask.py @@ -6,10 +6,13 @@ def _wrap_flask(f): if f is None: return - from distutils.version import StrictVersion + from packaging.version import Version, InvalidVersion from .cs50 import _formatException - if f.__version__ < StrictVersion("1.0"): + try: + if Version(f.__version__) < Version("1.0"): + return + except InvalidVersion: return if os.getenv("CS50_IDE_TYPE") == "online": From bb7263454280a73077ee1657956d030f6fd60f31 Mon Sep 17 00:00:00 2001 From: Rongxin Liu <rongxinliu.dev@gmail.com> Date: Tue, 24 Oct 2023 09:36:58 -0400 Subject: [PATCH 143/159] bumped version number, fixed comment styles --- setup.py | 2 +- src/cs50/sql.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index d9a2ee4..8c1abfb 100644 --- a/setup.py +++ b/setup.py @@ -18,5 +18,5 @@ package_dir={"": "src"}, packages=["cs50"], url="https://github.com/cs50/python-cs50", - version="9.2.7" + version="9.3.0" ) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index fee0d68..de3ad56 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -153,10 +153,10 @@ def execute(self, sql, *args, **kwargs): full_statement = ' '.join(str(token) for token in statements[0].tokens if token.ttype in [sqlparse.tokens.Keyword, sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML]) full_statement = full_statement.upper() - # set of possible commands + # Set of possible commands commands = {"BEGIN", "CREATE VIEW", "DELETE", "INSERT", "SELECT", "START", "UPDATE"} - # check if the full_statement starts with any command + # Check if the full_statement starts with any command command = next((cmd for cmd in commands if full_statement.startswith(cmd)), None) # Flatten statement From 3ddef3132a2485d48ea487bb9401317921ebfc82 Mon Sep 17 00:00:00 2001 From: Rongxin Liu <rongxinliu.dev@gmail.com> Date: Tue, 24 Oct 2023 09:39:36 -0400 Subject: [PATCH 144/159] updated SQLAlchemy version constraint --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 8c1abfb..23f6b01 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ "Topic :: Software Development :: Libraries :: Python Modules" ], description="CS50 library for Python", - install_requires=["Flask>=1.0", "packaging", "SQLAlchemy==1.4.46", "sqlparse", "termcolor", "wheel"], + install_requires=["Flask>=1.0", "packaging", "SQLAlchemy<3", "sqlparse", "termcolor", "wheel"], keywords="cs50", license="GPLv3", long_description_content_type="text/markdown", From 1b644dd0c4a4fd252306933ed2a1247ffda8de83 Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Sun, 17 Dec 2023 14:20:44 -0500 Subject: [PATCH 145/159] updated IO wrapper, style, version --- setup.py | 2 +- src/cs50/__init__.py | 1 + src/cs50/cs50.py | 49 +++++++---- src/cs50/flask.py | 10 ++- src/cs50/sql.py | 199 ++++++++++++++++++++++++++++++------------- 5 files changed, 184 insertions(+), 77 deletions(-) diff --git a/setup.py b/setup.py index 23f6b01..10ceb30 100644 --- a/setup.py +++ b/setup.py @@ -18,5 +18,5 @@ package_dir={"": "src"}, packages=["cs50"], url="https://github.com/cs50/python-cs50", - version="9.3.0" + version="9.3.1" ) diff --git a/src/cs50/__init__.py b/src/cs50/__init__.py index aaec161..7dd4e17 100644 --- a/src/cs50/__init__.py +++ b/src/cs50/__init__.py @@ -8,6 +8,7 @@ # Import cs50_* from .cs50 import get_char, get_float, get_int, get_string + try: from .cs50 import get_long except ImportError: diff --git a/src/cs50/cs50.py b/src/cs50/cs50.py index 07f13e9..f331a88 100644 --- a/src/cs50/cs50.py +++ b/src/cs50/cs50.py @@ -17,7 +17,9 @@ try: # Patch formatException - logging.root.handlers[0].formatter.formatException = lambda exc_info: _formatException(*exc_info) + logging.root.handlers[ + 0 + ].formatter.formatException = lambda exc_info: _formatException(*exc_info) except IndexError: pass @@ -37,26 +39,31 @@ _logger.addHandler(handler) -class _flushfile(): +class _Unbuffered: """ Disable buffering for standard output and standard error. - http://stackoverflow.com/a/231216 + https://stackoverflow.com/a/107717 + https://docs.python.org/3/library/io.html """ - def __init__(self, f): - self.f = f + def __init__(self, stream): + self.stream = stream - def __getattr__(self, name): - return getattr(self.f, name) + def __getattr__(self, attr): + return getattr(self.stream, attr) - def write(self, x): - self.f.write(x) - self.f.flush() + def write(self, b): + self.stream.write(b) + self.stream.flush() + def writelines(self, lines): + self.stream.writelines(lines) + self.stream.flush() -sys.stderr = _flushfile(sys.stderr) -sys.stdout = _flushfile(sys.stdout) + +sys.stderr = _Unbuffered(sys.stderr) +sys.stdout = _Unbuffered(sys.stdout) def _formatException(type, value, tb): @@ -78,19 +85,29 @@ def _formatException(type, value, tb): lines += line else: matches = re.search(r"^(\s*)(.*?)(\s*)$", line, re.DOTALL) - lines.append(matches.group(1) + colored(matches.group(2), "yellow") + matches.group(3)) + lines.append( + matches.group(1) + + colored(matches.group(2), "yellow") + + matches.group(3) + ) return "".join(lines).rstrip() -sys.excepthook = lambda type, value, tb: print(_formatException(type, value, tb), file=sys.stderr) +sys.excepthook = lambda type, value, tb: print( + _formatException(type, value, tb), file=sys.stderr +) def eprint(*args, **kwargs): - raise RuntimeError("The CS50 Library for Python no longer supports eprint, but you can use print instead!") + raise RuntimeError( + "The CS50 Library for Python no longer supports eprint, but you can use print instead!" + ) def get_char(prompt): - raise RuntimeError("The CS50 Library for Python no longer supports get_char, but you can use get_string instead!") + raise RuntimeError( + "The CS50 Library for Python no longer supports get_char, but you can use get_string instead!" + ) def get_float(prompt): diff --git a/src/cs50/flask.py b/src/cs50/flask.py index 3668007..6e38971 100644 --- a/src/cs50/flask.py +++ b/src/cs50/flask.py @@ -2,6 +2,7 @@ import pkgutil import sys + def _wrap_flask(f): if f is None: return @@ -17,10 +18,15 @@ def _wrap_flask(f): if os.getenv("CS50_IDE_TYPE") == "online": from werkzeug.middleware.proxy_fix import ProxyFix + _flask_init_before = f.Flask.__init__ + def _flask_init_after(self, *args, **kwargs): _flask_init_before(self, *args, **kwargs) - self.wsgi_app = ProxyFix(self.wsgi_app, x_proto=1) # For HTTPS-to-HTTP proxy + self.wsgi_app = ProxyFix( + self.wsgi_app, x_proto=1 + ) # For HTTPS-to-HTTP proxy + f.Flask.__init__ = _flask_init_after @@ -30,7 +36,7 @@ def _flask_init_after(self, *args, **kwargs): # If Flask wasn't imported else: - flask_loader = pkgutil.get_loader('flask') + flask_loader = pkgutil.get_loader("flask") if flask_loader: _exec_module_before = flask_loader.exec_module diff --git a/src/cs50/sql.py b/src/cs50/sql.py index de3ad56..a0b93eb 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -14,7 +14,6 @@ def _enable_logging(f): @functools.wraps(f) def decorator(*args, **kwargs): - # Infer whether Flask is installed try: import flask @@ -71,17 +70,20 @@ def __init__(self, url, **kwargs): # Create engine, disabling SQLAlchemy's own autocommit mode raising exception if back end's module not installed; # without isolation_level, PostgreSQL warns with "there is already a transaction in progress" for our own BEGIN and # "there is no transaction in progress" for our own COMMIT - self._engine = sqlalchemy.create_engine(url, **kwargs).execution_options(autocommit=False, isolation_level="AUTOCOMMIT") + self._engine = sqlalchemy.create_engine(url, **kwargs).execution_options( + autocommit=False, isolation_level="AUTOCOMMIT" + ) # Get logger self._logger = logging.getLogger("cs50") # Listener for connections def connect(dbapi_connection, connection_record): - # Enable foreign key constraints try: - if isinstance(dbapi_connection, sqlite3.Connection): # If back end is sqlite + if isinstance( + dbapi_connection, sqlite3.Connection + ): # If back end is sqlite cursor = dbapi_connection.cursor() cursor.execute("PRAGMA foreign_keys=ON") cursor.close() @@ -150,14 +152,33 @@ def execute(self, sql, *args, **kwargs): raise RuntimeError("cannot pass both positional and named parameters") # Infer command from flattened statement to a single string separated by spaces - full_statement = ' '.join(str(token) for token in statements[0].tokens if token.ttype in [sqlparse.tokens.Keyword, sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML]) + full_statement = " ".join( + str(token) + for token in statements[0].tokens + if token.ttype + in [ + sqlparse.tokens.Keyword, + sqlparse.tokens.Keyword.DDL, + sqlparse.tokens.Keyword.DML, + ] + ) full_statement = full_statement.upper() # Set of possible commands - commands = {"BEGIN", "CREATE VIEW", "DELETE", "INSERT", "SELECT", "START", "UPDATE"} + commands = { + "BEGIN", + "CREATE VIEW", + "DELETE", + "INSERT", + "SELECT", + "START", + "UPDATE", + } # Check if the full_statement starts with any command - command = next((cmd for cmd in commands if full_statement.startswith(cmd)), None) + command = next( + (cmd for cmd in commands if full_statement.startswith(cmd)), None + ) # Flatten statement tokens = list(statements[0].flatten()) @@ -166,10 +187,8 @@ def execute(self, sql, *args, **kwargs): placeholders = {} paramstyle = None for index, token in enumerate(tokens): - # If token is a placeholder if token.ttype == sqlparse.tokens.Name.Placeholder: - # Determine paramstyle, name _paramstyle, name = _parse_placeholder(token) @@ -186,7 +205,6 @@ def execute(self, sql, *args, **kwargs): # If no placeholders if not paramstyle: - # Error-check like qmark if args if args: paramstyle = "qmark" @@ -201,13 +219,20 @@ def execute(self, sql, *args, **kwargs): # qmark if paramstyle == "qmark": - # Validate number of placeholders if len(placeholders) != len(args): if len(placeholders) < len(args): - raise RuntimeError("fewer placeholders ({}) than values ({})".format(_placeholders, _args)) + raise RuntimeError( + "fewer placeholders ({}) than values ({})".format( + _placeholders, _args + ) + ) else: - raise RuntimeError("more placeholders ({}) than values ({})".format(_placeholders, _args)) + raise RuntimeError( + "more placeholders ({}) than values ({})".format( + _placeholders, _args + ) + ) # Escape values for i, index in enumerate(placeholders.keys()): @@ -215,27 +240,34 @@ def execute(self, sql, *args, **kwargs): # numeric elif paramstyle == "numeric": - # Escape values for index, i in placeholders.items(): if i >= len(args): - raise RuntimeError("missing value for placeholder (:{})".format(i + 1, len(args))) + raise RuntimeError( + "missing value for placeholder (:{})".format(i + 1, len(args)) + ) tokens[index] = self._escape(args[i]) # Check if any values unused indices = set(range(len(args))) - set(placeholders.values()) if indices: - raise RuntimeError("unused {} ({})".format( - "value" if len(indices) == 1 else "values", - ", ".join([str(self._escape(args[index])) for index in indices]))) + raise RuntimeError( + "unused {} ({})".format( + "value" if len(indices) == 1 else "values", + ", ".join( + [str(self._escape(args[index])) for index in indices] + ), + ) + ) # named elif paramstyle == "named": - # Escape values for index, name in placeholders.items(): if name not in kwargs: - raise RuntimeError("missing value for placeholder (:{})".format(name)) + raise RuntimeError( + "missing value for placeholder (:{})".format(name) + ) tokens[index] = self._escape(kwargs[name]) # Check if any keys unused @@ -245,13 +277,20 @@ def execute(self, sql, *args, **kwargs): # format elif paramstyle == "format": - # Validate number of placeholders if len(placeholders) != len(args): if len(placeholders) < len(args): - raise RuntimeError("fewer placeholders ({}) than values ({})".format(_placeholders, _args)) + raise RuntimeError( + "fewer placeholders ({}) than values ({})".format( + _placeholders, _args + ) + ) else: - raise RuntimeError("more placeholders ({}) than values ({})".format(_placeholders, _args)) + raise RuntimeError( + "more placeholders ({}) than values ({})".format( + _placeholders, _args + ) + ) # Escape values for i, index in enumerate(placeholders.keys()): @@ -259,40 +298,44 @@ def execute(self, sql, *args, **kwargs): # pyformat elif paramstyle == "pyformat": - # Escape values for index, name in placeholders.items(): if name not in kwargs: - raise RuntimeError("missing value for placeholder (%{}s)".format(name)) + raise RuntimeError( + "missing value for placeholder (%{}s)".format(name) + ) tokens[index] = self._escape(kwargs[name]) # Check if any keys unused keys = kwargs.keys() - placeholders.values() if keys: - raise RuntimeError("unused {} ({})".format( - "value" if len(keys) == 1 else "values", - ", ".join(keys))) + raise RuntimeError( + "unused {} ({})".format( + "value" if len(keys) == 1 else "values", ", ".join(keys) + ) + ) # For SQL statements where a colon is required verbatim, as within an inline string, use a backslash to escape # https://docs.sqlalchemy.org/en/13/core/sqlelement.html?highlight=text#sqlalchemy.sql.expression.text for index, token in enumerate(tokens): - # In string literal # https://www.sqlite.org/lang_keywords.html - if token.ttype in [sqlparse.tokens.Literal.String, sqlparse.tokens.Literal.String.Single]: + if token.ttype in [ + sqlparse.tokens.Literal.String, + sqlparse.tokens.Literal.String.Single, + ]: token.value = re.sub("(^'|\s+):", r"\1\:", token.value) # In identifier # https://www.sqlite.org/lang_keywords.html elif token.ttype == sqlparse.tokens.Literal.String.Symbol: - token.value = re.sub("(^\"|\s+):", r"\1\:", token.value) + token.value = re.sub('(^"|\s+):', r"\1\:", token.value) # Join tokens into statement statement = "".join([str(token) for token in tokens]) # If no connection yet if not hasattr(_data, self._name()): - # Connect to database setattr(_data, self._name(), self._engine.connect()) @@ -302,9 +345,12 @@ def execute(self, sql, *args, **kwargs): # Disconnect if/when a Flask app is torn down try: import flask + assert flask.current_app + def teardown_appcontext(exception): self._disconnect() + if teardown_appcontext not in flask.current_app.teardown_appcontext_funcs: flask.current_app.teardown_appcontext(teardown_appcontext) except (ModuleNotFoundError, AssertionError): @@ -312,15 +358,20 @@ def teardown_appcontext(exception): # Catch SQLAlchemy warnings with warnings.catch_warnings(): - # Raise exceptions for warnings warnings.simplefilter("error") # Prepare, execute statement try: - # Join tokens into statement, abbreviating binary data as <class 'bytes'> - _statement = "".join([str(bytes) if token.ttype == sqlparse.tokens.Other else str(token) for token in tokens]) + _statement = "".join( + [ + str(bytes) + if token.ttype == sqlparse.tokens.Other + else str(token) + for token in tokens + ] + ) # Check for start of transaction if command in ["BEGIN", "START"]: @@ -342,12 +393,10 @@ def teardown_appcontext(exception): # If SELECT, return result set as list of dict objects if command == "SELECT": - # Coerce types rows = [dict(row) for row in result.mappings().all()] for row in rows: for column in row: - # Coerce decimal.Decimal objects to float objects # https://groups.google.com/d/msg/sqlalchemy/0qXMYJvq8SA/oqtvMD9Uw-kJ if isinstance(row[column], decimal.Decimal): @@ -362,15 +411,15 @@ def teardown_appcontext(exception): # If INSERT, return primary key value for a newly inserted row (or None if none) elif command == "INSERT": - # If PostgreSQL if self._engine.url.get_backend_name() == "postgresql": - # Return LASTVAL() or NULL, avoiding # "(psycopg2.errors.ObjectNotInPrerequisiteState) lastval is not yet defined in this session", # a la https://stackoverflow.com/a/24186770/5156190; # cf. https://www.psycopg.org/docs/errors.html re 55000 - result = connection.execute(sqlalchemy.text(""" + result = connection.execute( + sqlalchemy.text( + """ CREATE OR REPLACE FUNCTION _LASTVAL() RETURNS integer LANGUAGE plpgsql AS $$ @@ -382,7 +431,9 @@ def teardown_appcontext(exception): END; END $$; SELECT _LASTVAL(); - """)) + """ + ) + ) ret = result.first()[0] # If not PostgreSQL @@ -405,7 +456,10 @@ def teardown_appcontext(exception): raise e # If user error - except (sqlalchemy.exc.OperationalError, sqlalchemy.exc.ProgrammingError) as e: + except ( + sqlalchemy.exc.OperationalError, + sqlalchemy.exc.ProgrammingError, + ) as e: self._disconnect() self._logger.error(termcolor.colored(_statement, "red")) e = RuntimeError(e.orig) @@ -430,7 +484,6 @@ def _escape(self, value): import sqlparse def __escape(value): - # Lazily import import datetime import sqlalchemy @@ -439,14 +492,21 @@ def __escape(value): if isinstance(value, bool): return sqlparse.sql.Token( sqlparse.tokens.Number, - sqlalchemy.types.Boolean().literal_processor(self._engine.dialect)(value)) + sqlalchemy.types.Boolean().literal_processor(self._engine.dialect)( + value + ), + ) # bytes elif isinstance(value, bytes): if self._engine.url.get_backend_name() in ["mysql", "sqlite"]: - return sqlparse.sql.Token(sqlparse.tokens.Other, f"x'{value.hex()}'") # https://dev.mysql.com/doc/refman/8.0/en/hexadecimal-literals.html + return sqlparse.sql.Token( + sqlparse.tokens.Other, f"x'{value.hex()}'" + ) # https://dev.mysql.com/doc/refman/8.0/en/hexadecimal-literals.html elif self._engine.url.get_backend_name() == "postgresql": - return sqlparse.sql.Token(sqlparse.tokens.Other, f"'\\x{value.hex()}'") # https://dba.stackexchange.com/a/203359 + return sqlparse.sql.Token( + sqlparse.tokens.Other, f"'\\x{value.hex()}'" + ) # https://dba.stackexchange.com/a/203359 else: raise RuntimeError("unsupported value: {}".format(value)) @@ -454,43 +514,59 @@ def __escape(value): elif isinstance(value, datetime.datetime): return sqlparse.sql.Token( sqlparse.tokens.String, - sqlalchemy.types.String().literal_processor(self._engine.dialect)(value.strftime("%Y-%m-%d %H:%M:%S"))) + sqlalchemy.types.String().literal_processor(self._engine.dialect)( + value.strftime("%Y-%m-%d %H:%M:%S") + ), + ) # datetime.date elif isinstance(value, datetime.date): return sqlparse.sql.Token( sqlparse.tokens.String, - sqlalchemy.types.String().literal_processor(self._engine.dialect)(value.strftime("%Y-%m-%d"))) + sqlalchemy.types.String().literal_processor(self._engine.dialect)( + value.strftime("%Y-%m-%d") + ), + ) # datetime.time elif isinstance(value, datetime.time): return sqlparse.sql.Token( sqlparse.tokens.String, - sqlalchemy.types.String().literal_processor(self._engine.dialect)(value.strftime("%H:%M:%S"))) + sqlalchemy.types.String().literal_processor(self._engine.dialect)( + value.strftime("%H:%M:%S") + ), + ) # float elif isinstance(value, float): return sqlparse.sql.Token( sqlparse.tokens.Number, - sqlalchemy.types.Float().literal_processor(self._engine.dialect)(value)) + sqlalchemy.types.Float().literal_processor(self._engine.dialect)( + value + ), + ) # int elif isinstance(value, int): return sqlparse.sql.Token( sqlparse.tokens.Number, - sqlalchemy.types.Integer().literal_processor(self._engine.dialect)(value)) + sqlalchemy.types.Integer().literal_processor(self._engine.dialect)( + value + ), + ) # str elif isinstance(value, str): return sqlparse.sql.Token( sqlparse.tokens.String, - sqlalchemy.types.String().literal_processor(self._engine.dialect)(value)) + sqlalchemy.types.String().literal_processor(self._engine.dialect)( + value + ), + ) # None elif value is None: - return sqlparse.sql.Token( - sqlparse.tokens.Keyword, - sqlalchemy.null()) + return sqlparse.sql.Token(sqlparse.tokens.Keyword, sqlalchemy.null()) # Unsupported value else: @@ -498,7 +574,9 @@ def __escape(value): # Escape value(s), separating with commas as needed if isinstance(value, (list, tuple)): - return sqlparse.sql.TokenList(sqlparse.parse(", ".join([str(__escape(v)) for v in value]))) + return sqlparse.sql.TokenList( + sqlparse.parse(", ".join([str(__escape(v)) for v in value])) + ) else: return __escape(value) @@ -510,7 +588,9 @@ def _parse_exception(e): import re # MySQL - matches = re.search(r"^\(_mysql_exceptions\.OperationalError\) \(\d+, \"(.+)\"\)$", str(e)) + matches = re.search( + r"^\(_mysql_exceptions\.OperationalError\) \(\d+, \"(.+)\"\)$", str(e) + ) if matches: return matches.group(1) @@ -536,7 +616,10 @@ def _parse_placeholder(token): import sqlparse # Validate token - if not isinstance(token, sqlparse.sql.Token) or token.ttype != sqlparse.tokens.Name.Placeholder: + if ( + not isinstance(token, sqlparse.sql.Token) + or token.ttype != sqlparse.tokens.Name.Placeholder + ): raise TypeError() # qmark From e18486cc36f879c83eefc74a408ba411eeb35060 Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Sun, 17 Dec 2023 14:40:13 -0500 Subject: [PATCH 146/159] fixes #178 --- src/cs50/sql.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index de3ad56..ef011da 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -399,6 +399,8 @@ def teardown_appcontext(exception): # If constraint violated except sqlalchemy.exc.IntegrityError as e: + if self._autocommit: + connection.execute(sqlalchemy.text("ROLLBACK")) self._logger.error(termcolor.colored(_statement, "red")) e = ValueError(e.orig) e.__cause__ = None From 781c1c201090493a65793f586174e14644163f36 Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Sun, 17 Dec 2023 14:40:26 -0500 Subject: [PATCH 147/159] increments version --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 23f6b01..10ceb30 100644 --- a/setup.py +++ b/setup.py @@ -18,5 +18,5 @@ package_dir={"": "src"}, packages=["cs50"], url="https://github.com/cs50/python-cs50", - version="9.3.0" + version="9.3.1" ) From 14a9741b904a437af015e0aa9d14d8531bc18de2 Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Sun, 17 Dec 2023 14:40:45 -0500 Subject: [PATCH 148/159] increments version --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 10ceb30..8f4b1be 100644 --- a/setup.py +++ b/setup.py @@ -18,5 +18,5 @@ package_dir={"": "src"}, packages=["cs50"], url="https://github.com/cs50/python-cs50", - version="9.3.1" + version="9.3.2" ) From fc6647adbf1bb67cbae8e29e01095d550ed04d2e Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Sun, 17 Dec 2023 19:16:30 -0500 Subject: [PATCH 149/159] updated style --- setup.py | 2 +- src/cs50/sql.py | 13 ++++++++++++- tests/sql.py | 6 ++++++ 3 files changed, 19 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 8f4b1be..7976109 100644 --- a/setup.py +++ b/setup.py @@ -18,5 +18,5 @@ package_dir={"": "src"}, packages=["cs50"], url="https://github.com/cs50/python-cs50", - version="9.3.2" + version="9.3.3" ) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index 3be30fe..356ed62 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -71,9 +71,13 @@ def __init__(self, url, **kwargs): # without isolation_level, PostgreSQL warns with "there is already a transaction in progress" for our own BEGIN and # "there is no transaction in progress" for our own COMMIT self._engine = sqlalchemy.create_engine(url, **kwargs).execution_options( - autocommit=False, isolation_level="AUTOCOMMIT" + autocommit=False, isolation_level="AUTOCOMMIT", no_parameters=True ) + # Avoid doubly escaping percent signs, since no_parameters=True anyway + # https://github.com/cs50/python-cs50/issues/171 + self._engine.dialect.identifier_preparer._double_percents = False + # Get logger self._logger = logging.getLogger("cs50") @@ -559,12 +563,19 @@ def __escape(value): # str elif isinstance(value, str): +<<<<<<< HEAD return sqlparse.sql.Token( sqlparse.tokens.String, sqlalchemy.types.String().literal_processor(self._engine.dialect)( value ), ) +======= + literal = sqlalchemy.types.String().literal_processor(self._engine.dialect)(value) + #if self._engine.dialect.identifier_preparer._double_percents: + # literal = literal.replace("%%", "%") + return sqlparse.sql.Token(sqlparse.tokens.String, literal) +>>>>>>> 3863555 (fixes #171) # None elif value is None: diff --git a/tests/sql.py b/tests/sql.py index 968f98b..b5d5406 100644 --- a/tests/sql.py +++ b/tests/sql.py @@ -138,6 +138,12 @@ def test_lastrowid(self): self.assertEqual(self.db.execute("INSERT INTO foo (firstname, lastname) VALUES('firstname', 'lastname')"), 1) self.assertRaises(ValueError, self.db.execute, "INSERT INTO foo (id, firstname, lastname) VALUES(1, 'firstname', 'lastname')") + def test_url(self): + url = "https://www.amazon.es/Desesperaci%C3%B3n-BEST-SELLER-Stephen-King/dp/8497595890" + self.db.execute("CREATE TABLE foo(id SERIAL PRIMARY KEY, url TEXT)") + self.db.execute("INSERT INTO foo (url) VALUES(?)", url) + self.assertEqual(self.db.execute("SELECT url FROM foo")[0]["url"], url) + def tearDown(self): self.db.execute("DROP TABLE cs50") self.db.execute("DROP TABLE IF EXISTS foo") From 128f498e8e34f2db10c06633b2f7d2b742ecc2c9 Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Sun, 17 Dec 2023 19:15:37 -0500 Subject: [PATCH 150/159] fixed YAML --- docker-compose.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker-compose.yml b/docker-compose.yml index f795750..91f5a7d 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -7,7 +7,7 @@ services: - postgres environment: MYSQL_HOST: mysql - POSTGRESQL_HOST: postgresql + POSTGRESQL_HOST: postgres links: - mysql - postgres From 3b9036933da29b0265f7ddae0d709ed2784cfb1d Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Sun, 17 Dec 2023 19:18:14 -0500 Subject: [PATCH 151/159] fixed merge --- src/cs50/sql.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index 356ed62..cbbd3d2 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -563,19 +563,12 @@ def __escape(value): # str elif isinstance(value, str): -<<<<<<< HEAD return sqlparse.sql.Token( sqlparse.tokens.String, sqlalchemy.types.String().literal_processor(self._engine.dialect)( value ), ) -======= - literal = sqlalchemy.types.String().literal_processor(self._engine.dialect)(value) - #if self._engine.dialect.identifier_preparer._double_percents: - # literal = literal.replace("%%", "%") - return sqlparse.sql.Token(sqlparse.tokens.String, literal) ->>>>>>> 3863555 (fixes #171) # None elif value is None: From e47cc3563cb015159ab897dd8053af2f887f7907 Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Sun, 17 Dec 2023 19:39:15 -0500 Subject: [PATCH 152/159] updated for MySQL tests --- docker-compose.yml | 2 +- tests/sql.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docker-compose.yml b/docker-compose.yml index 91f5a7d..8608080 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -20,7 +20,7 @@ services: MYSQL_ALLOW_EMPTY_PASSWORD: yes healthcheck: test: ["CMD", "mysqladmin", "-uroot", "ping"] - image: cs50/mysql:8 + image: cs50/mysql ports: - 3306:3306 postgres: diff --git a/tests/sql.py b/tests/sql.py index b5d5406..bb37fd9 100644 --- a/tests/sql.py +++ b/tests/sql.py @@ -336,7 +336,7 @@ def test_cte(self): if __name__ == "__main__": suite = unittest.TestSuite([ unittest.TestLoader().loadTestsFromTestCase(SQLiteTests), - #unittest.TestLoader().loadTestsFromTestCase(MySQLTests), + unittest.TestLoader().loadTestsFromTestCase(MySQLTests), unittest.TestLoader().loadTestsFromTestCase(PostgresTests) ]) From dc6a43f5f93cd22457873ee858a60cbac46237f6 Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Sun, 17 Dec 2023 23:04:50 -0500 Subject: [PATCH 153/159] updated SQLAlchemy versioning --- setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 7976109..6fd6cfa 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ "Topic :: Software Development :: Libraries :: Python Modules" ], description="CS50 library for Python", - install_requires=["Flask>=1.0", "packaging", "SQLAlchemy<3", "sqlparse", "termcolor", "wheel"], + install_requires=["Flask>=1.0", "packaging", "SQLAlchemy>=2,<3", "sqlparse", "termcolor", "wheel"], keywords="cs50", license="GPLv3", long_description_content_type="text/markdown", @@ -18,5 +18,5 @@ package_dir={"": "src"}, packages=["cs50"], url="https://github.com/cs50/python-cs50", - version="9.3.3" + version="9.3.4" ) From 688af684d61a01e3bec5008acf2293a6d505196f Mon Sep 17 00:00:00 2001 From: Rongxin Liu <rongxinliu.dev@gmail.com> Date: Tue, 5 Mar 2024 00:23:11 +0700 Subject: [PATCH 154/159] updated workflow actions --- .github/workflows/main.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index b8165f7..0f14704 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -20,8 +20,8 @@ jobs: ports: - 5432:5432 steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 with: python-version: '3.7' check-latest: true From d67a26b3891a8d093aaa3a53e0f9a7bfeb0a599e Mon Sep 17 00:00:00 2001 From: "yulai.linda@gmail.com" <rongxinliu.dev@gmail.com> Date: Thu, 2 May 2024 12:22:33 -0400 Subject: [PATCH 155/159] updated actions/github-script to version v7 --- .github/workflows/main.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 0f14704..0b0ee1a 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -56,7 +56,7 @@ jobs: - name: Create Release if: ${{ github.ref == 'refs/heads/main' }} - uses: actions/github-script@v6 + uses: actions/github-script@v7 with: github-token: ${{ github.token }} script: | From 2d5fd94132accb2733833eb273d755803a99794c Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Mon, 14 Oct 2024 20:39:22 -0400 Subject: [PATCH 156/159] adds support for VACUUM --- setup.py | 2 +- src/cs50/sql.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 6fd6cfa..1817b95 100644 --- a/setup.py +++ b/setup.py @@ -18,5 +18,5 @@ package_dir={"": "src"}, packages=["cs50"], url="https://github.com/cs50/python-cs50", - version="9.3.4" + version="9.4.0" ) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index cbbd3d2..a133293 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -177,6 +177,7 @@ def execute(self, sql, *args, **kwargs): "SELECT", "START", "UPDATE", + "VACUUM", } # Check if the full_statement starts with any command @@ -378,7 +379,7 @@ def teardown_appcontext(exception): ) # Check for start of transaction - if command in ["BEGIN", "START"]: + if command in ["BEGIN", "START", "VACUUM"]: # cannot VACUUM from within a transaction self._autocommit = False # Execute statement @@ -389,7 +390,7 @@ def teardown_appcontext(exception): connection.execute(sqlalchemy.text("COMMIT")) # Check for end of transaction - if command in ["COMMIT", "ROLLBACK"]: + if command in ["COMMIT", "ROLLBACK", "VACUUM"]: # cannot VACUUM from within a transaction self._autocommit = True # Return value From f81a0a3920d90296537843d45683ec0b5d70171b Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Mon, 14 Oct 2024 20:40:31 -0400 Subject: [PATCH 157/159] fixes raw strings --- src/cs50/sql.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index a133293..81905bc 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -329,12 +329,12 @@ def execute(self, sql, *args, **kwargs): sqlparse.tokens.Literal.String, sqlparse.tokens.Literal.String.Single, ]: - token.value = re.sub("(^'|\s+):", r"\1\:", token.value) + token.value = re.sub(r"(^'|\s+):", r"\1\:", token.value) # In identifier # https://www.sqlite.org/lang_keywords.html elif token.ttype == sqlparse.tokens.Literal.String.Symbol: - token.value = re.sub('(^"|\s+):', r"\1\:", token.value) + token.value = re.sub(r'(^"|\s+):', r"\1\:", token.value) # Join tokens into statement statement = "".join([str(token) for token in tokens]) From b629b6dcdf3041982117801a35c01be2b72cb3d3 Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Sat, 5 Apr 2025 11:36:15 -0400 Subject: [PATCH 158/159] Update README.md --- README.md | 47 ----------------------------------------------- 1 file changed, 47 deletions(-) diff --git a/README.md b/README.md index c94bd1f..a9033c6 100644 --- a/README.md +++ b/README.md @@ -39,50 +39,3 @@ s = cs50.get_string(); ``` python tests/sql.py ``` - -### Sample Tests - -``` -import cs50 -db = cs50.SQL("sqlite:///foo.db") -db.execute("CREATE TABLE IF NOT EXISTS cs50 (id INTEGER PRIMARY KEY, val TEXT, bin BLOB)") -db.execute("INSERT INTO cs50 (val) VALUES('a')") -db.execute("INSERT INTO cs50 (val) VALUES('b')") -db.execute("BEGIN") -db.execute("INSERT INTO cs50 (val) VALUES('c')") -db.execute("INSERT INTO cs50 (val) VALUES('x')") -db.execute("INSERT INTO cs50 (val) VALUES('y')") -db.execute("ROLLBACK") -db.execute("INSERT INTO cs50 (val) VALUES('z')") -db.execute("COMMIT") - ---- - -import cs50 -db = cs50.SQL("mysql://root@localhost/test") -db.execute("CREATE TABLE IF NOT EXISTS cs50 (id INTEGER PRIMARY KEY, val TEXT, bin BLOB)") -db.execute("INSERT INTO cs50 (val) VALUES('a')") -db.execute("INSERT INTO cs50 (val) VALUES('b')") -db.execute("BEGIN") -db.execute("INSERT INTO cs50 (val) VALUES('c')") -db.execute("INSERT INTO cs50 (val) VALUES('x')") -db.execute("INSERT INTO cs50 (val) VALUES('y')") -db.execute("ROLLBACK") -db.execute("INSERT INTO cs50 (val) VALUES('z')") -db.execute("COMMIT") - ---- - -import cs50 -db = cs50.SQL("postgresql://postgres@localhost/test") -db.execute("CREATE TABLE IF NOT EXISTS cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16), bin BYTEA)") -db.execute("INSERT INTO cs50 (val) VALUES('a')") -db.execute("INSERT INTO cs50 (val) VALUES('b')") -db.execute("BEGIN") -db.execute("INSERT INTO cs50 (val) VALUES('c')") -db.execute("INSERT INTO cs50 (val) VALUES('x')") -db.execute("INSERT INTO cs50 (val) VALUES('y')") -db.execute("ROLLBACK") -db.execute("INSERT INTO cs50 (val) VALUES('z')") -db.execute("COMMIT") -``` From 68f4380757d8adb92d027c42c2099e798bff9488 Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Sat, 5 Apr 2025 11:38:49 -0400 Subject: [PATCH 159/159] updated Python for Action --- .github/workflows/main.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 0b0ee1a..7fcb507 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -23,7 +23,7 @@ jobs: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 with: - python-version: '3.7' + python-version: '3.12' check-latest: true - name: Setup databases run: |