def _enable_logging(f):
    """Enable logging of SQL statements when Flask is in use."""

    import logging
    import functools

    @functools.wraps(f)
    def decorator(*args, **kwargs):

        # Infer whether Flask is installed
        try:
            import flask
        except ModuleNotFoundError:
            return f(*args, **kwargs)

        # Enable logging
        disabled = logging.getLogger("cs50").disabled
        if flask.current_app:
            logging.getLogger("cs50").disabled = False
        try:
            return f(*args, **kwargs)
        finally:
            logging.getLogger("cs50").disabled = disabled

    return decorator


class SQL(object):
    """Wrap SQLAlchemy to provide a simple SQL API."""

    def __init__(self, url, **kwargs):
        """
        Create instance of sqlalchemy.engine.Engine.

        URL should be a string that indicates database dialect and connection arguments.

        http://docs.sqlalchemy.org/en/latest/core/engines.html#sqlalchemy.create_engine
        http://docs.sqlalchemy.org/en/latest/dialects/index.html
        """

        # Lazily import
        import logging
        import os
        import re
        import sqlalchemy
        import sqlalchemy.orm
        import sqlite3

        # Get logger
        self._logger = logging.getLogger("cs50")

        # Require that file already exist for SQLite
        matches = re.search(r"^sqlite:///(.+)$", url)
        if matches:
            if not os.path.exists(matches.group(1)):
                raise RuntimeError("does not exist: {}".format(matches.group(1)))
            if not os.path.isfile(matches.group(1)):
                raise RuntimeError("not a file: {}".format(matches.group(1)))

        # Create engine, disabling SQLAlchemy's own autocommit mode, raising exception if back end's module not installed
        self._engine = sqlalchemy.create_engine(url, **kwargs).execution_options(autocommit=False)

        # Create a variable to hold the session. If None, autocommit is on.
        self._Session = sqlalchemy.orm.session.sessionmaker(bind=self._engine)
        self._session = None
        self._in_transaction = False

        # Listener for connections
        def connect(dbapi_connection, connection_record):

            # Disable underlying API's own emitting of BEGIN and COMMIT
            dbapi_connection.isolation_level = None

            # Enable foreign key constraints
            if type(dbapi_connection) is sqlite3.Connection:  # If back end is sqlite
                cursor = dbapi_connection.cursor()
                cursor.execute("PRAGMA foreign_keys=ON")
                cursor.close()

        # Register listener
        sqlalchemy.event.listen(self._engine, "connect", connect)

        # Log statements to standard error
        logging.basicConfig(level=logging.DEBUG)

        # Test database
        try:
            disabled = self._logger.disabled
            self._logger.disabled = True
            self.execute("SELECT 1")
        except sqlalchemy.exc.OperationalError as e:
            e = RuntimeError(_parse_exception(e))
            e.__cause__ = None
            raise e
        finally:
            self._logger.disabled = disabled

    def __del__(self):
        """Close database session and connection."""
        self._close_session()

    @_enable_logging
    def execute(self, sql, *args, **kwargs):
        """Execute a SQL statement."""

        # Lazily import
        import decimal
        import re
        import sqlalchemy
        import sqlparse
        import termcolor
        import warnings

        # Parse statement, stripping comments and then leading/trailing whitespace
        statements = sqlparse.parse(sqlparse.format(sql, strip_comments=True).strip())

        # Allow only one statement at a time, since SQLite doesn't support multiple
        # https://docs.python.org/3/library/sqlite3.html#sqlite3.Cursor.execute
        if len(statements) > 1:
            raise RuntimeError("too many statements at once")
        elif len(statements) == 0:
            raise RuntimeError("missing statement")

        # Ensure named and positional parameters are mutually exclusive
        if len(args) > 0 and len(kwargs) > 0:
            raise RuntimeError("cannot pass both named and positional parameters")

        # Infer command from (unflattened) statement
        for token in statements[0]:
            if token.ttype in [sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML]:
                command = token.value.upper()
                break

            # Begin a new session, if transaction started by caller (not using autocommit)
            elif token.value.upper() in ["BEGIN", "START"]:
                if self._in_transaction:
                    raise RuntimeError("transaction already open")

                self._in_transaction = True
        else:
            command = None

        # Flatten statement
        tokens = list(statements[0].flatten())

        # Validate paramstyle
        placeholders = {}
        paramstyle = None
        for index, token in enumerate(tokens):

            # If token is a placeholder
            if token.ttype == sqlparse.tokens.Name.Placeholder:

                # Determine paramstyle, name
                _paramstyle, name = _parse_placeholder(token)

                # Remember paramstyle
                if not paramstyle:
                    paramstyle = _paramstyle

                # Ensure paramstyle is consistent
                elif _paramstyle != paramstyle:
                    raise RuntimeError("inconsistent paramstyle")

                # Remember placeholder's index, name
                placeholders[index] = name

        # If more placeholders than arguments
        if len(args) == 1 and len(placeholders) > 1:

            # If user passed args as list or tuple, explode values into args
            if isinstance(args[0], (list, tuple)):
                args = args[0]

            # If user passed kwargs as dict, migrate values from args to kwargs
            elif len(kwargs) == 0 and isinstance(args[0], dict):
                kwargs = args[0]
                args = []

        # If no placeholders
        if not paramstyle:

            # Error-check like qmark if args
            if args:
                paramstyle = "qmark"

            # Error-check like named if kwargs
            elif kwargs:
                paramstyle = "named"

        # In case of errors
        _placeholders = ", ".join([str(tokens[index]) for index in placeholders])
        _args = ", ".join([str(self._escape(arg)) for arg in args])

        # qmark
        if paramstyle == "qmark":

            # Validate number of placeholders
            if len(placeholders) != len(args):
                if len(placeholders) < len(args):
                    raise RuntimeError("fewer placeholders ({}) than values ({})".format(_placeholders, _args))
                else:
                    raise RuntimeError("more placeholders ({}) than values ({})".format(_placeholders, _args))

            # Escape values
            for i, index in enumerate(placeholders.keys()):
                tokens[index] = self._escape(args[i])

        # numeric
        elif paramstyle == "numeric":

            # Escape values
            for index, i in placeholders.items():
                if i >= len(args):
                    raise RuntimeError("missing value for placeholder (:{})".format(i + 1, len(args)))
                tokens[index] = self._escape(args[i])

            # Check if any values unused
            indices = set(range(len(args))) - set(placeholders.values())
            if indices:
                raise RuntimeError("unused {} ({})".format(
                    "value" if len(indices) == 1 else "values",
                    ", ".join([str(self._escape(args[index])) for index in indices])))

        # named
        elif paramstyle == "named":

            # Escape values
            for index, name in placeholders.items():
                if name not in kwargs:
                    raise RuntimeError("missing value for placeholder (:{})".format(name))
                tokens[index] = self._escape(kwargs[name])

            # Check if any keys unused
            keys = kwargs.keys() - placeholders.values()
            if keys:
                raise RuntimeError("unused values ({})".format(", ".join(keys)))

        # format
        elif paramstyle == "format":

            # Validate number of placeholders
            if len(placeholders) != len(args):
                if len(placeholders) < len(args):
                    raise RuntimeError("fewer placeholders ({}) than values ({})".format(_placeholders, _args))
                else:
                    raise RuntimeError("more placeholders ({}) than values ({})".format(_placeholders, _args))

            # Escape values
            for i, index in enumerate(placeholders.keys()):
                tokens[index] = self._escape(args[i])

        # pyformat
        elif paramstyle == "pyformat":

            # Escape values
            for index, name in placeholders.items():
                if name not in kwargs:
                    raise RuntimeError("missing value for placeholder (%{}s)".format(name))
                tokens[index] = self._escape(kwargs[name])

            # Check if any keys unused
            keys = kwargs.keys() - placeholders.values()
            if keys:
                raise RuntimeError("unused {} ({})".format(
                    "value" if len(keys) == 1 else "values",
                    ", ".join(keys)))

        # For SQL statements where a colon is required verbatim, as within an inline string, use a backslash to escape
        # https://docs.sqlalchemy.org/en/13/core/sqlelement.html?highlight=text#sqlalchemy.sql.expression.text
        for index, token in enumerate(tokens):

            # In string literal
            # https://www.sqlite.org/lang_keywords.html
            if token.ttype in [sqlparse.tokens.Literal.String, sqlparse.tokens.Literal.String.Single]:
                token.value = re.sub("(^'|\s+):", r"\1\:", token.value)

            # In identifier
            # https://www.sqlite.org/lang_keywords.html
            elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
                token.value = re.sub("(^\"|\s+):", r"\1\:", token.value)

        # Join tokens into statement
        statement = "".join([str(token) for token in tokens])

        # Connect to database (for transactions' sake)
        if self._session is None:
            self._session = self._Session()

        # Set up a Flask app teardown function to close session at teardown
        try:

            # Infer whether Flask is installed
            import flask

            # Infer whether app is defined
            assert flask.current_app

            # Disconnect later - but only once
            if not hasattr(self, "_teardown_appcontext_added"):
                self._teardown_appcontext_added = True

                @flask.current_app.teardown_appcontext
                def shutdown_session(exception=None):
                    """Close any existing session on app context teardown."""
                    self._close_session()

        except (ModuleNotFoundError, AssertionError):
            pass

        # Catch SQLAlchemy warnings
        with warnings.catch_warnings():

            # Raise exceptions for warnings
            warnings.simplefilter("error")

            # Prepare, execute statement
            try:

                # Join tokens into statement, abbreviating binary data as <class 'bytes'>
                _statement = "".join([str(bytes) if token.ttype == sqlparse.tokens.Other else str(token) for token in tokens])

                # If COMMIT or ROLLBACK, turn on autocommit mode
                if command in ["COMMIT", "ROLLBACK"] and "TO" not in (token.value for token in tokens):
                    if not self._in_transaction:
                        raise RuntimeError("transactions must be initiated with BEGIN or START TRANSACTION")

                    self._in_transaction = False

                # Execute statement
                result = self._session.execute(sqlalchemy.text(statement))

                # Return value
                ret = True

                # If SELECT, return result set as list of dict objects
                if command == "SELECT":

                    # Coerce types
                    rows = [dict(row) for row in result.fetchall()]
                    for row in rows:
                        for column in row:

                            # Coerce decimal.Decimal objects to float objects
                            # https://groups.google.com/d/msg/sqlalchemy/0qXMYJvq8SA/oqtvMD9Uw-kJ
                            if type(row[column]) is decimal.Decimal:
                                row[column] = float(row[column])

                            # Coerce memoryview objects (as from PostgreSQL's bytea columns) to bytes
                            elif type(row[column]) is memoryview:
                                row[column] = bytes(row[column])

                    # Rows to be returned
                    ret = rows

                # If INSERT, return primary key value for a newly inserted row (or None if none)
                elif command == "INSERT":
                    if self._engine.url.get_backend_name() in ["postgres", "postgresql"]:
                        try:
                            result = self._session.execute("SELECT LASTVAL()")
                            ret = result.first()[0]
                        except sqlalchemy.exc.OperationalError:  # If lastval is not yet defined in this session
                            ret = None
                    else:
                        ret = result.lastrowid if result.rowcount == 1 else None

                # If DELETE or UPDATE, return number of rows matched
                elif command in ["DELETE", "UPDATE"]:
                    ret = result.rowcount

                # If autocommit is on, commit
                if not self._in_transaction:
                    self._session.commit()

            # If constraint violated, return None
            except sqlalchemy.exc.IntegrityError as e:
                self._logger.debug(termcolor.colored(statement, "yellow"))
                e = RuntimeError(e.orig)
                e.__cause__ = None
                raise e

            # If user errror
            except sqlalchemy.exc.OperationalError as e:
                self._logger.debug(termcolor.colored(statement, "red"))
                e = RuntimeError(e.orig)
                e.__cause__ = None
                raise e

            # Return value
            else:
                self._logger.debug(termcolor.colored(_statement, "green"))
                return ret

    def _close_session(self):
        """Closes any existing session and resets instance variables."""
        if self._session is not None:
            self._session.close()

        self._session = None
        self._in_transaction = False

    def _escape(self, value):
        """
        Escapes value using engine's conversion function.

        https://docs.sqlalchemy.org/en/latest/core/type_api.html#sqlalchemy.types.TypeEngine.literal_processor
        """

        # Lazily import
        import sqlparse

        def __escape(value):

            # Lazily import
            import datetime
            import sqlalchemy

            # bool
            if type(value) is bool:
                return sqlparse.sql.Token(
                    sqlparse.tokens.Number,
                    sqlalchemy.types.Boolean().literal_processor(self._engine.dialect)(value))

            # bytes
            elif type(value) is bytes:
                if self._engine.url.get_backend_name() in ["mysql", "sqlite"]:
                    return sqlparse.sql.Token(sqlparse.tokens.Other, f"x'{value.hex()}'")  # https://dev.mysql.com/doc/refman/8.0/en/hexadecimal-literals.html
                elif self._engine.url.get_backend_name() == "postgresql":
                    return sqlparse.sql.Token(sqlparse.tokens.Other, f"'\\x{value.hex()}'")  # https://dba.stackexchange.com/a/203359
                else:
                    raise RuntimeError("unsupported value: {}".format(value))

            # datetime.date
            elif type(value) is datetime.date:
                return sqlparse.sql.Token(
                    sqlparse.tokens.String,
                    sqlalchemy.types.String().literal_processor(self._engine.dialect)(value.strftime("%Y-%m-%d")))

            # datetime.datetime
            elif type(value) is datetime.datetime:
                return sqlparse.sql.Token(
                    sqlparse.tokens.String,
                    sqlalchemy.types.String().literal_processor(self._engine.dialect)(value.strftime("%Y-%m-%d %H:%M:%S")))

            # datetime.time
            elif type(value) is datetime.time:
                return sqlparse.sql.Token(
                    sqlparse.tokens.String,
                    sqlalchemy.types.String().literal_processor(self._engine.dialect)(value.strftime("%H:%M:%S")))

            # float
            elif type(value) is float:
                return sqlparse.sql.Token(
                    sqlparse.tokens.Number,
                    sqlalchemy.types.Float().literal_processor(self._engine.dialect)(value))

            # int
            elif type(value) is int:
                return sqlparse.sql.Token(
                    sqlparse.tokens.Number,
                    sqlalchemy.types.Integer().literal_processor(self._engine.dialect)(value))

            # str
            elif type(value) is str:
                return sqlparse.sql.Token(
                    sqlparse.tokens.String,
                    sqlalchemy.types.String().literal_processor(self._engine.dialect)(value))

            # None
            elif value is None:
                return sqlparse.sql.Token(
                    sqlparse.tokens.Keyword,
                    sqlalchemy.types.NullType().literal_processor(self._engine.dialect)(value))

            # Unsupported value
            else:
                raise RuntimeError("unsupported value: {}".format(value))

        # Escape value(s), separating with commas as needed
        if type(value) in [list, tuple]:
            return sqlparse.sql.TokenList([__escape(v) for v in value])
        else:
            return __escape(value)


