import sys
import threading

# Thread-local data
_data = threading.local()


def _enable_logging(f):
    """Enable logging of SQL statements when Flask is in use."""

    import logging
    import functools
    import os

    @functools.wraps(f)
    def decorator(*args, **kwargs):
        # Infer whether Flask is installed
        try:
            import flask
        except ModuleNotFoundError:
            return f(*args, **kwargs)

        # Enable logging in development mode
        disabled = logging.getLogger("cs50").disabled
        if flask.current_app and os.getenv("FLASK_ENV") == "development":
            logging.getLogger("cs50").disabled = False
        try:
            return f(*args, **kwargs)
        finally:
            logging.getLogger("cs50").disabled = disabled

    return decorator


class SQL(object):
    """Wrap SQLAlchemy to provide a simple SQL API."""

    def __init__(self, url, **kwargs):
        """
        Create instance of sqlalchemy.engine.Engine.

        URL should be a string that indicates database dialect and connection arguments.

        http://docs.sqlalchemy.org/en/latest/core/engines.html#sqlalchemy.create_engine
        http://docs.sqlalchemy.org/en/latest/dialects/index.html
        """

        # Lazily import
        import logging
        import os
        import re
        import sqlalchemy
        import sqlalchemy.orm
        import threading

        # Temporary fix for missing sqlite3 module on the buildpack stack
        try:
            import sqlite3
        except:
            pass

        # Require that file already exist for SQLite
        matches = re.search(r"^sqlite:///(.+)$", url)
        if matches:
            if not os.path.exists(matches.group(1)):
                raise RuntimeError("does not exist: {}".format(matches.group(1)))
            if not os.path.isfile(matches.group(1)):
                raise RuntimeError("not a file: {}".format(matches.group(1)))

        # Create engine, disabling SQLAlchemy's own autocommit mode raising exception if back end's module not installed;
        # without isolation_level, PostgreSQL warns with "there is already a transaction in progress" for our own BEGIN and
        # "there is no transaction in progress" for our own COMMIT
        self._engine = sqlalchemy.create_engine(url, **kwargs).execution_options(
            autocommit=False, isolation_level="AUTOCOMMIT", no_parameters=True
        )

        # Avoid doubly escaping percent signs, since no_parameters=True anyway
        # https://github.com/cs50/python-cs50/issues/171
        self._engine.dialect.identifier_preparer._double_percents = False

        # Get logger
        self._logger = logging.getLogger("cs50")

        # Listener for connections
        def connect(dbapi_connection, connection_record):
            # Enable foreign key constraints
            try:
                if isinstance(
                    dbapi_connection, sqlite3.Connection
                ):  # If back end is sqlite
                    cursor = dbapi_connection.cursor()
                    cursor.execute("PRAGMA foreign_keys=ON")
                    cursor.close()
            except:
                # Temporary fix for missing sqlite3 module on the buildpack stack
                pass

        # Register listener
        sqlalchemy.event.listen(self._engine, "connect", connect)

        # Autocommit by default
        self._autocommit = True

        # Test database
        disabled = self._logger.disabled
        self._logger.disabled = True
        try:
            connection = self._engine.connect()
            connection.execute(sqlalchemy.text("SELECT 1"))
            connection.close()
        except sqlalchemy.exc.OperationalError as e:
            e = RuntimeError(_parse_exception(e))
            e.__cause__ = None
            raise e
        finally:
            self._logger.disabled = disabled

    def __del__(self):
        """Disconnect from database."""
        self._disconnect()

    def _disconnect(self):
        """Close database connection."""
        if hasattr(_data, self._name()):
            getattr(_data, self._name()).close()
            delattr(_data, self._name())

    def _name(self):
        """Return object's hash as a str."""
        return str(hash(self))

    @_enable_logging
    def execute(self, sql, *args, **kwargs):
        """Execute a SQL statement."""

        # Lazily import
        import decimal
        import re
        import sqlalchemy
        import sqlparse
        import termcolor
        import warnings

        # Parse statement, stripping comments and then leading/trailing whitespace
        statements = sqlparse.parse(sqlparse.format(sql, strip_comments=True).strip())

        # Allow only one statement at a time, since SQLite doesn't support multiple
        # https://docs.python.org/3/library/sqlite3.html#sqlite3.Cursor.execute
        if len(statements) > 1:
            raise RuntimeError("too many statements at once")
        elif len(statements) == 0:
            raise RuntimeError("missing statement")

        # Ensure named and positional parameters are mutually exclusive
        if len(args) > 0 and len(kwargs) > 0:
            raise RuntimeError("cannot pass both positional and named parameters")

        # Infer command from flattened statement to a single string separated by spaces
        full_statement = " ".join(
            str(token)
            for token in statements[0].tokens
            if token.ttype
            in [
                sqlparse.tokens.Keyword,
                sqlparse.tokens.Keyword.DDL,
                sqlparse.tokens.Keyword.DML,
            ]
        )
        full_statement = full_statement.upper()

        # Set of possible commands
        commands = {
            "BEGIN",
            "CREATE VIEW",
            "DELETE",
            "INSERT",
            "SELECT",
            "START",
            "UPDATE",
        }

        # Check if the full_statement starts with any command
        command = next(
            (cmd for cmd in commands if full_statement.startswith(cmd)), None
        )

        # Flatten statement
        tokens = list(statements[0].flatten())

        # Validate paramstyle
        placeholders = {}
        paramstyle = None
        for index, token in enumerate(tokens):
            # If token is a placeholder
            if token.ttype == sqlparse.tokens.Name.Placeholder:
                # Determine paramstyle, name
                _paramstyle, name = _parse_placeholder(token)

                # Remember paramstyle
                if not paramstyle:
                    paramstyle = _paramstyle

                # Ensure paramstyle is consistent
                elif _paramstyle != paramstyle:
                    raise RuntimeError("inconsistent paramstyle")

                # Remember placeholder's index, name
                placeholders[index] = name

        # If no placeholders
        if not paramstyle:
            # Error-check like qmark if args
            if args:
                paramstyle = "qmark"

            # Error-check like named if kwargs
            elif kwargs:
                paramstyle = "named"

        # In case of errors
        _placeholders = ", ".join([str(tokens[index]) for index in placeholders])
        _args = ", ".join([str(self._escape(arg)) for arg in args])

        # qmark
        if paramstyle == "qmark":
            # Validate number of placeholders
            if len(placeholders) != len(args):
                if len(placeholders) < len(args):
                    raise RuntimeError(
                        "fewer placeholders ({}) than values ({})".format(
                            _placeholders, _args
                        )
                    )
                else:
                    raise RuntimeError(
                        "more placeholders ({}) than values ({})".format(
                            _placeholders, _args
                        )
                    )

            # Escape values
            for i, index in enumerate(placeholders.keys()):
                tokens[index] = self._escape(args[i])

        # numeric
        elif paramstyle == "numeric":
            # Escape values
            for index, i in placeholders.items():
                if i >= len(args):
                    raise RuntimeError(
                        "missing value for placeholder (:{})".format(i + 1, len(args))
                    )
                tokens[index] = self._escape(args[i])

            # Check if any values unused
            indices = set(range(len(args))) - set(placeholders.values())
            if indices:
                raise RuntimeError(
                    "unused {} ({})".format(
                        "value" if len(indices) == 1 else "values",
                        ", ".join(
                            [str(self._escape(args[index])) for index in indices]
                        ),
                    )
                )

        # named
        elif paramstyle == "named":
            # Escape values
            for index, name in placeholders.items():
                if name not in kwargs:
                    raise RuntimeError(
                        "missing value for placeholder (:{})".format(name)
                    )
                tokens[index] = self._escape(kwargs[name])

            # Check if any keys unused
            keys = kwargs.keys() - placeholders.values()
            if keys:
                raise RuntimeError("unused values ({})".format(", ".join(keys)))

        # format
        elif paramstyle == "format":
            # Validate number of placeholders
            if len(placeholders) != len(args):
                if len(placeholders) < len(args):
                    raise RuntimeError(
                        "fewer placeholders ({}) than values ({})".format(
                            _placeholders, _args
                        )
                    )
                else:
                    raise RuntimeError(
                        "more placeholders ({}) than values ({})".format(
                            _placeholders, _args
                        )
                    )

            # Escape values
            for i, index in enumerate(placeholders.keys()):
                tokens[index] = self._escape(args[i])

        # pyformat
        elif paramstyle == "pyformat":
            # Escape values
            for index, name in placeholders.items():
                if name not in kwargs:
                    raise RuntimeError(
                        "missing value for placeholder (%{}s)".format(name)
                    )
                tokens[index] = self._escape(kwargs[name])

            # Check if any keys unused
            keys = kwargs.keys() - placeholders.values()
            if keys:
                raise RuntimeError(
                    "unused {} ({})".format(
                        "value" if len(keys) == 1 else "values", ", ".join(keys)
                    )
                )

        # For SQL statements where a colon is required verbatim, as within an inline string, use a backslash to escape
        # https://docs.sqlalchemy.org/en/13/core/sqlelement.html?highlight=text#sqlalchemy.sql.expression.text
        for index, token in enumerate(tokens):
            # In string literal
            # https://www.sqlite.org/lang_keywords.html
            if token.ttype in [
                sqlparse.tokens.Literal.String,
                sqlparse.tokens.Literal.String.Single,
            ]:
                token.value = re.sub("(^'|\s+):", r"\1\:", token.value)

            # In identifier
            # https://www.sqlite.org/lang_keywords.html
            elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
                token.value = re.sub('(^"|\s+):', r"\1\:", token.value)

        # Join tokens into statement
        statement = "".join([str(token) for token in tokens])

        # If no connection yet
        if not hasattr(_data, self._name()):
            # Connect to database
            setattr(_data, self._name(), self._engine.connect())

        # Use this connection
        connection = getattr(_data, self._name())

        # Disconnect if/when a Flask app is torn down
        try:
            import flask

            assert flask.current_app

            def teardown_appcontext(exception):
                self._disconnect()

            if teardown_appcontext not in flask.current_app.teardown_appcontext_funcs:
                flask.current_app.teardown_appcontext(teardown_appcontext)
        except (ModuleNotFoundError, AssertionError):
            pass

        # Catch SQLAlchemy warnings
        with warnings.catch_warnings():
            # Raise exceptions for warnings
            warnings.simplefilter("error")

            # Prepare, execute statement
            try:
                # Join tokens into statement, abbreviating binary data as <class 'bytes'>
                _statement = "".join(
                    [
                        str(bytes)
                        if token.ttype == sqlparse.tokens.Other
                        else str(token)
                        for token in tokens
                    ]
                )

                # Check for start of transaction
                if command in ["BEGIN", "START"]:
                    self._autocommit = False

                # Execute statement
                if self._autocommit:
                    connection.execute(sqlalchemy.text("BEGIN"))
                result = connection.execute(sqlalchemy.text(statement))
                if self._autocommit:
                    connection.execute(sqlalchemy.text("COMMIT"))

                # Check for end of transaction
                if command in ["COMMIT", "ROLLBACK"]:
                    self._autocommit = True

                # Return value
                ret = True

                # If SELECT, return result set as list of dict objects
                if command == "SELECT":
                    # Coerce types
                    rows = [dict(row) for row in result.mappings().all()]
                    for row in rows:
                        for column in row:
                            # Coerce decimal.Decimal objects to float objects
                            # https://groups.google.com/d/msg/sqlalchemy/0qXMYJvq8SA/oqtvMD9Uw-kJ
                            if isinstance(row[column], decimal.Decimal):
                                row[column] = float(row[column])

                            # Coerce memoryview objects (as from PostgreSQL's bytea columns) to bytes
                            elif isinstance(row[column], memoryview):
                                row[column] = bytes(row[column])

                    # Rows to be returned
                    ret = rows

                # If INSERT, return primary key value for a newly inserted row (or None if none)
                elif command == "INSERT":
                    # If PostgreSQL
                    if self._engine.url.get_backend_name() == "postgresql":
                        # Return LASTVAL() or NULL, avoiding
                        # "(psycopg2.errors.ObjectNotInPrerequisiteState) lastval is not yet defined in this session",
                        # a la https://stackoverflow.com/a/24186770/5156190;
                        # cf. https://www.psycopg.org/docs/errors.html re 55000
                        result = connection.execute(
                            sqlalchemy.text(
                                """
                            CREATE OR REPLACE FUNCTION _LASTVAL()
                            RETURNS integer LANGUAGE plpgsql
                            AS $$
                            BEGIN
                                BEGIN
                                    RETURN (SELECT LASTVAL());
                                EXCEPTION
                                    WHEN SQLSTATE '55000' THEN RETURN NULL;
                                END;
                            END $$;
                            SELECT _LASTVAL();
                        """
                            )
                        )
                        ret = result.first()[0]

                    # If not PostgreSQL
                    else:
                        ret = result.lastrowid if result.rowcount == 1 else None

                # If DELETE or UPDATE, return number of rows matched
                elif command in ["DELETE", "UPDATE"]:
                    ret = result.rowcount

                # If CREATE VIEW, return True
                elif command == "CREATE VIEW":
                    ret = True

            # If constraint violated
            except sqlalchemy.exc.IntegrityError as e:
                if self._autocommit:
                    connection.execute(sqlalchemy.text("ROLLBACK"))
                self._logger.error(termcolor.colored(_statement, "red"))
                e = ValueError(e.orig)
                e.__cause__ = None
                raise e

            # If user error
            except (
                sqlalchemy.exc.OperationalError,
                sqlalchemy.exc.ProgrammingError,
            ) as e:
                self._disconnect()
                self._logger.error(termcolor.colored(_statement, "red"))
                e = RuntimeError(e.orig)
                e.__cause__ = None
                raise e

            # Return value
            else:
                self._logger.info(termcolor.colored(_statement, "green"))
                if self._autocommit:  # Don't stay connected unnecessarily
                    self._disconnect()
                return ret

    def _escape(self, value):
        """
        Escapes value using engine's conversion function.

        https://docs.sqlalchemy.org/en/latest/core/type_api.html#sqlalchemy.types.TypeEngine.literal_processor
        """

        # Lazily import
        import sqlparse

        def __escape(value):
            # Lazily import
            import datetime
            import sqlalchemy

            # bool
            if isinstance(value, bool):
                return sqlparse.sql.Token(
                    sqlparse.tokens.Number,
                    sqlalchemy.types.Boolean().literal_processor(self._engine.dialect)(
                        value
                    ),
                )

            # bytes
            elif isinstance(value, bytes):
                if self._engine.url.get_backend_name() in ["mysql", "sqlite"]:
                    return sqlparse.sql.Token(
                        sqlparse.tokens.Other, f"x'{value.hex()}'"
                    )  # https://dev.mysql.com/doc/refman/8.0/en/hexadecimal-literals.html
                elif self._engine.url.get_backend_name() == "postgresql":
                    return sqlparse.sql.Token(
                        sqlparse.tokens.Other, f"'\\x{value.hex()}'"
                    )  # https://dba.stackexchange.com/a/203359
                else:
                    raise RuntimeError("unsupported value: {}".format(value))

            # datetime.datetime
            elif isinstance(value, datetime.datetime):
                return sqlparse.sql.Token(
                    sqlparse.tokens.String,
                    sqlalchemy.types.String().literal_processor(self._engine.dialect)(
                        value.strftime("%Y-%m-%d %H:%M:%S")
                    ),
                )

            # datetime.date
            elif isinstance(value, datetime.date):
                return sqlparse.sql.Token(
                    sqlparse.tokens.String,
                    sqlalchemy.types.String().literal_processor(self._engine.dialect)(
                        value.strftime("%Y-%m-%d")
                    ),
                )

            # datetime.time
            elif isinstance(value, datetime.time):
                return sqlparse.sql.Token(
                    sqlparse.tokens.String,
                    sqlalchemy.types.String().literal_processor(self._engine.dialect)(
                        value.strftime("%H:%M:%S")
                    ),
                )

            # float
            elif isinstance(value, float):
                return sqlparse.sql.Token(
                    sqlparse.tokens.Number,
                    sqlalchemy.types.Float().literal_processor(self._engine.dialect)(
                        value
                    ),
                )

            # int
            elif isinstance(value, int):
                return sqlparse.sql.Token(
                    sqlparse.tokens.Number,
                    sqlalchemy.types.Integer().literal_processor(self._engine.dialect)(
                        value
                    ),
                )

            # str
            elif isinstance(value, str):
                return sqlparse.sql.Token(
                    sqlparse.tokens.String,
                    sqlalchemy.types.String().literal_processor(self._engine.dialect)(
                        value
                    ),
                )

            # None
            elif value is None:
                return sqlparse.sql.Token(sqlparse.tokens.Keyword, sqlalchemy.null())

            # Unsupported value
            else:
                raise RuntimeError("unsupported value: {}".format(value))

        # Escape value(s), separating with commas as needed
        if isinstance(value, (list, tuple)):
            return sqlparse.sql.TokenList(
                sqlparse.parse(", ".join([str(__escape(v)) for v in value]))
            )
        else:
            return __escape(value)


