diff --git a/.travis.yml b/.travis.yml index b4da059..b13ffd6 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,41 +1,30 @@ language: python -python: -- '2.7' -- '3.6' +python: '3.6' branches: except: "/^v\\d/" services: -- mysql -- postgresql + - mysql + - postgresql install: -- python setup.py install -- pip install mysqlclient -- pip install psycopg2-binary + - python setup.py install + - pip install mysqlclient + - pip install psycopg2-binary before_script: -- mysql -e 'CREATE DATABASE IF NOT EXISTS test;' -- psql -c 'create database test;' -U postgres -- touch test.db test1.db + - mysql -e 'CREATE DATABASE IF NOT EXISTS test;' + - psql -c 'create database test;' -U postgres + - touch test.db test1.db script: python tests/sql.py -after_script: rm -f test.db -jobs: - include: - - stage: deploy - python: '3.6' - install: skip - before_script: skip - script: skip - deploy: - - provider: script - script: 'curl --fail --data "{ \"tag_name\": \"v$(python setup.py --version)\", - \"target_commitish\": \"$TRAVIS_COMMIT\", \"name\": \"v$(python setup.py --version)\" - }" --user bot50:$GITHUB_TOKEN https://api.github.com/repos/$TRAVIS_REPO_SLUG/releases' - on: - branch: master - - provider: pypi - user: "$PYPI_USERNAME" - password: "$PYPI_PASSWORD" - on: - branch: master +deploy: + - provider: script + script: 'curl --fail --data "{ \"tag_name\": \"v$(python setup.py --version)\", + \"target_commitish\": \"$TRAVIS_COMMIT\", \"name\": \"v$(python setup.py --version)\" + }" --user bot50:$GITHUB_TOKEN https://api.github.com/repos/$TRAVIS_REPO_SLUG/releases' + on: + branch: master + - provider: pypi + user: "$PYPI_USERNAME" + password: "$PYPI_PASSWORD" + on: master notifications: slack: secure: lJklhcBVjDT6KzUNa3RFHXdXSeH7ytuuGrkZ5ZcR72CXMoTf2pMJTzPwRLWOp6lCSdDC9Y8MWLrcg/e33dJga4Jlp9alOmWqeqesaFjfee4st8vAsgNbv8/RajPH1gD2bnkt8oIwUzdHItdb5AucKFYjbH2g0d8ndoqYqUeBLrnsT1AP5G/Vi9OHC9OWNpR0FKaZIJE0Wt52vkPMH3sV2mFeIskByPB+56U5y547mualKxn61IVR/dhYBEtZQJuSvnwKHPOn9Pkk7cCa+SSSeTJ4w5LboY8T17otaYNauXo46i1bKIoGiBcCcrJyQHHiPQmcq/YU540MC5Wzt9YXUycmJzRi347oyQeDee27wV3XJlWMXuuhbtJiKCFny7BTQ160VATlj/dbwIzN99Ra6/BtTumv/6LyTdKIuVjdAkcN8dtdDW1nlrQ29zuPNCcXXzJ7zX7kQaOCUV1c2OrsbiH/0fE9nknUORn97txqhlYVi0QMS7764wFo6kg0vpmFQRkkQySsJl+TmgcZ01AlsJc2EMMWVuaj9Af9JU4/4yalqDiXIh1fOYYUZnLfOfWS+MsnI+/oLfqJFyMbrsQQTIjs+kTzbiEdhd2R4EZgusU/xRFWokS2NAvahexrRhRQ6tpAI+LezPrkNOR3aHiykBf+P9BkUa0wPp6V2Ayc6q0= diff --git a/setup.py b/setup.py index afce734..3ca103a 100644 --- a/setup.py +++ b/setup.py @@ -16,5 +16,5 @@ package_dir={"": "src"}, packages=["cs50"], url="https://github.com/cs50/python-cs50", - version="3.1.0" + version="3.2.0" ) diff --git a/src/cs50/__init__.py b/src/cs50/__init__.py index 94e9fc8..54706cb 100644 --- a/src/cs50/__init__.py +++ b/src/cs50/__init__.py @@ -4,7 +4,7 @@ try: # Save student's sys.path - path = sys.path[:] + _path = sys.path[:] # In case student has files that shadow packages sys.path = [p for p in sys.path if p not in ("", os.getcwd())] @@ -25,4 +25,4 @@ finally: # Restore student's sys.path (just in case library raised an exception that caller caught) - sys.path = path + sys.path = _path diff --git a/src/cs50/flask.py b/src/cs50/flask.py index c7ed023..d4a063c 100644 --- a/src/cs50/flask.py +++ b/src/cs50/flask.py @@ -1,3 +1,5 @@ +import logging + from distutils.version import StrictVersion from os import getenv from pkg_resources import get_distribution @@ -8,8 +10,8 @@ try: # Only patch >= 1.0 - version = StrictVersion(get_distribution("flask").version) - assert version >= StrictVersion("1.0") + _version = StrictVersion(get_distribution("flask").version) + assert _version >= StrictVersion("1.0") # Reformat logger's exceptions # http://flask.pocoo.org/docs/1.0/logging/ @@ -17,8 +19,28 @@ try: import flask.logging flask.logging.default_handler.formatter.formatException = lambda exc_info: formatException(*exc_info) - except: + except Exception: + pass + + # Enable logging when Flask is in use, + # monkey-patching own SQL module, which shouldn't need to know about Flask + logging.getLogger("cs50").disabled = True + try: + import flask + from .sql import SQL + except ImportError: pass + else: + _before = SQL.execute + def _after(*args, **kwargs): + disabled = logging.getLogger("cs50").disabled + if flask.current_app: + logging.getLogger("cs50").disabled = False + try: + return _before(*args, **kwargs) + finally: + logging.getLogger("cs50").disabled = disabled + SQL.execute = _after # Add support for Cloud9 proxy so that flask.redirect doesn't redirect from HTTPS to HTTP # http://stackoverflow.com/a/23504684/5156190 @@ -26,13 +48,13 @@ try: import flask from werkzeug.contrib.fixers import ProxyFix - before = flask.Flask.__init__ - def after(self, *args, **kwargs): - before(self, *args, **kwargs) + _before = flask.Flask.__init__ + def _after(*args, **kwargs): + _before(*args, **kwargs) self.wsgi_app = ProxyFix(self.wsgi_app) - flask.Flask.__init__ = after + flask.Flask.__init__ = _after except: pass -except: +except Exception: pass diff --git a/src/cs50/sql.py b/src/cs50/sql.py index 113e2ae..9536197 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -25,6 +25,9 @@ def __init__(self, url, **kwargs): http://docs.sqlalchemy.org/en/latest/dialects/index.html """ + # Get logger + self._logger = logging.getLogger("cs50") + # Require that file already exist for SQLite matches = re.search(r"^sqlite:///(.+)$", url) if matches: @@ -41,116 +44,176 @@ def __init__(self, url, **kwargs): # Enable foreign key constraints if foreign_keys: - sqlalchemy.event.listen(self.engine, "connect", _connect) + def connect(dbapi_connection, connection_record): + if type(dbapi_connection) is sqlite3.Connection: # If back end is sqlite + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + sqlalchemy.event.listen(self.engine, "connect", connect) else: # Create engine, raising exception if back end's module not installed self.engine = sqlalchemy.create_engine(url, **kwargs) - # Log statements to standard error logging.basicConfig(level=logging.DEBUG) - self.logger = logging.getLogger("cs50") - disabled = self.logger.disabled # Test database try: - self.logger.disabled = True + disabled = self._logger.disabled + self._logger.disabled = True self.execute("SELECT 1") except sqlalchemy.exc.OperationalError as e: - e = RuntimeError(self._parse(e)) + e = RuntimeError(_parse_exception(e)) e.__cause__ = None raise e - else: - self.logger.disabled = disabled + finally: + self._logger.disabled = disabled - def _parse(self, e): - """Parses an exception, returns its message.""" + def execute(self, sql, *args, **kwargs): + """Execute a SQL statement.""" - # MySQL - matches = re.search(r"^\(_mysql_exceptions\.OperationalError\) \(\d+, \"(.+)\"\)$", str(e)) - if matches: - return matches.group(1) + # Allow only one statement at a time, since SQLite doesn't support multiple + # https://docs.python.org/3/library/sqlite3.html#sqlite3.Cursor.execute + statements = sqlparse.parse(sql) + if len(statements) > 1: + raise RuntimeError("too many statements at once") + elif len(statements) == 0: + raise RuntimeError("missing statement") - # PostgreSQL - matches = re.search(r"^\(psycopg2\.OperationalError\) (.+)$", str(e)) - if matches: - return matches.group(1) + # Ensure named and positional parameters are mutually exclusive + if len(args) > 0 and len(kwargs) > 0: + raise RuntimeError("cannot pass both named and positional parameters") - # SQLite - matches = re.search(r"^\(sqlite3\.OperationalError\) (.+)$", str(e)) - if matches: - return matches.group(1) + # Flatten statement + tokens = list(statements[0].flatten()) - # Default - return str(e) + # Validate paramstyle + placeholders = {} + paramstyle = None + for index, token in enumerate(tokens): - def execute(self, text, **params): - """Execute a SQL statement.""" + # If token is a placeholder + if token.ttype == sqlparse.tokens.Name.Placeholder: + + # Determine paramstyle, name + _paramstyle, name = _parse_placeholder(token) - class UserDefinedType(sqlalchemy.TypeDecorator): - """Add support for expandable values, a la https://github.com/sqlalchemy/sqlalchemy/issues/3953.""" + # Remember paramstyle + if not paramstyle: + paramstyle = _paramstyle - # Required class-level attribute - # https://docs.sqlalchemy.org/en/latest/core/custom_types.html#sqlalchemy.types.TypeDecorator - impl = sqlalchemy.types.UserDefinedType + # Ensure paramstyle is consistent + elif _paramstyle != paramstyle: + raise RuntimeError("inconsistent paramstyle") - def process_literal_param(self, value, dialect): - """Receive a literal parameter value to be rendered inline within a statement.""" + # Remember placeholder's index, name + placeholders[index] = name - def process(value): - """Render a literal value, escaping as needed.""" + # If more placeholders than arguments + if len(args) == 1 and len(placeholders) > 1: - # bool - if type(value) is bool: - return sqlalchemy.types.Boolean().literal_processor(dialect)(value) + # If user passed args as list or tuple, explode values into args + if isinstance(args[0], (list, tuple)): + args = args[0] - # datetime.date - elif type(value) is datetime.date: - return sqlalchemy.types.String().literal_processor(dialect)(value.strftime("%Y-%m-%d")) + # If user passed kwargs as dict, migrate values from args to kwargs + elif len(kwargs) == 0 and isinstance(args[0], dict): + kwargs = args[0] + args = [] - # datetime.datetime - elif type(value) is datetime.datetime: - return sqlalchemy.types.String().literal_processor(dialect)(value.strftime("%Y-%m-%d %H:%M:%S")) + # If no placeholders + if not paramstyle: - # datetime.time - elif type(value) is datetime.time: - return sqlalchemy.types.String().literal_processor(dialect)(value.strftime("%H:%M:%S")) + # Error-check like qmark if args + if args: + paramstyle = "qmark" - # float - elif type(value) is float: - return sqlalchemy.types.Float().literal_processor(dialect)(value) + # Error-check like named if kwargs + elif kwargs: + paramstyle = "named" - # int - elif type(value) is int: - return sqlalchemy.types.Integer().literal_processor(dialect)(value) + # In case of errors + _placeholders = ", ".join([str(tokens[index]) for index in placeholders]) + _args = ", ".join([str(self._escape(arg)) for arg in args]) - # long - elif sys.version_info.major != 3 and type(value) is long: - return sqlalchemy.types.Integer().literal_processor(dialect)(value) + # qmark + if paramstyle == "qmark": - # str - elif type(value) is str: - return sqlalchemy.types.String().literal_processor(dialect)(value) + # Validate number of placeholders + if len(placeholders) != len(args): + if len(placeholders) < len(args): + raise RuntimeError("fewer placeholders ({}) than values ({})".format(_placeholders, _args)) + else: + raise RuntimeError("more placeholders ({}) than values ({})".format(_placeholders, _args)) + + # Escape values + for i, index in enumerate(placeholders.keys()): + tokens[index] = self._escape(args[i]) + + # numeric + elif paramstyle == "numeric": + + # Escape values + for index, i in placeholders.items(): + if i >= len(args): + raise RuntimeError("missing value for placeholder (:{})".format(i + 1, len(args))) + tokens[index] = self._escape(args[i]) + + # Check if any values unused + indices = set(range(len(args))) - set(placeholders.values()) + if indices: + raise RuntimeError("unused {} ({})".format( + "value" if len(indices) == 1 else "values", + ", ".join([str(self._escape(args[index])) for index in indices]))) + + # named + elif paramstyle == "named": + + # Escape values + for index, name in placeholders.items(): + if name not in kwargs: + raise RuntimeError("missing value for placeholder (:{})".format(name)) + tokens[index] = self._escape(kwargs[name]) + + # Check if any keys unused + keys = kwargs.keys() - placeholders.values() + if keys: + raise RuntimeError("unused values ({})".format(", ".join(keys))) + + # format + elif paramstyle == "format": + + # Validate number of placeholders + if len(placeholders) != len(args): + if len(placeholders) < len(args): + raise RuntimeError("fewer placeholders ({}) than values ({})".format(_placeholders, _args)) + else: + raise RuntimeError("more placeholders ({}) than values ({})".format(_placeholders, _args)) - # None - elif type(value) is sqlalchemy.sql.elements.Null: - return sqlalchemy.types.NullType().literal_processor(dialect)(value) + # Escape values + for i, index in enumerate(placeholders.keys()): + tokens[index] = self._escape(args[i]) - # Unsupported value - raise RuntimeError("unsupported value") + # pyformat + elif paramstyle == "pyformat": - # Process value(s), separating with commas as needed - if type(value) is list: - return ", ".join([process(v) for v in value]) - else: - return process(value) + # Escape values + for index, name in placeholders.items(): + if name not in kwargs: + raise RuntimeError("missing value for placeholder (%{}s)".format(name)) + tokens[index] = self._escape(kwargs[name]) - # Allow only one statement at a time, since SQLite doesn't support multiple - # https://docs.python.org/3/library/sqlite3.html#sqlite3.Cursor.execute - if len(sqlparse.split(text)) > 1: - raise RuntimeError("too many statements at once") + # Check if any keys unused + keys = kwargs.keys() - placeholders.values() + if keys: + raise RuntimeError("unused {} ({})".format( + "value" if len(keys) == 1 else "values", + ", ".join(keys))) + + # Join tokens into statement + statement = "".join([str(token) for token in tokens]) # Raise exceptions for warnings warnings.filterwarnings("error") @@ -158,85 +221,182 @@ def process(value): # Prepare, execute statement try: - # Construct a new TextClause clause - statement = sqlalchemy.text(text) - - # Iterate over parameters - for key, value in params.items(): - - # Translate None to NULL - if value is None: - value = sqlalchemy.sql.null() - - # Bind parameters before statement reaches database, so that bound parameters appear in exceptions - # http://docs.sqlalchemy.org/en/latest/core/sqlelement.html#sqlalchemy.sql.expression.text - statement = statement.bindparams(sqlalchemy.bindparam( - key, value=value, type_=UserDefinedType())) - - # Stringify bound parameters - # http://docs.sqlalchemy.org/en/latest/faq/sqlexpressions.html#how-do-i-render-sql-expressions-as-strings-possibly-with-bound-parameters-inlined - statement = str(statement.compile(compile_kwargs={"literal_binds": True})) - - # Statement for logging - log = re.sub(r"\n\s*", " ", sqlparse.format(statement, reindent=True)) - # Execute statement - result = self.engine.execute(statement) - - # If SELECT (or INSERT with RETURNING), return result set as list of dict objects - if re.search(r"^\s*SELECT", statement, re.I): - - # Coerce any decimal.Decimal objects to float objects - # https://groups.google.com/d/msg/sqlalchemy/0qXMYJvq8SA/oqtvMD9Uw-kJ - rows = [dict(row) for row in result.fetchall()] - for row in rows: - for column in row: - if type(row[column]) is decimal.Decimal: - row[column] = float(row[column]) - ret = rows - - # If INSERT, return primary key value for a newly inserted row - elif re.search(r"^\s*INSERT", statement, re.I): - if self.engine.url.get_backend_name() in ["postgres", "postgresql"]: - result = self.engine.execute(sqlalchemy.text("SELECT LASTVAL()")) - ret = result.first()[0] - else: - ret = result.lastrowid - - # If DELETE or UPDATE, return number of rows matched - elif re.search(r"^\s*(?:DELETE|UPDATE)", statement, re.I): - ret = result.rowcount - - # If some other statement, return True unless exception - else: - ret = True + result = self.engine.execute(sqlalchemy.text(statement)) + + # Return value + ret = True + if tokens[0].ttype == sqlparse.tokens.Keyword.DML: + + # Uppercase token's value + value = tokens[0].value.upper() + + # If SELECT, return result set as list of dict objects + if value == "SELECT": + + # Coerce any decimal.Decimal objects to float objects + # https://groups.google.com/d/msg/sqlalchemy/0qXMYJvq8SA/oqtvMD9Uw-kJ + rows = [dict(row) for row in result.fetchall()] + for row in rows: + for column in row: + if type(row[column]) is decimal.Decimal: + row[column] = float(row[column]) + ret = rows + + # If INSERT, return primary key value for a newly inserted row + elif value == "INSERT": + if self.engine.url.get_backend_name() in ["postgres", "postgresql"]: + result = self.engine.execute("SELECT LASTVAL()") + ret = result.first()[0] + else: + ret = result.lastrowid + + # If DELETE or UPDATE, return number of rows matched + elif value in ["DELETE", "UPDATE"]: + ret = result.rowcount # If constraint violated, return None except sqlalchemy.exc.IntegrityError: - self.logger.debug(termcolor.colored(log, "yellow")) + self._logger.debug(termcolor.colored(statement, "yellow")) return None # If user errror except sqlalchemy.exc.OperationalError as e: - self.logger.debug(termcolor.colored(log, "red")) - e = RuntimeError(self._parse(e)) + self._logger.debug(termcolor.colored(statement, "red")) + e = RuntimeError(_parse_exception(e)) e.__cause__ = None raise e # Return value else: - self.logger.debug(termcolor.colored(log, "green")) + self._logger.debug(termcolor.colored(statement, "green")) return ret + def _escape(self, value): + """ + Escapes value using engine's conversion function. + + https://docs.sqlalchemy.org/en/latest/core/type_api.html#sqlalchemy.types.TypeEngine.literal_processor + """ + + def __escape(value): + + # bool + if type(value) is bool: + return sqlparse.sql.Token( + sqlparse.tokens.Number, + sqlalchemy.types.Boolean().literal_processor(self.engine.dialect)(value)) + + # bytearray, bytes + elif type(value) in [bytearray, bytes]: + raise RuntimeError("unsupported value") # TODO + + # datetime.date + elif type(value) is datetime.date: + return sqlparse.sql.Token( + sqlparse.tokens.String, + sqlalchemy.types.String().literal_processor(self.engine.dialect)(value.strftime("%Y-%m-%d"))) + + # datetime.datetime + elif type(value) is datetime.datetime: + return sqlparse.sql.Token( + sqlparse.tokens.String, + sqlalchemy.types.String().literal_processor(self.engine.dialect)(value.strftime("%Y-%m-%d %H:%M:%S"))) + + # datetime.time + elif type(value) is datetime.time: + return sqlparse.sql.Token( + sqlparse.tokens.String, + sqlalchemy.types.String().literal_processor(self.engine.dialect)(value.strftime("%H:%M:%S"))) + + # float + elif type(value) is float: + return sqlparse.sql.Token( + sqlparse.tokens.Number, + sqlalchemy.types.Float().literal_processor(self.engine.dialect)(value)) + + # int + elif type(value) is int: + return sqlparse.sql.Token( + sqlparse.tokens.Number, + sqlalchemy.types.Integer().literal_processor(self.engine.dialect)(value)) + + # str + elif type(value) is str: + return sqlparse.sql.Token( + sqlparse.tokens.String, + sqlalchemy.types.String().literal_processor(self.engine.dialect)(value)) + + # None + elif value is None: + return sqlparse.sql.Token( + sqlparse.tokens.Keyword, + sqlalchemy.types.NullType().literal_processor(self.engine.dialect)(value)) + + # Unsupported value + else: + raise RuntimeError("unsupported value: {}".format(value)) + + # Escape value(s), separating with commas as needed + if type(value) in [list, tuple]: + return sqlparse.sql.TokenList(sqlparse.parse(", ".join([str(__escape(v)) for v in value]))) + else: + return sqlparse.sql.Token( + sqlparse.tokens.String, + __escape(value)) + + +def _parse_exception(e): + """Parses an exception, returns its message.""" + + # MySQL + matches = re.search(r"^\(_mysql_exceptions\.OperationalError\) \(\d+, \"(.+)\"\)$", str(e)) + if matches: + return matches.group(1) + + # PostgreSQL + matches = re.search(r"^\(psycopg2\.OperationalError\) (.+)$", str(e)) + if matches: + return matches.group(1) + + # SQLite + matches = re.search(r"^\(sqlite3\.OperationalError\) (.+)$", str(e)) + if matches: + return matches.group(1) + + # Default + return str(e) + + +def _parse_placeholder(token): + """Infers paramstyle, name from sqlparse.tokens.Name.Placeholder.""" + + # Validate token + if not isinstance(token, sqlparse.sql.Token) or token.ttype != sqlparse.tokens.Name.Placeholder: + raise TypeError() + + # qmark + if token.value == "?": + return "qmark", None + + # numeric + matches = re.search(r"^:([1-9]\d*)$", token.value) + if matches: + return "numeric", int(matches.group(1)) - 1 + + # named + matches = re.search(r"^:([a-zA-Z]\w*)$", token.value) + if matches: + return "named", matches.group(1) -# http://docs.sqlalchemy.org/en/latest/dialects/sqlite.html#foreign-key-support -def _connect(dbapi_connection, connection_record): - """Enables foreign key support.""" + # format + if token.value == "%s": + return "format", None - # If back end is sqlite - if type(dbapi_connection) is sqlite3.Connection: + # pyformat + matches = re.search(r"%\((\w+)\)s$", token.value) + if matches: + return "pyformat", matches.group(1) - # Respect foreign key constraints by default - cursor = dbapi_connection.cursor() - cursor.execute("PRAGMA foreign_keys=ON") - cursor.close() + # Invalid + raise RuntimeError("{}: invalid placeholder".format(token.value)) diff --git a/tests/flask/application.py b/tests/flask/application.py index db7724c..939a8f9 100644 --- a/tests/flask/application.py +++ b/tests/flask/application.py @@ -5,12 +5,18 @@ sys.path.insert(0, "../../src") import cs50 +import cs50.flask app = Flask(__name__) +db = cs50.SQL("sqlite:///../sqlite.db") + @app.route("/") def index(): + db.execute("SELECT 1") + """ def f(): res = requests.get("cs50.harvard.edu") f() + """ return render_template("index.html") diff --git a/tests/python3.py b/tests/python.py similarity index 100% rename from tests/python3.py rename to tests/python.py diff --git a/tests/python2.py b/tests/python2.py deleted file mode 100644 index c132c20..0000000 --- a/tests/python2.py +++ /dev/null @@ -1,8 +0,0 @@ -import sys - -sys.path.insert(0, "../src") - -import cs50 - -l = cs50.get_long() -print(l) diff --git a/tests/sql.py b/tests/sql.py index 67b4e94..64f8e76 100644 --- a/tests/sql.py +++ b/tests/sql.py @@ -106,21 +106,159 @@ def setUp(self): class SQLiteTests(SQLTests): @classmethod def setUpClass(self): + open("test.db", "w").close() self.db = SQL("sqlite:///test.db") + open("test1.db", "w").close() self.db1 = SQL("sqlite:///test1.db", foreign_keys=True) def setUp(self): + self.db.execute("DROP TABLE IF EXISTS cs50") self.db.execute("CREATE TABLE cs50(id INTEGER PRIMARY KEY, val TEXT)") def test_foreign_key_support(self): + self.db.execute("DROP TABLE IF EXISTS foo") self.db.execute("CREATE TABLE foo(id INTEGER PRIMARY KEY)") + self.db.execute("DROP TABLE IF EXISTS bar") self.db.execute("CREATE TABLE bar(foo_id INTEGER, FOREIGN KEY (foo_id) REFERENCES foo(id))") self.assertEqual(self.db.execute("INSERT INTO bar VALUES(50)"), 1) + self.db1.execute("DROP TABLE IF EXISTS foo") self.db1.execute("CREATE TABLE foo(id INTEGER PRIMARY KEY)") + self.db1.execute("DROP TABLE IF EXISTS bar") self.db1.execute("CREATE TABLE bar(foo_id INTEGER, FOREIGN KEY (foo_id) REFERENCES foo(id))") self.assertEqual(self.db1.execute("INSERT INTO bar VALUES(50)"), None) + + def test_qmark(self): + self.db.execute("DROP TABLE IF EXISTS foo") + self.db.execute("CREATE TABLE foo (firstname STRING, lastname STRING)") + + self.db.execute("INSERT INTO foo VALUES (?, 'bar')", "baz") + self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "baz", "lastname": "bar"}]) + self.db.execute("DELETE FROM foo") + + self.db.execute("INSERT INTO foo VALUES ('bar', ?)", "baz") + self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "bar", "lastname": "baz"}]) + self.db.execute("DELETE FROM foo") + + self.db.execute("INSERT INTO foo VALUES (?, ?)", "bar", "baz") + self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "bar", "lastname": "baz"}]) + + self.db.execute("INSERT INTO foo VALUES ('qux', 'quux')") + self.assertEqual(self.db.execute("SELECT * FROM foo WHERE firstname = ?", 'qux'), [{"firstname": "qux", "lastname": "quux"}]) + self.assertEqual(self.db.execute("SELECT * FROM foo WHERE firstname = ? AND lastname = ?", "qux", "quux"), [{"firstname": "qux", "lastname": "quux"}]) + self.assertEqual(self.db.execute("SELECT * FROM foo WHERE firstname = ? AND lastname = ?", ("qux", "quux")), [{"firstname": "qux", "lastname": "quux"}]) + self.assertEqual(self.db.execute("SELECT * FROM foo WHERE firstname = ? AND lastname = ?", ["qux", "quux"]), [{"firstname": "qux", "lastname": "quux"}]) + self.db.execute("DELETE FROM foo") + + self.db.execute("INSERT INTO foo VALUES (?, ?)", ("bar", "baz")) + self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "bar", "lastname": "baz"}]) + self.db.execute("DELETE FROM foo") + + self.db.execute("INSERT INTO foo VALUES (?, ?)", ["bar", "baz"]) + self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "bar", "lastname": "baz"}]) + self.db.execute("DELETE FROM foo") + + + self.db.execute("INSERT INTO foo VALUES (?,?)", "bar", "baz") + self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "bar", "lastname": "baz"}]) + self.db.execute("DELETE FROM foo") + + self.db.execute("DROP TABLE IF EXISTS bar") + self.db.execute("CREATE TABLE bar (firstname STRING)") + self.db.execute("INSERT INTO bar VALUES (?)", "baz") + self.assertEqual(self.db.execute("SELECT * FROM bar"), [{"firstname": "baz"}]) + + self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo VALUES (?)") + self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo VALUES (?, ?)") + # self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo VALUES (?)", ('bar', 'baz')) + # self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo VALUES (?)", ['bar', 'baz']) + self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo VALUES (?, ?)", 'bar', 'baz', 'qux') + self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo VALUES (?, ?)", ('bar', 'baz', 'qux')) + self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo VALUES (?, ?)", ['bar', 'baz', 'qux']) + self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo VALUES (?, ?)", 'bar', baz='baz') + + def test_named(self): + self.db.execute("DROP TABLE IF EXISTS foo") + self.db.execute("CREATE TABLE foo (firstname STRING, lastname STRING)") + + self.db.execute("INSERT INTO foo VALUES (:baz, 'bar')", baz="baz") + self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "baz", "lastname": "bar"}]) + self.db.execute("DELETE FROM foo") + + self.db.execute("INSERT INTO foo VALUES ('bar', :baz)", baz="baz") + self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "bar", "lastname": "baz"}]) + self.db.execute("DELETE FROM foo") + + self.db.execute("INSERT INTO foo VALUES (:bar, :baz)", bar="bar", baz="baz") + self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "bar", "lastname": "baz"}]) + + self.db.execute("INSERT INTO foo VALUES ('qux', 'quux')") + self.assertEqual(self.db.execute("SELECT * FROM foo WHERE firstname = :qux", qux='qux'), [{"firstname": "qux", "lastname": "quux"}]) + self.assertEqual(self.db.execute("SELECT * FROM foo WHERE firstname = :qux AND lastname = :quux", qux="qux", quux="quux"), [{"firstname": "qux", "lastname": "quux"}]) + self.db.execute("DELETE FROM foo") + + self.db.execute("INSERT INTO foo VALUES (:bar,:baz)", bar="bar", baz="baz") + self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "bar", "lastname": "baz"}]) + self.db.execute("DELETE FROM foo") + + self.db.execute("DROP TABLE IF EXISTS bar") + self.db.execute("CREATE TABLE bar (firstname STRING)") + self.db.execute("INSERT INTO bar VALUES (:baz)", baz="baz") + self.assertEqual(self.db.execute("SELECT * FROM bar"), [{"firstname": "baz"}]) + + self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo VALUES (:bar)") + self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo VALUES (:bar, :baz)") + self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo VALUES (:bar, :baz)", bar='bar', baz='baz', qux='qux') + self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo VALUES (:bar, :baz)", 'baz', bar='bar') + + + def test_numeric(self): + self.db.execute("DROP TABLE IF EXISTS foo") + self.db.execute("CREATE TABLE foo (firstname STRING, lastname STRING)") + + self.db.execute("INSERT INTO foo VALUES (:1, 'bar')", "baz") + self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "baz", "lastname": "bar"}]) + self.db.execute("DELETE FROM foo") + + self.db.execute("INSERT INTO foo VALUES ('bar', :1)", "baz") + self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "bar", "lastname": "baz"}]) + self.db.execute("DELETE FROM foo") + + self.db.execute("INSERT INTO foo VALUES (:1, :2)", "bar", "baz") + self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "bar", "lastname": "baz"}]) + + self.db.execute("INSERT INTO foo VALUES ('qux', 'quux')") + self.assertEqual(self.db.execute("SELECT * FROM foo WHERE firstname = :1", 'qux'), [{"firstname": "qux", "lastname": "quux"}]) + self.assertEqual(self.db.execute("SELECT * FROM foo WHERE firstname = :1 AND lastname = :2", "qux", "quux"), [{"firstname": "qux", "lastname": "quux"}]) + self.assertEqual(self.db.execute("SELECT * FROM foo WHERE firstname = :1 AND lastname = :2", ("qux", "quux")), [{"firstname": "qux", "lastname": "quux"}]) + self.assertEqual(self.db.execute("SELECT * FROM foo WHERE firstname = :1 AND lastname = :2", ["qux", "quux"]), [{"firstname": "qux", "lastname": "quux"}]) + self.db.execute("DELETE FROM foo") + + self.db.execute("INSERT INTO foo VALUES (:1, :2)", ("bar", "baz")) + self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "bar", "lastname": "baz"}]) + self.db.execute("DELETE FROM foo") + + self.db.execute("INSERT INTO foo VALUES (:1, :2)", ["bar", "baz"]) + self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "bar", "lastname": "baz"}]) + self.db.execute("DELETE FROM foo") + + + self.db.execute("INSERT INTO foo VALUES (:1,:2)", "bar", "baz") + self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "bar", "lastname": "baz"}]) + self.db.execute("DELETE FROM foo") + + self.db.execute("DROP TABLE IF EXISTS bar") + self.db.execute("CREATE TABLE bar (firstname STRING)") + self.db.execute("INSERT INTO bar VALUES (:1)", "baz") + self.assertEqual(self.db.execute("SELECT * FROM bar"), [{"firstname": "baz"}]) + + self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo VALUES (:1)") + self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo VALUES (:1, :2)") + self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo VALUES (:1, :2)", 'bar', 'baz', 'qux') + self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo VALUES (:1, :2)", 'bar', baz='baz') + + if __name__ == "__main__": suite = unittest.TestSuite([ unittest.TestLoader().loadTestsFromTestCase(SQLiteTests), diff --git a/tests/sqlite.py b/tests/sqlite.py index 50fc980..7a64278 100644 --- a/tests/sqlite.py +++ b/tests/sqlite.py @@ -1,8 +1,42 @@ +import logging import sys sys.path.insert(0, "../src") from cs50 import SQL +logging.getLogger("cs50").disabled = False + db = SQL("sqlite:///sqlite.db") db.execute("SELECT 1") + +# TODO +#db.execute("SELECT * FROM Employee WHERE FirstName = ?", b'\x00') + +db.execute("SELECT * FROM Employee WHERE FirstName = ?", "' OR 1 = 1") + +db.execute("SELECT * FROM Employee WHERE FirstName IN (?)", "Andrew") +db.execute("SELECT * FROM Employee WHERE FirstName IN (?)", ["Andrew"]) +db.execute("SELECT * FROM Employee WHERE FirstName IN (?)", ("Andrew",)) +db.execute("SELECT * FROM Employee WHERE FirstName IN (?)", ["Andrew", "Nancy"]) +db.execute("SELECT * FROM Employee WHERE FirstName IN (?)", ("Andrew", "Nancy")) +db.execute("SELECT * FROM Employee WHERE FirstName IN (?)", []) +db.execute("SELECT * FROM Employee WHERE FirstName IN (?)", ()) + +db.execute("SELECT * FROM Employee WHERE FirstName = ? AND LastName = ?", "Andrew", "Adams") +db.execute("SELECT * FROM Employee WHERE FirstName = ? AND LastName = ?", ["Andrew", "Adams"]) +db.execute("SELECT * FROM Employee WHERE FirstName = ? AND LastName = ?", ("Andrew", "Adams")) + +db.execute("SELECT * FROM Employee WHERE FirstName = :1 AND LastName = :2", "Andrew", "Adams") +db.execute("SELECT * FROM Employee WHERE FirstName = :1 AND LastName = :2", ["Andrew", "Adams"]) +db.execute("SELECT * FROM Employee WHERE FirstName = :1 AND LastName = :2", ("Andrew", "Adams")) + +db.execute("SELECT * FROM Employee WHERE FirstName = :first AND LastName = :last", first="Andrew", last="Adams") +db.execute("SELECT * FROM Employee WHERE FirstName = :first AND LastName = :last", {"first": "Andrew", "last": "Adams"}) + +db.execute("SELECT * FROM Employee WHERE FirstName = %s AND LastName = %s", "Andrew", "Adams") +db.execute("SELECT * FROM Employee WHERE FirstName = %s AND LastName = %s", ["Andrew", "Adams"]) +db.execute("SELECT * FROM Employee WHERE FirstName = %s AND LastName = %s", ("Andrew", "Adams")) + +db.execute("SELECT * FROM Employee WHERE FirstName = %(first)s AND LastName = %(last)s", first="Andrew", last="Adams") +db.execute("SELECT * FROM Employee WHERE FirstName = %(first)s AND LastName = %(last)s", {"first": "Andrew", "last": "Adams"})