def _parse_exception(e):
    """Parses an exception, returns its message."""

    # Lazily import
    import re

    # MySQL
    matches = re.search(r"^\(_mysql_exceptions\.OperationalError\) \(\d+, \"(.+)\"\)$", str(e))
    if matches:
        return matches.group(1)

    # PostgreSQL
    matches = re.search(r"^\(psycopg2\.OperationalError\) (.+)$", str(e))
    if matches:
        return matches.group(1)

    # SQLite
    matches = re.search(r"^\(sqlite3\.OperationalError\) (.+)$", str(e))
    if matches:
        return matches.group(1)

    # Default
    return str(e)


def _parse_placeholder(token):
    """Infers paramstyle, name from sqlparse.tokens.Name.Placeholder."""

    # Lazily load
    import re
    import sqlparse

    # Validate token
    if not isinstance(token, sqlparse.sql.Token) or token.ttype != sqlparse.tokens.Name.Placeholder:
        raise TypeError()

    # qmark
    if token.value == "?":
        return "qmark", None

    # numeric
    matches = re.search(r"^:([1-9]\d*)$", token.value)
    if matches:
        return "numeric", int(matches.group(1)) - 1

    # named
    matches = re.search(r"^:([a-zA-Z]\w*)$", token.value)
    if matches:
        return "named", matches.group(1)

    # format
    if token.value == "%s":
        return "format", None

    # pyformat
    matches = re.search(r"%\((\w+)\)s$", token.value)
    if matches:
        return "pyformat", matches.group(1)

    # Invalid
    raise RuntimeError("{}: invalid placeholder".format(token.value))