def _parse_exception(e):
    """Parses an exception, returns its message."""

    # Lazily import
    import re

    # MySQL
    matches = re.search(
        r"^\(_mysql_exceptions\.OperationalError\) \(\d+, \"(.+)\"\)$", str(e)
    )
    if matches:
        return matches.group(1)

    # PostgreSQL
    matches = re.search(r"^\(psycopg2\.OperationalError\) (.+)$", str(e))
    if matches:
        return matches.group(1)

    # SQLite
    matches = re.search(r"^\(sqlite3\.OperationalError\) (.+)$", str(e))
    if matches:
        return matches.group(1)

    # Default
    return str(e)


def _parse_placeholder(token):
    """Infers paramstyle, name from sqlparse.tokens.Name.Placeholder."""

    # Lazily load
    import re
    import sqlparse

    # Validate token
    if (
        not isinstance(token, sqlparse.sql.Token)
        or token.ttype != sqlparse.tokens.Name.Placeholder
    ):
        raise TypeError()

    # qmark
    if token.value == "?":
        return "qmark", None

    # numeric
    matches = re.search(r"^:([1-9]\d*)$", token.value)
    if matches:
        return "numeric", int(matches.group(1)) - 1

    # named
    matches = re.search(r"^:([a-zA-Z]\w*)$", token.value)
    if matches:
        return "named", matches.group(1)

    # format
    if token.value == "%s":
        return "format", None

    # pyformat
    matches = re.search(r"%\((\w+)\)s$", token.value)
    if matches:
        return "pyformat", matches.group(1)

    # Invalid
    raise RuntimeError("{}: invalid placeholder".format(token.value))