diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index f438d7b..7fcb507 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -20,15 +20,15 @@ jobs: ports: - 5432:5432 steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 with: - python-version: '3.7' + python-version: '3.12' check-latest: true - name: Setup databases run: | pip install . - pip install mysqlclient psycopg2-binary + pip install mysqlclient psycopg2-binary SQLAlchemy - name: Run tests run: python tests/sql.py @@ -56,7 +56,7 @@ jobs: - name: Create Release if: ${{ github.ref == 'refs/heads/main' }} - uses: actions/github-script@v6 + uses: actions/github-script@v7 with: github-token: ${{ github.token }} script: | diff --git a/.gitignore b/.gitignore index 4286ed6..dd3ffcc 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ .* !/.github/ !.gitignore +build/ *.db *.egg-info/ *.pyc diff --git a/README.md b/README.md index c94bd1f..a9033c6 100644 --- a/README.md +++ b/README.md @@ -39,50 +39,3 @@ s = cs50.get_string(); ``` python tests/sql.py ``` - -### Sample Tests - -``` -import cs50 -db = cs50.SQL("sqlite:///foo.db") -db.execute("CREATE TABLE IF NOT EXISTS cs50 (id INTEGER PRIMARY KEY, val TEXT, bin BLOB)") -db.execute("INSERT INTO cs50 (val) VALUES('a')") -db.execute("INSERT INTO cs50 (val) VALUES('b')") -db.execute("BEGIN") -db.execute("INSERT INTO cs50 (val) VALUES('c')") -db.execute("INSERT INTO cs50 (val) VALUES('x')") -db.execute("INSERT INTO cs50 (val) VALUES('y')") -db.execute("ROLLBACK") -db.execute("INSERT INTO cs50 (val) VALUES('z')") -db.execute("COMMIT") - ---- - -import cs50 -db = cs50.SQL("mysql://root@localhost/test") -db.execute("CREATE TABLE IF NOT EXISTS cs50 (id INTEGER PRIMARY KEY, val TEXT, bin BLOB)") -db.execute("INSERT INTO cs50 (val) VALUES('a')") -db.execute("INSERT INTO cs50 (val) VALUES('b')") -db.execute("BEGIN") -db.execute("INSERT INTO cs50 (val) VALUES('c')") -db.execute("INSERT INTO cs50 (val) VALUES('x')") -db.execute("INSERT INTO cs50 (val) VALUES('y')") -db.execute("ROLLBACK") -db.execute("INSERT INTO cs50 (val) VALUES('z')") -db.execute("COMMIT") - ---- - -import cs50 -db = cs50.SQL("postgresql://postgres@localhost/test") -db.execute("CREATE TABLE IF NOT EXISTS cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16), bin BYTEA)") -db.execute("INSERT INTO cs50 (val) VALUES('a')") -db.execute("INSERT INTO cs50 (val) VALUES('b')") -db.execute("BEGIN") -db.execute("INSERT INTO cs50 (val) VALUES('c')") -db.execute("INSERT INTO cs50 (val) VALUES('x')") -db.execute("INSERT INTO cs50 (val) VALUES('y')") -db.execute("ROLLBACK") -db.execute("INSERT INTO cs50 (val) VALUES('z')") -db.execute("COMMIT") -``` diff --git a/docker-compose.yml b/docker-compose.yml index f795750..8608080 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -7,7 +7,7 @@ services: - postgres environment: MYSQL_HOST: mysql - POSTGRESQL_HOST: postgresql + POSTGRESQL_HOST: postgres links: - mysql - postgres @@ -20,7 +20,7 @@ services: MYSQL_ALLOW_EMPTY_PASSWORD: yes healthcheck: test: ["CMD", "mysqladmin", "-uroot", "ping"] - image: cs50/mysql:8 + image: cs50/mysql ports: - 3306:3306 postgres: diff --git a/setup.py b/setup.py index 2d90788..1817b95 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ "Topic :: Software Development :: Libraries :: Python Modules" ], description="CS50 library for Python", - install_requires=["Flask>=1.0", "SQLAlchemy", "sqlparse", "termcolor", "wheel"], + install_requires=["Flask>=1.0", "packaging", "SQLAlchemy>=2,<3", "sqlparse", "termcolor", "wheel"], keywords="cs50", license="GPLv3", long_description_content_type="text/markdown", @@ -18,5 +18,5 @@ package_dir={"": "src"}, packages=["cs50"], url="https://github.com/cs50/python-cs50", - version="9.2.3" + version="9.4.0" ) diff --git a/src/cs50/__init__.py b/src/cs50/__init__.py index aaec161..7dd4e17 100644 --- a/src/cs50/__init__.py +++ b/src/cs50/__init__.py @@ -8,6 +8,7 @@ # Import cs50_* from .cs50 import get_char, get_float, get_int, get_string + try: from .cs50 import get_long except ImportError: diff --git a/src/cs50/cs50.py b/src/cs50/cs50.py index 1d7b6ea..f331a88 100644 --- a/src/cs50/cs50.py +++ b/src/cs50/cs50.py @@ -6,7 +6,6 @@ import re import sys -from distutils.sysconfig import get_python_lib from os.path import abspath, join from termcolor import colored from traceback import format_exception @@ -18,7 +17,9 @@ try: # Patch formatException - logging.root.handlers[0].formatter.formatException = lambda exc_info: _formatException(*exc_info) + logging.root.handlers[ + 0 + ].formatter.formatException = lambda exc_info: _formatException(*exc_info) except IndexError: pass @@ -38,26 +39,31 @@ _logger.addHandler(handler) -class _flushfile(): +class _Unbuffered: """ Disable buffering for standard output and standard error. - http://stackoverflow.com/a/231216 + https://stackoverflow.com/a/107717 + https://docs.python.org/3/library/io.html """ - def __init__(self, f): - self.f = f + def __init__(self, stream): + self.stream = stream - def __getattr__(self, name): - return object.__getattribute__(self.f, name) + def __getattr__(self, attr): + return getattr(self.stream, attr) - def write(self, x): - self.f.write(x) - self.f.flush() + def write(self, b): + self.stream.write(b) + self.stream.flush() + def writelines(self, lines): + self.stream.writelines(lines) + self.stream.flush() -sys.stderr = _flushfile(sys.stderr) -sys.stdout = _flushfile(sys.stdout) + +sys.stderr = _Unbuffered(sys.stderr) +sys.stdout = _Unbuffered(sys.stdout) def _formatException(type, value, tb): @@ -79,19 +85,29 @@ def _formatException(type, value, tb): lines += line else: matches = re.search(r"^(\s*)(.*?)(\s*)$", line, re.DOTALL) - lines.append(matches.group(1) + colored(matches.group(2), "yellow") + matches.group(3)) + lines.append( + matches.group(1) + + colored(matches.group(2), "yellow") + + matches.group(3) + ) return "".join(lines).rstrip() -sys.excepthook = lambda type, value, tb: print(_formatException(type, value, tb), file=sys.stderr) +sys.excepthook = lambda type, value, tb: print( + _formatException(type, value, tb), file=sys.stderr +) def eprint(*args, **kwargs): - raise RuntimeError("The CS50 Library for Python no longer supports eprint, but you can use print instead!") + raise RuntimeError( + "The CS50 Library for Python no longer supports eprint, but you can use print instead!" + ) def get_char(prompt): - raise RuntimeError("The CS50 Library for Python no longer supports get_char, but you can use get_string instead!") + raise RuntimeError( + "The CS50 Library for Python no longer supports get_char, but you can use get_string instead!" + ) def get_float(prompt): @@ -135,7 +151,7 @@ def get_string(prompt): as line endings. If user inputs only a line ending, returns "", not None. Returns None upon error or no input whatsoever (i.e., just EOF). """ - if type(prompt) is not str: + if not isinstance(prompt, str): raise TypeError("prompt must be of type str") try: return input(prompt) diff --git a/src/cs50/flask.py b/src/cs50/flask.py index 324ec30..6e38971 100644 --- a/src/cs50/flask.py +++ b/src/cs50/flask.py @@ -2,22 +2,31 @@ import pkgutil import sys + def _wrap_flask(f): if f is None: return - from distutils.version import StrictVersion + from packaging.version import Version, InvalidVersion from .cs50 import _formatException - if f.__version__ < StrictVersion("1.0"): + try: + if Version(f.__version__) < Version("1.0"): + return + except InvalidVersion: return if os.getenv("CS50_IDE_TYPE") == "online": from werkzeug.middleware.proxy_fix import ProxyFix + _flask_init_before = f.Flask.__init__ + def _flask_init_after(self, *args, **kwargs): _flask_init_before(self, *args, **kwargs) - self.wsgi_app = ProxyFix(self.wsgi_app, x_proto=1) # For HTTPS-to-HTTP proxy + self.wsgi_app = ProxyFix( + self.wsgi_app, x_proto=1 + ) # For HTTPS-to-HTTP proxy + f.Flask.__init__ = _flask_init_after @@ -27,7 +36,7 @@ def _flask_init_after(self, *args, **kwargs): # If Flask wasn't imported else: - flask_loader = pkgutil.get_loader('flask') + flask_loader = pkgutil.get_loader("flask") if flask_loader: _exec_module_before = flask_loader.exec_module diff --git a/src/cs50/sql.py b/src/cs50/sql.py index 8087657..81905bc 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -14,7 +14,6 @@ def _enable_logging(f): @functools.wraps(f) def decorator(*args, **kwargs): - # Infer whether Flask is installed try: import flask @@ -53,7 +52,7 @@ def __init__(self, url, **kwargs): import sqlalchemy import sqlalchemy.orm import threading - + # Temporary fix for missing sqlite3 module on the buildpack stack try: import sqlite3 @@ -71,19 +70,30 @@ def __init__(self, url, **kwargs): # Create engine, disabling SQLAlchemy's own autocommit mode raising exception if back end's module not installed; # without isolation_level, PostgreSQL warns with "there is already a transaction in progress" for our own BEGIN and # "there is no transaction in progress" for our own COMMIT - self._engine = sqlalchemy.create_engine(url, **kwargs).execution_options(autocommit=False, isolation_level="AUTOCOMMIT") + self._engine = sqlalchemy.create_engine(url, **kwargs).execution_options( + autocommit=False, isolation_level="AUTOCOMMIT", no_parameters=True + ) + + # Avoid doubly escaping percent signs, since no_parameters=True anyway + # https://github.com/cs50/python-cs50/issues/171 + self._engine.dialect.identifier_preparer._double_percents = False # Get logger self._logger = logging.getLogger("cs50") # Listener for connections def connect(dbapi_connection, connection_record): - # Enable foreign key constraints - if type(dbapi_connection) is sqlite3.Connection: # If back end is sqlite - cursor = dbapi_connection.cursor() - cursor.execute("PRAGMA foreign_keys=ON") - cursor.close() + try: + if isinstance( + dbapi_connection, sqlite3.Connection + ): # If back end is sqlite + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + except: + # Temporary fix for missing sqlite3 module on the buildpack stack + pass # Register listener sqlalchemy.event.listen(self._engine, "connect", connect) @@ -96,7 +106,7 @@ def connect(dbapi_connection, connection_record): self._logger.disabled = True try: connection = self._engine.connect() - connection.execute("SELECT 1") + connection.execute(sqlalchemy.text("SELECT 1")) connection.close() except sqlalchemy.exc.OperationalError as e: e = RuntimeError(_parse_exception(e)) @@ -145,15 +155,35 @@ def execute(self, sql, *args, **kwargs): if len(args) > 0 and len(kwargs) > 0: raise RuntimeError("cannot pass both positional and named parameters") - # Infer command from (unflattened) statement - for token in statements[0]: - if token.ttype in [sqlparse.tokens.Keyword, sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML]: - token_value = token.value.upper() - if token_value in ["BEGIN", "DELETE", "INSERT", "SELECT", "START", "UPDATE"]: - command = token_value - break - else: - command = None + # Infer command from flattened statement to a single string separated by spaces + full_statement = " ".join( + str(token) + for token in statements[0].tokens + if token.ttype + in [ + sqlparse.tokens.Keyword, + sqlparse.tokens.Keyword.DDL, + sqlparse.tokens.Keyword.DML, + ] + ) + full_statement = full_statement.upper() + + # Set of possible commands + commands = { + "BEGIN", + "CREATE VIEW", + "DELETE", + "INSERT", + "SELECT", + "START", + "UPDATE", + "VACUUM", + } + + # Check if the full_statement starts with any command + command = next( + (cmd for cmd in commands if full_statement.startswith(cmd)), None + ) # Flatten statement tokens = list(statements[0].flatten()) @@ -162,10 +192,8 @@ def execute(self, sql, *args, **kwargs): placeholders = {} paramstyle = None for index, token in enumerate(tokens): - # If token is a placeholder if token.ttype == sqlparse.tokens.Name.Placeholder: - # Determine paramstyle, name _paramstyle, name = _parse_placeholder(token) @@ -182,7 +210,6 @@ def execute(self, sql, *args, **kwargs): # If no placeholders if not paramstyle: - # Error-check like qmark if args if args: paramstyle = "qmark" @@ -197,13 +224,20 @@ def execute(self, sql, *args, **kwargs): # qmark if paramstyle == "qmark": - # Validate number of placeholders if len(placeholders) != len(args): if len(placeholders) < len(args): - raise RuntimeError("fewer placeholders ({}) than values ({})".format(_placeholders, _args)) + raise RuntimeError( + "fewer placeholders ({}) than values ({})".format( + _placeholders, _args + ) + ) else: - raise RuntimeError("more placeholders ({}) than values ({})".format(_placeholders, _args)) + raise RuntimeError( + "more placeholders ({}) than values ({})".format( + _placeholders, _args + ) + ) # Escape values for i, index in enumerate(placeholders.keys()): @@ -211,27 +245,34 @@ def execute(self, sql, *args, **kwargs): # numeric elif paramstyle == "numeric": - # Escape values for index, i in placeholders.items(): if i >= len(args): - raise RuntimeError("missing value for placeholder (:{})".format(i + 1, len(args))) + raise RuntimeError( + "missing value for placeholder (:{})".format(i + 1, len(args)) + ) tokens[index] = self._escape(args[i]) # Check if any values unused indices = set(range(len(args))) - set(placeholders.values()) if indices: - raise RuntimeError("unused {} ({})".format( - "value" if len(indices) == 1 else "values", - ", ".join([str(self._escape(args[index])) for index in indices]))) + raise RuntimeError( + "unused {} ({})".format( + "value" if len(indices) == 1 else "values", + ", ".join( + [str(self._escape(args[index])) for index in indices] + ), + ) + ) # named elif paramstyle == "named": - # Escape values for index, name in placeholders.items(): if name not in kwargs: - raise RuntimeError("missing value for placeholder (:{})".format(name)) + raise RuntimeError( + "missing value for placeholder (:{})".format(name) + ) tokens[index] = self._escape(kwargs[name]) # Check if any keys unused @@ -241,13 +282,20 @@ def execute(self, sql, *args, **kwargs): # format elif paramstyle == "format": - # Validate number of placeholders if len(placeholders) != len(args): if len(placeholders) < len(args): - raise RuntimeError("fewer placeholders ({}) than values ({})".format(_placeholders, _args)) + raise RuntimeError( + "fewer placeholders ({}) than values ({})".format( + _placeholders, _args + ) + ) else: - raise RuntimeError("more placeholders ({}) than values ({})".format(_placeholders, _args)) + raise RuntimeError( + "more placeholders ({}) than values ({})".format( + _placeholders, _args + ) + ) # Escape values for i, index in enumerate(placeholders.keys()): @@ -255,40 +303,44 @@ def execute(self, sql, *args, **kwargs): # pyformat elif paramstyle == "pyformat": - # Escape values for index, name in placeholders.items(): if name not in kwargs: - raise RuntimeError("missing value for placeholder (%{}s)".format(name)) + raise RuntimeError( + "missing value for placeholder (%{}s)".format(name) + ) tokens[index] = self._escape(kwargs[name]) # Check if any keys unused keys = kwargs.keys() - placeholders.values() if keys: - raise RuntimeError("unused {} ({})".format( - "value" if len(keys) == 1 else "values", - ", ".join(keys))) + raise RuntimeError( + "unused {} ({})".format( + "value" if len(keys) == 1 else "values", ", ".join(keys) + ) + ) # For SQL statements where a colon is required verbatim, as within an inline string, use a backslash to escape # https://docs.sqlalchemy.org/en/13/core/sqlelement.html?highlight=text#sqlalchemy.sql.expression.text for index, token in enumerate(tokens): - # In string literal # https://www.sqlite.org/lang_keywords.html - if token.ttype in [sqlparse.tokens.Literal.String, sqlparse.tokens.Literal.String.Single]: - token.value = re.sub("(^'|\s+):", r"\1\:", token.value) + if token.ttype in [ + sqlparse.tokens.Literal.String, + sqlparse.tokens.Literal.String.Single, + ]: + token.value = re.sub(r"(^'|\s+):", r"\1\:", token.value) # In identifier # https://www.sqlite.org/lang_keywords.html elif token.ttype == sqlparse.tokens.Literal.String.Symbol: - token.value = re.sub("(^\"|\s+):", r"\1\:", token.value) + token.value = re.sub(r'(^"|\s+):', r"\1\:", token.value) # Join tokens into statement statement = "".join([str(token) for token in tokens]) # If no connection yet if not hasattr(_data, self._name()): - # Connect to database setattr(_data, self._name(), self._engine.connect()) @@ -298,9 +350,12 @@ def execute(self, sql, *args, **kwargs): # Disconnect if/when a Flask app is torn down try: import flask + assert flask.current_app + def teardown_appcontext(exception): self._disconnect() + if teardown_appcontext not in flask.current_app.teardown_appcontext_funcs: flask.current_app.teardown_appcontext(teardown_appcontext) except (ModuleNotFoundError, AssertionError): @@ -308,18 +363,23 @@ def teardown_appcontext(exception): # Catch SQLAlchemy warnings with warnings.catch_warnings(): - # Raise exceptions for warnings warnings.simplefilter("error") # Prepare, execute statement try: - # Join tokens into statement, abbreviating binary data as <class 'bytes'> - _statement = "".join([str(bytes) if token.ttype == sqlparse.tokens.Other else str(token) for token in tokens]) + _statement = "".join( + [ + str(bytes) + if token.ttype == sqlparse.tokens.Other + else str(token) + for token in tokens + ] + ) # Check for start of transaction - if command in ["BEGIN", "START"]: + if command in ["BEGIN", "START", "VACUUM"]: # cannot VACUUM from within a transaction self._autocommit = False # Execute statement @@ -330,7 +390,7 @@ def teardown_appcontext(exception): connection.execute(sqlalchemy.text("COMMIT")) # Check for end of transaction - if command in ["COMMIT", "ROLLBACK"]: + if command in ["COMMIT", "ROLLBACK", "VACUUM"]: # cannot VACUUM from within a transaction self._autocommit = True # Return value @@ -338,19 +398,17 @@ def teardown_appcontext(exception): # If SELECT, return result set as list of dict objects if command == "SELECT": - # Coerce types - rows = [dict(row) for row in result.fetchall()] + rows = [dict(row) for row in result.mappings().all()] for row in rows: for column in row: - # Coerce decimal.Decimal objects to float objects # https://groups.google.com/d/msg/sqlalchemy/0qXMYJvq8SA/oqtvMD9Uw-kJ - if type(row[column]) is decimal.Decimal: + if isinstance(row[column], decimal.Decimal): row[column] = float(row[column]) # Coerce memoryview objects (as from PostgreSQL's bytea columns) to bytes - elif type(row[column]) is memoryview: + elif isinstance(row[column], memoryview): row[column] = bytes(row[column]) # Rows to be returned @@ -358,15 +416,15 @@ def teardown_appcontext(exception): # If INSERT, return primary key value for a newly inserted row (or None if none) elif command == "INSERT": - # If PostgreSQL if self._engine.url.get_backend_name() == "postgresql": - # Return LASTVAL() or NULL, avoiding # "(psycopg2.errors.ObjectNotInPrerequisiteState) lastval is not yet defined in this session", # a la https://stackoverflow.com/a/24186770/5156190; # cf. https://www.psycopg.org/docs/errors.html re 55000 - result = connection.execute(""" + result = connection.execute( + sqlalchemy.text( + """ CREATE OR REPLACE FUNCTION _LASTVAL() RETURNS integer LANGUAGE plpgsql AS $$ @@ -378,7 +436,9 @@ def teardown_appcontext(exception): END; END $$; SELECT _LASTVAL(); - """) + """ + ) + ) ret = result.first()[0] # If not PostgreSQL @@ -389,15 +449,24 @@ def teardown_appcontext(exception): elif command in ["DELETE", "UPDATE"]: ret = result.rowcount + # If CREATE VIEW, return True + elif command == "CREATE VIEW": + ret = True + # If constraint violated except sqlalchemy.exc.IntegrityError as e: + if self._autocommit: + connection.execute(sqlalchemy.text("ROLLBACK")) self._logger.error(termcolor.colored(_statement, "red")) e = ValueError(e.orig) e.__cause__ = None raise e # If user error - except (sqlalchemy.exc.OperationalError, sqlalchemy.exc.ProgrammingError) as e: + except ( + sqlalchemy.exc.OperationalError, + sqlalchemy.exc.ProgrammingError, + ) as e: self._disconnect() self._logger.error(termcolor.colored(_statement, "red")) e = RuntimeError(e.orig) @@ -422,75 +491,99 @@ def _escape(self, value): import sqlparse def __escape(value): - # Lazily import import datetime import sqlalchemy # bool - if type(value) is bool: + if isinstance(value, bool): return sqlparse.sql.Token( sqlparse.tokens.Number, - sqlalchemy.types.Boolean().literal_processor(self._engine.dialect)(value)) + sqlalchemy.types.Boolean().literal_processor(self._engine.dialect)( + value + ), + ) # bytes - elif type(value) is bytes: + elif isinstance(value, bytes): if self._engine.url.get_backend_name() in ["mysql", "sqlite"]: - return sqlparse.sql.Token(sqlparse.tokens.Other, f"x'{value.hex()}'") # https://dev.mysql.com/doc/refman/8.0/en/hexadecimal-literals.html + return sqlparse.sql.Token( + sqlparse.tokens.Other, f"x'{value.hex()}'" + ) # https://dev.mysql.com/doc/refman/8.0/en/hexadecimal-literals.html elif self._engine.url.get_backend_name() == "postgresql": - return sqlparse.sql.Token(sqlparse.tokens.Other, f"'\\x{value.hex()}'") # https://dba.stackexchange.com/a/203359 + return sqlparse.sql.Token( + sqlparse.tokens.Other, f"'\\x{value.hex()}'" + ) # https://dba.stackexchange.com/a/203359 else: raise RuntimeError("unsupported value: {}".format(value)) - # datetime.date - elif type(value) is datetime.date: + # datetime.datetime + elif isinstance(value, datetime.datetime): return sqlparse.sql.Token( sqlparse.tokens.String, - sqlalchemy.types.String().literal_processor(self._engine.dialect)(value.strftime("%Y-%m-%d"))) + sqlalchemy.types.String().literal_processor(self._engine.dialect)( + value.strftime("%Y-%m-%d %H:%M:%S") + ), + ) - # datetime.datetime - elif type(value) is datetime.datetime: + # datetime.date + elif isinstance(value, datetime.date): return sqlparse.sql.Token( sqlparse.tokens.String, - sqlalchemy.types.String().literal_processor(self._engine.dialect)(value.strftime("%Y-%m-%d %H:%M:%S"))) + sqlalchemy.types.String().literal_processor(self._engine.dialect)( + value.strftime("%Y-%m-%d") + ), + ) # datetime.time - elif type(value) is datetime.time: + elif isinstance(value, datetime.time): return sqlparse.sql.Token( sqlparse.tokens.String, - sqlalchemy.types.String().literal_processor(self._engine.dialect)(value.strftime("%H:%M:%S"))) + sqlalchemy.types.String().literal_processor(self._engine.dialect)( + value.strftime("%H:%M:%S") + ), + ) # float - elif type(value) is float: + elif isinstance(value, float): return sqlparse.sql.Token( sqlparse.tokens.Number, - sqlalchemy.types.Float().literal_processor(self._engine.dialect)(value)) + sqlalchemy.types.Float().literal_processor(self._engine.dialect)( + value + ), + ) # int - elif type(value) is int: + elif isinstance(value, int): return sqlparse.sql.Token( sqlparse.tokens.Number, - sqlalchemy.types.Integer().literal_processor(self._engine.dialect)(value)) + sqlalchemy.types.Integer().literal_processor(self._engine.dialect)( + value + ), + ) # str - elif type(value) is str: + elif isinstance(value, str): return sqlparse.sql.Token( sqlparse.tokens.String, - sqlalchemy.types.String().literal_processor(self._engine.dialect)(value)) + sqlalchemy.types.String().literal_processor(self._engine.dialect)( + value + ), + ) # None elif value is None: - return sqlparse.sql.Token( - sqlparse.tokens.Keyword, - sqlalchemy.null()) + return sqlparse.sql.Token(sqlparse.tokens.Keyword, sqlalchemy.null()) # Unsupported value else: raise RuntimeError("unsupported value: {}".format(value)) # Escape value(s), separating with commas as needed - if type(value) in [list, tuple]: - return sqlparse.sql.TokenList(sqlparse.parse(", ".join([str(__escape(v)) for v in value]))) + if isinstance(value, (list, tuple)): + return sqlparse.sql.TokenList( + sqlparse.parse(", ".join([str(__escape(v)) for v in value])) + ) else: return __escape(value) @@ -502,7 +595,9 @@ def _parse_exception(e): import re # MySQL - matches = re.search(r"^\(_mysql_exceptions\.OperationalError\) \(\d+, \"(.+)\"\)$", str(e)) + matches = re.search( + r"^\(_mysql_exceptions\.OperationalError\) \(\d+, \"(.+)\"\)$", str(e) + ) if matches: return matches.group(1) @@ -528,7 +623,10 @@ def _parse_placeholder(token): import sqlparse # Validate token - if not isinstance(token, sqlparse.sql.Token) or token.ttype != sqlparse.tokens.Name.Placeholder: + if ( + not isinstance(token, sqlparse.sql.Token) + or token.ttype != sqlparse.tokens.Name.Placeholder + ): raise TypeError() # qmark diff --git a/tests/sql.py b/tests/sql.py index 968f98b..bb37fd9 100644 --- a/tests/sql.py +++ b/tests/sql.py @@ -138,6 +138,12 @@ def test_lastrowid(self): self.assertEqual(self.db.execute("INSERT INTO foo (firstname, lastname) VALUES('firstname', 'lastname')"), 1) self.assertRaises(ValueError, self.db.execute, "INSERT INTO foo (id, firstname, lastname) VALUES(1, 'firstname', 'lastname')") + def test_url(self): + url = "https://www.amazon.es/Desesperaci%C3%B3n-BEST-SELLER-Stephen-King/dp/8497595890" + self.db.execute("CREATE TABLE foo(id SERIAL PRIMARY KEY, url TEXT)") + self.db.execute("INSERT INTO foo (url) VALUES(?)", url) + self.assertEqual(self.db.execute("SELECT url FROM foo")[0]["url"], url) + def tearDown(self): self.db.execute("DROP TABLE cs50") self.db.execute("DROP TABLE IF EXISTS foo") @@ -330,7 +336,7 @@ def test_cte(self): if __name__ == "__main__": suite = unittest.TestSuite([ unittest.TestLoader().loadTestsFromTestCase(SQLiteTests), - #unittest.TestLoader().loadTestsFromTestCase(MySQLTests), + unittest.TestLoader().loadTestsFromTestCase(MySQLTests), unittest.TestLoader().loadTestsFromTestCase(PostgresTests) ])