diff --git a/setup.py b/setup.py index 550e65d..de271f8 100644 --- a/setup.py +++ b/setup.py @@ -10,11 +10,11 @@ "Topic :: Software Development :: Libraries :: Python Modules" ], description="CS50 library for Python", - install_requires=["Flask>=1.0", "SQLAlchemy", "sqlparse", "termcolor"], + install_requires=["Flask>=1.0", "SQLAlchemy<2", "sqlparse", "termcolor"], keywords="cs50", name="cs50", package_dir={"": "src"}, packages=["cs50"], url="https://github.com/cs50/python-cs50", - version="6.0.4" + version="7.0.0" ) diff --git a/src/cs50/__init__.py b/src/cs50/__init__.py index aaec161..e5ec787 100644 --- a/src/cs50/__init__.py +++ b/src/cs50/__init__.py @@ -1,20 +1,5 @@ -import logging -import os -import sys - - -# Disable cs50 logger by default -logging.getLogger("cs50").disabled = True - -# Import cs50_* -from .cs50 import get_char, get_float, get_int, get_string -try: - from .cs50 import get_long -except ImportError: - pass - -# Hook into flask importing -from . import flask - -# Wrap SQLAlchemy +from .cs50 import get_float, get_int, get_string from .sql import SQL +from ._logger import _setup_logger + +_setup_logger() diff --git a/src/cs50/_engine.py b/src/cs50/_engine.py new file mode 100644 index 0000000..d74992c --- /dev/null +++ b/src/cs50/_engine.py @@ -0,0 +1,66 @@ +import threading + +from ._engine_util import create_engine + + +thread_local_data = threading.local() + + +class Engine: + """Wraps a SQLAlchemy engine. + """ + + def __init__(self, url): + self._engine = create_engine(url) + + def get_transaction_connection(self): + """ + :returns: A new connection with autocommit disabled (to be used for transactions). + """ + + _thread_local_connections()[self] = self._engine.connect().execution_options( + autocommit=False) + return self.get_existing_transaction_connection() + + def get_connection(self): + """ + :returns: A new connection with autocommit enabled + """ + + return self._engine.connect().execution_options(autocommit=True) + + def get_existing_transaction_connection(self): + """ + :returns: The transaction connection bound to this Engine instance, if one exists, or None. + """ + + return _thread_local_connections().get(self) + + def close_transaction_connection(self): + """Closes the transaction connection bound to this Engine instance, if one exists and + removes it. + """ + + connection = self.get_existing_transaction_connection() + if connection: + connection.close() + del _thread_local_connections()[self] + + def is_postgres(self): + return self._engine.dialect.name in {"postgres", "postgresql"} + + def __getattr__(self, attr): + return getattr(self._engine, attr) + +def _thread_local_connections(): + """ + :returns: A thread local dict to keep track of transaction connection. If one does not exist, + creates one. + """ + + try: + connections = thread_local_data.connections + except AttributeError: + connections = thread_local_data.connections = {} + + return connections diff --git a/src/cs50/_engine_util.py b/src/cs50/_engine_util.py new file mode 100644 index 0000000..c55b8f2 --- /dev/null +++ b/src/cs50/_engine_util.py @@ -0,0 +1,43 @@ +"""Utility functions used by _session.py. +""" + +import os +import sqlite3 + +import sqlalchemy + +sqlite_url_prefix = "sqlite:///" + + +def create_engine(url, **kwargs): + """Creates a new SQLAlchemy engine. If ``url`` is a URL for a SQLite database, makes sure that + the SQLite file exits and enables foreign key constraints. + """ + + try: + engine = sqlalchemy.create_engine(url, **kwargs) + except sqlalchemy.exc.ArgumentError: + raise RuntimeError(f"invalid URL: {url}") from None + + if _is_sqlite_url(url): + _assert_sqlite_file_exists(url) + sqlalchemy.event.listen(engine, "connect", _enable_sqlite_foreign_key_constraints) + + return engine + +def _is_sqlite_url(url): + return url.startswith(sqlite_url_prefix) + + +def _assert_sqlite_file_exists(url): + path = url[len(sqlite_url_prefix):] + if not os.path.exists(path): + raise RuntimeError(f"does not exist: {path}") + if not os.path.isfile(path): + raise RuntimeError(f"not a file: {path}") + + +def _enable_sqlite_foreign_key_constraints(dbapi_connection, _): + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() diff --git a/src/cs50/_logger.py b/src/cs50/_logger.py new file mode 100644 index 0000000..e7b03ca --- /dev/null +++ b/src/cs50/_logger.py @@ -0,0 +1,98 @@ +"""Sets up logging for the library. +""" + +import logging +import os.path +import re +import sys +import traceback + +import termcolor + + +def green(msg): + return _colored(msg, "green") + + +def red(msg): + return _colored(msg, "red") + + +def yellow(msg): + return _colored(msg, "yellow") + + +def _colored(msg, color): + return termcolor.colored(str(msg), color) + + +def _setup_logger(): + _configure_default_logger() + _patch_root_handler_format_exception() + _configure_cs50_logger() + _patch_excepthook() + + +def _configure_default_logger(): + """Configures a default handler and formatter to prevent flask and werkzeug from adding theirs. + """ + + logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.DEBUG) + + +def _patch_root_handler_format_exception(): + """Patches formatException for the root handler to use ``_format_exception``. + """ + + try: + formatter = logging.root.handlers[0].formatter + formatter.formatException = lambda exc_info: _format_exception(*exc_info) + except IndexError: + pass + + +def _configure_cs50_logger(): + """Disables the cs50 logger by default. Disables logging propagation to prevent messages from + being logged more than once. Sets the logging handler and formatter. + """ + + _logger = logging.getLogger("cs50") + _logger.disabled = True + _logger.setLevel(logging.DEBUG) + + # Log messages once + _logger.propagate = False + + handler = logging.StreamHandler() + handler.setLevel(logging.DEBUG) + + formatter = logging.Formatter("%(levelname)s: %(message)s") + formatter.formatException = lambda exc_info: _format_exception(*exc_info) + handler.setFormatter(formatter) + _logger.addHandler(handler) + + +def _patch_excepthook(): + sys.excepthook = lambda type_, value, exc_tb: print( + _format_exception(type_, value, exc_tb), file=sys.stderr) + + +def _format_exception(type_, value, exc_tb): + """Formats traceback, darkening entries from global site-packages directories and user-specific + site-packages directory. + https://stackoverflow.com/a/46071447/5156190 + """ + + # Absolute paths to site-packages + packages = tuple(os.path.join(os.path.abspath(p), "") for p in sys.path[1:]) + + # Highlight lines not referring to files in site-packages + lines = [] + for line in traceback.format_exception(type_, value, exc_tb): + matches = re.search(r"^ File \"([^\"]+)\", line \d+, in .+", line) + if matches and matches.group(1).startswith(packages): + lines += line + else: + matches = re.search(r"^(\s*)(.*?)(\s*)$", line, re.DOTALL) + lines.append(matches.group(1) + yellow(matches.group(2)) + matches.group(3)) + return "".join(lines).rstrip() diff --git a/src/cs50/_sql_sanitizer.py b/src/cs50/_sql_sanitizer.py new file mode 100644 index 0000000..17fc5fa --- /dev/null +++ b/src/cs50/_sql_sanitizer.py @@ -0,0 +1,95 @@ +import datetime +import re + +import sqlalchemy +import sqlparse + + +class SQLSanitizer: + """Sanitizes SQL values. + """ + + def __init__(self, dialect): + self._dialect = dialect + + def escape(self, value): + """Escapes value using engine's conversion function. + https://docs.sqlalchemy.org/en/latest/core/type_api.html#sqlalchemy.types.TypeEngine.literal_processor + + :param value: The value to be sanitized + + :returns: The sanitized value + """ + # pylint: disable=too-many-return-statements + if isinstance(value, (list, tuple)): + return self.escape_iterable(value) + + if isinstance(value, bool): + return sqlparse.sql.Token( + sqlparse.tokens.Number, + sqlalchemy.types.Boolean().literal_processor(self._dialect)(value)) + + if isinstance(value, bytes): + if self._dialect.name in {"mysql", "sqlite"}: + # https://dev.mysql.com/doc/refman/8.0/en/hexadecimal-literals.html + return sqlparse.sql.Token(sqlparse.tokens.Other, f"x'{value.hex()}'") + if self._dialect.name in {"postgres", "postgresql"}: + # https://dba.stackexchange.com/a/203359 + return sqlparse.sql.Token(sqlparse.tokens.Other, f"'\\x{value.hex()}'") + + raise RuntimeError(f"unsupported value: {value}") + + string_processor = sqlalchemy.types.String().literal_processor(self._dialect) + if isinstance(value, datetime.date): + return sqlparse.sql.Token( + sqlparse.tokens.String, string_processor(value.strftime("%Y-%m-%d"))) + + if isinstance(value, datetime.datetime): + return sqlparse.sql.Token( + sqlparse.tokens.String, string_processor(value.strftime("%Y-%m-%d %H:%M:%S"))) + + if isinstance(value, datetime.time): + return sqlparse.sql.Token( + sqlparse.tokens.String, string_processor(value.strftime("%H:%M:%S"))) + + if isinstance(value, float): + return sqlparse.sql.Token( + sqlparse.tokens.Number, + sqlalchemy.types.Float().literal_processor(self._dialect)(value)) + + if isinstance(value, int): + return sqlparse.sql.Token( + sqlparse.tokens.Number, + sqlalchemy.types.Integer().literal_processor(self._dialect)(value)) + + if isinstance(value, str): + return sqlparse.sql.Token(sqlparse.tokens.String, string_processor(value)) + + if value is None: + return sqlparse.sql.Token( + sqlparse.tokens.Keyword, + sqlalchemy.types.NullType().literal_processor(self._dialect)(value)) + + raise RuntimeError(f"unsupported value: {value}") + + def escape_iterable(self, iterable): + """Escapes each value in iterable and joins all the escaped values with ", ", formatted for + SQL's ``IN`` operator. + + :param: An iterable of values to be escaped + + :returns: A comma-separated list of escaped values from ``iterable`` + :rtype: :class:`sqlparse.sql.TokenList` + """ + + return sqlparse.sql.TokenList( + sqlparse.parse(", ".join([str(self.escape(v)) for v in iterable]))) + + +def escape_verbatim_colon(value): + """Escapes verbatim colon from a value so as it is not confused with a parameter marker. + """ + + # E.g., ':foo, ":foo, :foo will be replaced with + # '\:foo, "\:foo, \:foo respectively + return re.sub(r"(^(?:'|\")|\s+):", r"\1\:", value) diff --git a/src/cs50/_sql_util.py b/src/cs50/_sql_util.py new file mode 100644 index 0000000..2dbfecf --- /dev/null +++ b/src/cs50/_sql_util.py @@ -0,0 +1,51 @@ +"""Utility functions used by sql.py. +""" + +import contextlib +import decimal +import warnings + +import sqlalchemy + + +def process_select_result(result): + """Converts a SQLAlchemy result to a ``list`` of ``dict`` objects, each of which represents a + row in the result set. + + :param result: A SQLAlchemy result + :type result: :class:`sqlalchemy.engine.Result` + """ + rows = [dict(row) for row in result.fetchall()] + for row in rows: + for column in row: + # Coerce decimal.Decimal objects to float objects + # https://groups.google.com/d/msg/sqlalchemy/0qXMYJvq8SA/oqtvMD9Uw-kJ + if isinstance(row[column], decimal.Decimal): + row[column] = float(row[column]) + + # Coerce memoryview objects (as from PostgreSQL's bytea columns) to bytes + elif isinstance(row[column], memoryview): + row[column] = bytes(row[column]) + + return rows + + +@contextlib.contextmanager +def raise_errors_for_warnings(): + """Catches warnings and raises errors instead. + """ + + with warnings.catch_warnings(): + warnings.simplefilter("error") + yield + + +def postgres_lastval(connection): + """ + :returns: The ID of the last inserted row, if defined in this session, or None + """ + + try: + return connection.execute("SELECT LASTVAL()").first()[0] + except sqlalchemy.exc.OperationalError: + return None diff --git a/src/cs50/_statement.py b/src/cs50/_statement.py new file mode 100644 index 0000000..2de956a --- /dev/null +++ b/src/cs50/_statement.py @@ -0,0 +1,244 @@ +import collections + +from ._sql_sanitizer import SQLSanitizer, escape_verbatim_colon +from ._statement_util import ( + format_and_parse, + get_human_readable_list, + is_identifier, + is_operation_token, + is_placeholder, + is_string_literal, + operation_keywords, + Paramstyle, + parse_placeholder, +) + + +def statement_factory(dialect): + """Creates a sanitizer for ``dialect`` and injects it into ``Statement``, exposing a simpler + interface for ``Statement``. + + :param dialect: a SQLAlchemy dialect + :type dialect: :class:`sqlalchemy.engine.Dialect` + """ + + sql_sanitizer = SQLSanitizer(dialect) + + def statement(sql, *args, **kwargs): + return Statement(sql_sanitizer, sql, *args, **kwargs) + + return statement + + +class Statement: + """Parses a SQL statement and substitutes any parameter markers with their corresponding + placeholders. + """ + + def __init__(self, sql_sanitizer, sql, *args, **kwargs): + """ + :param sql_sanitizer: The SQL sanitizer used to sanitize the parameters + :type sql_sanitizer: :class:`_sql_sanitizer.SQLSanitizer` + + :param sql: The SQL statement + :type sql: str + + :param *args: Zero or more positional parameters to be substituted for the parameter markers + + :param *kwargs: Zero or more keyword arguments to be substituted for the parameter markers + """ + + if len(args) > 0 and len(kwargs) > 0: + raise RuntimeError("cannot pass both positional and named parameters") + + self._sql_sanitizer = sql_sanitizer + + self._args = self._get_escaped_args(args) + self._kwargs = self._get_escaped_kwargs(kwargs) + + self._statement = format_and_parse(sql) + self._tokens = self._tokenize() + + self._operation_keyword = self._get_operation_keyword() + + self._paramstyle = self._get_paramstyle() + self._placeholders = self._get_placeholders() + self._substitute_markers_with_escaped_params() + # self._escape_verbatim_colons() + + def _get_escaped_args(self, args): + return [self._sql_sanitizer.escape(arg) for arg in args] + + def _get_escaped_kwargs(self, kwargs): + return {k: self._sql_sanitizer.escape(v) for k, v in kwargs.items()} + + def _tokenize(self): + """ + :returns: A flattened list of SQLParse tokens that represent the SQL statement + """ + + return list(self._statement.flatten()) + + def _get_operation_keyword(self): + """ + :returns: The operation keyword of the SQL statement (e.g., ``SELECT``, ``DELETE``, etc) + :rtype: str + """ + + for token in self._statement: + if is_operation_token(token.ttype): + token_value = token.value.upper() + if token_value in operation_keywords: + operation_keyword = token_value + break + else: + operation_keyword = None + + return operation_keyword + + def _get_paramstyle(self): + """ + :returns: The paramstyle used in the SQL statement (if any) + :rtype: :class:_statement_util.Paramstyle`` + """ + + paramstyle = None + for token in self._tokens: + if is_placeholder(token.ttype): + paramstyle, _ = parse_placeholder(token.value) + break + else: + paramstyle = self._default_paramstyle() + + return paramstyle + + def _default_paramstyle(self): + """ + :returns: If positional args were passed, returns ``Paramstyle.QMARK``; if keyword arguments + were passed, returns ``Paramstyle.NAMED``; otherwise, returns ``None`` + """ + + paramstyle = None + if self._args: + paramstyle = Paramstyle.QMARK + elif self._kwargs: + paramstyle = Paramstyle.NAMED + + return paramstyle + + def _get_placeholders(self): + """ + :returns: A dict that maps the index of each parameter marker in the tokens list to the name + of that parameter marker (if applicable) or ``None`` + :rtype: dict + """ + + placeholders = collections.OrderedDict() + for index, token in enumerate(self._tokens): + if is_placeholder(token.ttype): + paramstyle, name = parse_placeholder(token.value) + if paramstyle != self._paramstyle: + raise RuntimeError("inconsistent paramstyle") + + placeholders[index] = name + + return placeholders + + def _substitute_markers_with_escaped_params(self): + if self._paramstyle in {Paramstyle.FORMAT, Paramstyle.QMARK}: + self._substitute_format_or_qmark_markers() + elif self._paramstyle == Paramstyle.NUMERIC: + self._substitue_numeric_markers() + if self._paramstyle in {Paramstyle.NAMED, Paramstyle.PYFORMAT}: + self._substitute_named_or_pyformat_markers() + + def _substitute_format_or_qmark_markers(self): + """Substitutes format or qmark parameter markers with their corresponding parameters. + """ + + self._assert_valid_arg_count() + for arg_index, token_index in enumerate(self._placeholders.keys()): + self._tokens[token_index] = self._args[arg_index] + + def _assert_valid_arg_count(self): + """Raises a ``RuntimeError`` if the number of arguments does not match the number of + placeholders. + """ + + if len(self._placeholders) != len(self._args): + placeholders = get_human_readable_list(self._placeholders.values()) + args = get_human_readable_list(self._args) + if len(self._placeholders) < len(self._args): + raise RuntimeError(f"fewer placeholders ({placeholders}) than values ({args})") + + raise RuntimeError(f"more placeholders ({placeholders}) than values ({args})") + + def _substitue_numeric_markers(self): + """Substitutes numeric parameter markers with their corresponding parameters. Raises a + ``RuntimeError`` if any parameters are missing or unused. + """ + + unused_arg_indices = set(range(len(self._args))) + for token_index, num in self._placeholders.items(): + if num >= len(self._args): + raise RuntimeError(f"missing value for placeholder ({num + 1})") + + self._tokens[token_index] = self._args[num] + unused_arg_indices.remove(num) + + if len(unused_arg_indices) > 0: + unused_args = get_human_readable_list( + [self._args[i] for i in sorted(unused_arg_indices)]) + raise RuntimeError( + f"unused value{'' if len(unused_args) == 1 else 's'} ({unused_args})") + + def _substitute_named_or_pyformat_markers(self): + """Substitutes named or pyformat parameter markers with their corresponding parameters. + Raises a ``RuntimeError`` if any parameters are missing or unused. + """ + + unused_params = set(self._kwargs.keys()) + for token_index, param_name in self._placeholders.items(): + if param_name not in self._kwargs: + raise RuntimeError(f"missing value for placeholder ({param_name})") + + self._tokens[token_index] = self._kwargs[param_name] + unused_params.remove(param_name) + + if len(unused_params) > 0: + joined_unused_params = get_human_readable_list(sorted(unused_params)) + raise RuntimeError( + f"unused value{'' if len(unused_params) == 1 else 's'} ({joined_unused_params})") + + def _escape_verbatim_colons(self): + """Escapes verbatim colons from string literal and identifier tokens so they aren't treated + as parameter markers. + """ + + for token in self._tokens: + if is_string_literal(token.ttype) or is_identifier(token.ttype): + token.value = escape_verbatim_colon(token.value) + + def is_transaction_start(self): + return self._operation_keyword in {"BEGIN", "START"} + + def is_transaction_end(self): + return self._operation_keyword in {"COMMIT", "ROLLBACK"} + + def is_delete(self): + return self._operation_keyword == "DELETE" + + def is_insert(self): + return self._operation_keyword == "INSERT" + + def is_select(self): + return self._operation_keyword == "SELECT" + + def is_update(self): + return self._operation_keyword == "UPDATE" + + def __str__(self): + """Joins the statement tokens into a string. + """ + + return "".join([str(token) for token in self._tokens]) diff --git a/src/cs50/_statement_util.py b/src/cs50/_statement_util.py new file mode 100644 index 0000000..34ca6ff --- /dev/null +++ b/src/cs50/_statement_util.py @@ -0,0 +1,101 @@ +"""Utility functions used by _statement.py. +""" + +import enum +import re + +import sqlparse + + +operation_keywords = { + "BEGIN", + "COMMIT", + "DELETE", + "INSERT", + "ROLLBACK", + "SELECT", + "START", + "UPDATE" +} + + +class Paramstyle(enum.Enum): + """Represents the supported parameter marker styles. + """ + + FORMAT = enum.auto() + NAMED = enum.auto() + NUMERIC = enum.auto() + PYFORMAT = enum.auto() + QMARK = enum.auto() + + +def format_and_parse(sql): + """Formats and parses a SQL statement. Raises ``RuntimeError`` if ``sql`` represents more than + one statement. + + :param sql: The SQL statement to be formatted and parsed + :type sql: str + + :returns: A list of unflattened SQLParse tokens that represent the parsed statement + """ + + formatted_statements = sqlparse.format(sql, strip_comments=True).strip() + parsed_statements = sqlparse.parse(formatted_statements) + statement_count = len(parsed_statements) + if statement_count == 0: + raise RuntimeError("missing statement") + if statement_count > 1: + raise RuntimeError("too many statements at once") + + return parsed_statements[0] + + +def is_placeholder(ttype): + return ttype == sqlparse.tokens.Name.Placeholder + + +def parse_placeholder(value): + """ + :returns: A tuple of the paramstyle and the name of the parameter marker (if any) or ``None`` + :rtype: tuple + """ + if value == "?": + return Paramstyle.QMARK, None + + # E.g., :1 + matches = re.search(r"^:([1-9]\d*)$", value) + if matches: + return Paramstyle.NUMERIC, int(matches.group(1)) - 1 + + # E.g., :foo + matches = re.search(r"^:([a-zA-Z]\w*)$", value) + if matches: + return Paramstyle.NAMED, matches.group(1) + + if value == "%s": + return Paramstyle.FORMAT, None + + # E.g., %(foo)s + matches = re.search(r"%\((\w+)\)s$", value) + if matches: + return Paramstyle.PYFORMAT, matches.group(1) + + raise RuntimeError(f"{value}: invalid placeholder") + + +def is_operation_token(ttype): + return ttype in { + sqlparse.tokens.Keyword, sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML} + + +def is_string_literal(ttype): + return ttype in [sqlparse.tokens.Literal.String, sqlparse.tokens.Literal.String.Single] + + +def is_identifier(ttype): + return ttype == sqlparse.tokens.Literal.String.Symbol + + +def get_human_readable_list(iterable): + return ", ".join(str(v) for v in iterable) diff --git a/src/cs50/cs50.py b/src/cs50/cs50.py index 1d7b6ea..11fa20a 100644 --- a/src/cs50/cs50.py +++ b/src/cs50/cs50.py @@ -1,143 +1,104 @@ -from __future__ import print_function +"""Exposes simple API for getting and validating user input""" -import inspect -import logging -import os import re import sys -from distutils.sysconfig import get_python_lib -from os.path import abspath, join -from termcolor import colored -from traceback import format_exception +def get_float(prompt): + """Reads a line of text from standard input and returns the equivalent float as precisely as + possible; if text does not represent a float, user is prompted to retry. If line can't be read, + returns None. -# Configure default logging handler and formatter -# Prevent flask, werkzeug, etc from adding default handler -logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.DEBUG) + :type prompt: str -try: - # Patch formatException - logging.root.handlers[0].formatter.formatException = lambda exc_info: _formatException(*exc_info) -except IndexError: - pass + """ -# Configure cs50 logger -_logger = logging.getLogger("cs50") -_logger.setLevel(logging.DEBUG) + while True: + try: + return _get_float(prompt) + except (OverflowError, ValueError): + pass -# Log messages once -_logger.propagate = False -handler = logging.StreamHandler() -handler.setLevel(logging.DEBUG) +def _get_float(prompt): + user_input = get_string(prompt) + if user_input is None: + return None -formatter = logging.Formatter("%(levelname)s: %(message)s") -formatter.formatException = lambda exc_info: _formatException(*exc_info) -handler.setFormatter(formatter) -_logger.addHandler(handler) + if len(user_input) > 0 and re.search(r"^[+-]?\d*(?:\.\d*)?$", user_input): + return float(user_input) + raise ValueError(f"invalid float literal: {user_input}") -class _flushfile(): - """ - Disable buffering for standard output and standard error. - http://stackoverflow.com/a/231216 +def get_int(prompt): + """Reads a line of text from standard input and return the equivalent int; if text does not + represent an int, user is prompted to retry. If line can't be read, returns None. + + :type prompt: str """ - def __init__(self, f): - self.f = f + while True: + try: + return _get_int(prompt) + except (MemoryError, ValueError): + pass - def __getattr__(self, name): - return object.__getattribute__(self.f, name) - def write(self, x): - self.f.write(x) - self.f.flush() +def _get_int(prompt): + user_input = get_string(prompt) + if user_input is None: + return None + if re.search(r"^[+-]?\d+$", user_input): + return int(user_input, 10) -sys.stderr = _flushfile(sys.stderr) -sys.stdout = _flushfile(sys.stdout) + raise ValueError(f"invalid int literal for base 10: {user_input}") -def _formatException(type, value, tb): - """ - Format traceback, darkening entries from global site-packages directories - and user-specific site-packages directory. +def get_string(prompt): + """Reads a line of text from standard input and returns it as a string, sans trailing line + ending. Supports CR (\r), LF (\n), and CRLF (\r\n) as line endings. If user inputs only a line + ending, returns "", not None. Returns None upon error or no input whatsoever (i.e., just EOF). - https://stackoverflow.com/a/46071447/5156190 + :type prompt: str """ - # Absolute paths to site-packages - packages = tuple(join(abspath(p), "") for p in sys.path[1:]) - - # Highlight lines not referring to files in site-packages - lines = [] - for line in format_exception(type, value, tb): - matches = re.search(r"^ File \"([^\"]+)\", line \d+, in .+", line) - if matches and matches.group(1).startswith(packages): - lines += line - else: - matches = re.search(r"^(\s*)(.*?)(\s*)$", line, re.DOTALL) - lines.append(matches.group(1) + colored(matches.group(2), "yellow") + matches.group(3)) - return "".join(lines).rstrip() + if not isinstance(prompt, str): + raise TypeError("prompt must be of type str") + try: + return _get_input(prompt) + except EOFError: + return None -sys.excepthook = lambda type, value, tb: print(_formatException(type, value, tb), file=sys.stderr) +def _get_input(prompt): + return input(prompt) -def eprint(*args, **kwargs): - raise RuntimeError("The CS50 Library for Python no longer supports eprint, but you can use print instead!") +class _flushfile(): + """ Disable buffering for standard output and standard error. + http://stackoverflow.com/a/231216 + """ -def get_char(prompt): - raise RuntimeError("The CS50 Library for Python no longer supports get_char, but you can use get_string instead!") + def __init__(self, stream): + self.stream = stream + def __getattr__(self, name): + return object.__getattribute__(self.stream, name) -def get_float(prompt): - """ - Read a line of text from standard input and return the equivalent float - as precisely as possible; if text does not represent a double, user is - prompted to retry. If line can't be read, return None. - """ - while True: - s = get_string(prompt) - if s is None: - return None - if len(s) > 0 and re.search(r"^[+-]?\d*(?:\.\d*)?$", s): - try: - return float(s) - except (OverflowError, ValueError): - pass + def write(self, data): + """Writes data to stream""" + self.stream.write(data) + self.stream.flush() -def get_int(prompt): +def disable_output_buffering(): + """Disables output buffering to prevent prompts from being buffered. """ - Read a line of text from standard input and return the equivalent int; - if text does not represent an int, user is prompted to retry. If line - can't be read, return None. - """ - while True: - s = get_string(prompt) - if s is None: - return None - if re.search(r"^[+-]?\d+$", s): - try: - return int(s, 10) - except ValueError: - pass + sys.stderr = _flushfile(sys.stderr) + sys.stdout = _flushfile(sys.stdout) -def get_string(prompt): - """ - Read a line of text from standard input and return it as a string, - sans trailing line ending. Supports CR (\r), LF (\n), and CRLF (\r\n) - as line endings. If user inputs only a line ending, returns "", not None. - Returns None upon error or no input whatsoever (i.e., just EOF). - """ - if type(prompt) is not str: - raise TypeError("prompt must be of type str") - try: - return input(prompt) - except EOFError: - return None +disable_output_buffering() diff --git a/src/cs50/flask.py b/src/cs50/flask.py deleted file mode 100644 index 324ec30..0000000 --- a/src/cs50/flask.py +++ /dev/null @@ -1,38 +0,0 @@ -import os -import pkgutil -import sys - -def _wrap_flask(f): - if f is None: - return - - from distutils.version import StrictVersion - from .cs50 import _formatException - - if f.__version__ < StrictVersion("1.0"): - return - - if os.getenv("CS50_IDE_TYPE") == "online": - from werkzeug.middleware.proxy_fix import ProxyFix - _flask_init_before = f.Flask.__init__ - def _flask_init_after(self, *args, **kwargs): - _flask_init_before(self, *args, **kwargs) - self.wsgi_app = ProxyFix(self.wsgi_app, x_proto=1) # For HTTPS-to-HTTP proxy - f.Flask.__init__ = _flask_init_after - - -# If Flask was imported before cs50 -if "flask" in sys.modules: - _wrap_flask(sys.modules["flask"]) - -# If Flask wasn't imported -else: - flask_loader = pkgutil.get_loader('flask') - if flask_loader: - _exec_module_before = flask_loader.exec_module - - def _exec_module_after(*args, **kwargs): - _exec_module_before(*args, **kwargs) - _wrap_flask(sys.modules["flask"]) - - flask_loader.exec_module = _exec_module_after diff --git a/src/cs50/sql.py b/src/cs50/sql.py index f95b347..64d30e3 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -1,545 +1,116 @@ -def _enable_logging(f): - """Enable logging of SQL statements when Flask is in use.""" +import logging - import logging - import functools +import sqlalchemy - @functools.wraps(f) - def decorator(*args, **kwargs): +from ._logger import green, red, yellow +from ._engine import Engine +from ._statement import statement_factory +from ._sql_util import postgres_lastval, process_select_result, raise_errors_for_warnings - # Infer whether Flask is installed - try: - import flask - except ModuleNotFoundError: - return f(*args, **kwargs) - # Enable logging - disabled = logging.getLogger("cs50").disabled - if flask.current_app: - logging.getLogger("cs50").disabled = False - try: - return f(*args, **kwargs) - finally: - logging.getLogger("cs50").disabled = disabled +_logger = logging.getLogger("cs50") - return decorator +class SQL: + """An API for executing SQL Statements. + """ -class SQL(object): - """Wrap SQLAlchemy to provide a simple SQL API.""" - - def __init__(self, url, **kwargs): + def __init__(self, url): """ - Create instance of sqlalchemy.engine.Engine. - - URL should be a string that indicates database dialect and connection arguments. - - http://docs.sqlalchemy.org/en/latest/core/engines.html#sqlalchemy.create_engine - http://docs.sqlalchemy.org/en/latest/dialects/index.html + :param url: The database URL """ - # Lazily import - import logging - import os - import re - import sqlalchemy - import sqlalchemy.orm - import sqlite3 - - # Require that file already exist for SQLite - matches = re.search(r"^sqlite:///(.+)$", url) - if matches: - if not os.path.exists(matches.group(1)): - raise RuntimeError("does not exist: {}".format(matches.group(1))) - if not os.path.isfile(matches.group(1)): - raise RuntimeError("not a file: {}".format(matches.group(1))) - - # Create engine, disabling SQLAlchemy's own autocommit mode, raising exception if back end's module not installed - self._engine = sqlalchemy.create_engine(url, **kwargs).execution_options(autocommit=False) - - # Get logger - self._logger = logging.getLogger("cs50") - - # Listener for connections - def connect(dbapi_connection, connection_record): - - # Disable underlying API's own emitting of BEGIN and COMMIT so we can ourselves - # https://docs.sqlalchemy.org/en/13/dialects/sqlite.html#serializable-isolation-savepoints-transactional-ddl - dbapi_connection.isolation_level = None - - # Enable foreign key constraints - if type(dbapi_connection) is sqlite3.Connection: # If back end is sqlite - cursor = dbapi_connection.cursor() - cursor.execute("PRAGMA foreign_keys=ON") - cursor.close() - - # Register listener - sqlalchemy.event.listen(self._engine, "connect", connect) - - # Autocommit by default - self._autocommit = True - - # Test database - disabled = self._logger.disabled - self._logger.disabled = True - try: - self.execute("SELECT 1") - except sqlalchemy.exc.OperationalError as e: - e = RuntimeError(_parse_exception(e)) - e.__cause__ = None - raise e - finally: - self._logger.disabled = disabled - - def __del__(self): - """Disconnect from database.""" - self._disconnect() - - def _disconnect(self): - """Close database connection.""" - if hasattr(self, "_session"): - self._session.remove() - delattr(self, "_session") + self._engine = Engine(url) + self._substitute_markers_with_params = statement_factory(self._engine.dialect) - @_enable_logging def execute(self, sql, *args, **kwargs): - """Execute a SQL statement.""" - - # Lazily import - import decimal - import re - import sqlalchemy - import sqlparse - import termcolor - import warnings - - # Parse statement, stripping comments and then leading/trailing whitespace - statements = sqlparse.parse(sqlparse.format(sql, strip_comments=True).strip()) - - # Allow only one statement at a time, since SQLite doesn't support multiple - # https://docs.python.org/3/library/sqlite3.html#sqlite3.Cursor.execute - if len(statements) > 1: - raise RuntimeError("too many statements at once") - elif len(statements) == 0: - raise RuntimeError("missing statement") - - # Ensure named and positional parameters are mutually exclusive - if len(args) > 0 and len(kwargs) > 0: - raise RuntimeError("cannot pass both positional and named parameters") - - # Infer command from (unflattened) statement - for token in statements[0]: - if token.ttype in [sqlparse.tokens.Keyword, sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML]: - token_value = token.value.upper() - if token_value in ["BEGIN", "DELETE", "INSERT", "SELECT", "START", "UPDATE"]: - command = token_value - break - else: - command = None - - # Flatten statement - tokens = list(statements[0].flatten()) - - # Validate paramstyle - placeholders = {} - paramstyle = None - for index, token in enumerate(tokens): - - # If token is a placeholder - if token.ttype == sqlparse.tokens.Name.Placeholder: - - # Determine paramstyle, name - _paramstyle, name = _parse_placeholder(token) - - # Remember paramstyle - if not paramstyle: - paramstyle = _paramstyle - - # Ensure paramstyle is consistent - elif _paramstyle != paramstyle: - raise RuntimeError("inconsistent paramstyle") - - # Remember placeholder's index, name - placeholders[index] = name - - # If no placeholders - if not paramstyle: - - # Error-check like qmark if args - if args: - paramstyle = "qmark" - - # Error-check like named if kwargs - elif kwargs: - paramstyle = "named" - - # In case of errors - _placeholders = ", ".join([str(tokens[index]) for index in placeholders]) - _args = ", ".join([str(self._escape(arg)) for arg in args]) - - # qmark - if paramstyle == "qmark": - - # Validate number of placeholders - if len(placeholders) != len(args): - if len(placeholders) < len(args): - raise RuntimeError("fewer placeholders ({}) than values ({})".format(_placeholders, _args)) - else: - raise RuntimeError("more placeholders ({}) than values ({})".format(_placeholders, _args)) - - # Escape values - for i, index in enumerate(placeholders.keys()): - tokens[index] = self._escape(args[i]) - - # numeric - elif paramstyle == "numeric": - - # Escape values - for index, i in placeholders.items(): - if i >= len(args): - raise RuntimeError("missing value for placeholder (:{})".format(i + 1, len(args))) - tokens[index] = self._escape(args[i]) - - # Check if any values unused - indices = set(range(len(args))) - set(placeholders.values()) - if indices: - raise RuntimeError("unused {} ({})".format( - "value" if len(indices) == 1 else "values", - ", ".join([str(self._escape(args[index])) for index in indices]))) - - # named - elif paramstyle == "named": - - # Escape values - for index, name in placeholders.items(): - if name not in kwargs: - raise RuntimeError("missing value for placeholder (:{})".format(name)) - tokens[index] = self._escape(kwargs[name]) - - # Check if any keys unused - keys = kwargs.keys() - placeholders.values() - if keys: - raise RuntimeError("unused values ({})".format(", ".join(keys))) - - # format - elif paramstyle == "format": - - # Validate number of placeholders - if len(placeholders) != len(args): - if len(placeholders) < len(args): - raise RuntimeError("fewer placeholders ({}) than values ({})".format(_placeholders, _args)) - else: - raise RuntimeError("more placeholders ({}) than values ({})".format(_placeholders, _args)) - - # Escape values - for i, index in enumerate(placeholders.keys()): - tokens[index] = self._escape(args[i]) - - # pyformat - elif paramstyle == "pyformat": - - # Escape values - for index, name in placeholders.items(): - if name not in kwargs: - raise RuntimeError("missing value for placeholder (%{}s)".format(name)) - tokens[index] = self._escape(kwargs[name]) - - # Check if any keys unused - keys = kwargs.keys() - placeholders.values() - if keys: - raise RuntimeError("unused {} ({})".format( - "value" if len(keys) == 1 else "values", - ", ".join(keys))) - - # For SQL statements where a colon is required verbatim, as within an inline string, use a backslash to escape - # https://docs.sqlalchemy.org/en/13/core/sqlelement.html?highlight=text#sqlalchemy.sql.expression.text - for index, token in enumerate(tokens): - - # In string literal - # https://www.sqlite.org/lang_keywords.html - if token.ttype in [sqlparse.tokens.Literal.String, sqlparse.tokens.Literal.String.Single]: - token.value = re.sub("(^'|\s+):", r"\1\:", token.value) - - # In identifier - # https://www.sqlite.org/lang_keywords.html - elif token.ttype == sqlparse.tokens.Literal.String.Symbol: - token.value = re.sub("(^\"|\s+):", r"\1\:", token.value) - - # Join tokens into statement - statement = "".join([str(token) for token in tokens]) - - # Connect to database - try: - - # Infer whether Flask is installed - import flask - - # Infer whether app is defined - assert flask.current_app + """Executes a SQL statement. - # If no sessions for any databases yet - if not hasattr(flask.g, "_sessions"): - setattr(flask.g, "_sessions", {}) - sessions = getattr(flask.g, "_sessions") + :param sql: a SQL statement, possibly with parameters markers + :type sql: str + :param *args: zero or more positional arguments to substitute the parameter markers with + :param **kwargs: zero or more keyword arguments to substitute the parameter markers with - # If no session yet for this database - # https://flask.palletsprojects.com/en/1.1.x/appcontext/#storing-data - # https://stackoverflow.com/a/34010159 - if self not in sessions: + :returns: For ``SELECT``, a :py:class:`list` of :py:class:`dict` objects, each of which + represents a row in the result set; for ``INSERT``, the primary key of a newly inserted row + (or ``None`` if none); for ``UPDATE``, the number of rows updated; for ``DELETE``, the + number of rows deleted; for other statements, ``True``; on integrity errors, a + :py:class:`ValueError` is raised, on other errors, a :py:class:`RuntimeError` is raised - # Connect to database - sessions[self] = sqlalchemy.orm.scoping.scoped_session(sqlalchemy.orm.sessionmaker(bind=self._engine)) - - # Remove session later - if _teardown_appcontext not in flask.current_app.teardown_appcontext_funcs: - flask.current_app.teardown_appcontext(_teardown_appcontext) - - # Use this session - session = sessions[self] - - except (ModuleNotFoundError, AssertionError): - - # If no connection yet - if not hasattr(self, "_session"): - self._session = sqlalchemy.orm.scoping.scoped_session(sqlalchemy.orm.sessionmaker(bind=self._engine)) - - # Use this session - session = self._session - - # Catch SQLAlchemy warnings - with warnings.catch_warnings(): - - # Raise exceptions for warnings - warnings.simplefilter("error") - - # Prepare, execute statement - try: - - # Join tokens into statement, abbreviating binary data as <class 'bytes'> - _statement = "".join([str(bytes) if token.ttype == sqlparse.tokens.Other else str(token) for token in tokens]) - - # Check for start of transaction - if command in ["BEGIN", "START"]: - self._autocommit = False - - # Execute statement - if self._autocommit: - session.execute(sqlalchemy.text("BEGIN")) - result = session.execute(sqlalchemy.text(statement)) - if self._autocommit: - session.execute(sqlalchemy.text("COMMIT")) - - # Check for end of transaction - if command in ["COMMIT", "ROLLBACK"]: - self._autocommit = True - - # Return value - ret = True - - # If SELECT, return result set as list of dict objects - if command == "SELECT": - - # Coerce types - rows = [dict(row) for row in result.fetchall()] - for row in rows: - for column in row: - - # Coerce decimal.Decimal objects to float objects - # https://groups.google.com/d/msg/sqlalchemy/0qXMYJvq8SA/oqtvMD9Uw-kJ - if type(row[column]) is decimal.Decimal: - row[column] = float(row[column]) - - # Coerce memoryview objects (as from PostgreSQL's bytea columns) to bytes - elif type(row[column]) is memoryview: - row[column] = bytes(row[column]) - - # Rows to be returned - ret = rows - - # If INSERT, return primary key value for a newly inserted row (or None if none) - elif command == "INSERT": - if self._engine.url.get_backend_name() in ["postgres", "postgresql"]: - try: - result = session.execute("SELECT LASTVAL()") - ret = result.first()[0] - except sqlalchemy.exc.OperationalError: # If lastval is not yet defined in this session - ret = None - else: - ret = result.lastrowid if result.rowcount == 1 else None - - # If DELETE or UPDATE, return number of rows matched - elif command in ["DELETE", "UPDATE"]: - ret = result.rowcount - - # If constraint violated, return None - except sqlalchemy.exc.IntegrityError as e: - self._logger.debug(termcolor.colored(statement, "yellow")) - e = ValueError(e.orig) - e.__cause__ = None - raise e - - # If user error - except (sqlalchemy.exc.OperationalError, sqlalchemy.exc.ProgrammingError) as e: - self._disconnect() - self._logger.debug(termcolor.colored(statement, "red")) - e = RuntimeError(e.orig) - e.__cause__ = None - raise e - - # Return value - else: - self._logger.debug(termcolor.colored(_statement, "green")) - return ret - - def _escape(self, value): """ - Escapes value using engine's conversion function. - https://docs.sqlalchemy.org/en/latest/core/type_api.html#sqlalchemy.types.TypeEngine.literal_processor - """ - - # Lazily import - import sqlparse - - def __escape(value): - - # Lazily import - import datetime - import sqlalchemy + statement = self._substitute_markers_with_params(sql, *args, **kwargs) + connection = self._engine.get_existing_transaction_connection() + if connection is None: + if statement.is_transaction_start(): + connection = self._engine.get_transaction_connection() + else: + connection = self._engine.get_connection() + elif statement.is_transaction_start(): + raise RuntimeError("nested transactions are not supported") - # bool - if type(value) is bool: - return sqlparse.sql.Token( - sqlparse.tokens.Number, - sqlalchemy.types.Boolean().literal_processor(self._engine.dialect)(value)) + return self._execute(statement, connection) - # bytes - elif type(value) is bytes: - if self._engine.url.get_backend_name() in ["mysql", "sqlite"]: - return sqlparse.sql.Token(sqlparse.tokens.Other, f"x'{value.hex()}'") # https://dev.mysql.com/doc/refman/8.0/en/hexadecimal-literals.html - elif self._engine.url.get_backend_name() == "postgresql": - return sqlparse.sql.Token(sqlparse.tokens.Other, f"'\\x{value.hex()}'") # https://dba.stackexchange.com/a/203359 + def _execute(self, statement, connection): + with raise_errors_for_warnings(): + try: + result = connection.execute(str(statement)) + # E.g., failed constraint + except sqlalchemy.exc.IntegrityError as exc: + _logger.debug(yellow(statement)) + if self._engine.get_existing_transaction_connection() is None: + connection.close() + raise ValueError(exc.orig) from None + # E.g., connection error or syntax error + except (sqlalchemy.exc.OperationalError, sqlalchemy.exc.ProgrammingError) as exc: + if self._engine.get_existing_transaction_connection(): + self._engine.close_transaction_connection() else: - raise RuntimeError("unsupported value: {}".format(value)) - - # datetime.date - elif type(value) is datetime.date: - return sqlparse.sql.Token( - sqlparse.tokens.String, - sqlalchemy.types.String().literal_processor(self._engine.dialect)(value.strftime("%Y-%m-%d"))) - - # datetime.datetime - elif type(value) is datetime.datetime: - return sqlparse.sql.Token( - sqlparse.tokens.String, - sqlalchemy.types.String().literal_processor(self._engine.dialect)(value.strftime("%Y-%m-%d %H:%M:%S"))) - - # datetime.time - elif type(value) is datetime.time: - return sqlparse.sql.Token( - sqlparse.tokens.String, - sqlalchemy.types.String().literal_processor(self._engine.dialect)(value.strftime("%H:%M:%S"))) - - # float - elif type(value) is float: - return sqlparse.sql.Token( - sqlparse.tokens.Number, - sqlalchemy.types.Float().literal_processor(self._engine.dialect)(value)) - - # int - elif type(value) is int: - return sqlparse.sql.Token( - sqlparse.tokens.Number, - sqlalchemy.types.Integer().literal_processor(self._engine.dialect)(value)) - - # str - elif type(value) is str: - return sqlparse.sql.Token( - sqlparse.tokens.String, - sqlalchemy.types.String().literal_processor(self._engine.dialect)(value)) - - # None - elif value is None: - return sqlparse.sql.Token( - sqlparse.tokens.Keyword, - sqlalchemy.types.NullType().literal_processor(self._engine.dialect)(value)) - - # Unsupported value + connection.close() + _logger.debug(red(statement)) + raise RuntimeError(exc.orig) from None + + _logger.debug(green(statement)) + + if statement.is_select(): + ret = process_select_result(result) + elif statement.is_insert(): + ret = self._last_row_id_or_none(result) + elif statement.is_delete() or statement.is_update(): + ret = result.rowcount else: - raise RuntimeError("unsupported value: {}".format(value)) + ret = True - # Escape value(s), separating with commas as needed - if type(value) in [list, tuple]: - return sqlparse.sql.TokenList(sqlparse.parse(", ".join([str(__escape(v)) for v in value]))) + if self._engine.get_existing_transaction_connection(): + if statement.is_transaction_end(): + self._engine.close_transaction_connection() else: - return __escape(value) + connection.close() + return ret -def _parse_exception(e): - """Parses an exception, returns its message.""" - - # Lazily import - import re - - # MySQL - matches = re.search(r"^\(_mysql_exceptions\.OperationalError\) \(\d+, \"(.+)\"\)$", str(e)) - if matches: - return matches.group(1) - - # PostgreSQL - matches = re.search(r"^\(psycopg2\.OperationalError\) (.+)$", str(e)) - if matches: - return matches.group(1) - - # SQLite - matches = re.search(r"^\(sqlite3\.OperationalError\) (.+)$", str(e)) - if matches: - return matches.group(1) - - # Default - return str(e) - - -def _parse_placeholder(token): - """Infers paramstyle, name from sqlparse.tokens.Name.Placeholder.""" - - # Lazily load - import re - import sqlparse - - # Validate token - if not isinstance(token, sqlparse.sql.Token) or token.ttype != sqlparse.tokens.Name.Placeholder: - raise TypeError() - - # qmark - if token.value == "?": - return "qmark", None + def _last_row_id_or_none(self, result): + """ + :param result: A SQLAlchemy result object + :type result: :class:`sqlalchemy.engine.Result` - # numeric - matches = re.search(r"^:([1-9]\d*)$", token.value) - if matches: - return "numeric", int(matches.group(1)) - 1 + :returns: The ID of the last inserted row or ``None`` + """ - # named - matches = re.search(r"^:([a-zA-Z]\w*)$", token.value) - if matches: - return "named", matches.group(1) + if self._engine.is_postgres(): + return postgres_lastval(result.connection) + return result.lastrowid if result.rowcount == 1 else None - # format - if token.value == "%s": - return "format", None + def init_app(self, app): + """Enables logging and registers a ``teardown_appcontext`` listener to remove the session. - # pyformat - matches = re.search(r"%\((\w+)\)s$", token.value) - if matches: - return "pyformat", matches.group(1) + :param app: a Flask application instance + :type app: :class:`flask.Flask` + """ - # Invalid - raise RuntimeError("{}: invalid placeholder".format(token.value)) + @app.teardown_appcontext + def _(_): + self._engine.close_transaction_connection() -def _teardown_appcontext(exception=None): - """Closes context's database connection, if any.""" - import flask - for session in flask.g.pop("_sessions", {}).values(): - session.remove() + logging.getLogger("cs50").disabled = False diff --git a/tests/flask/application.py b/tests/flask/application.py deleted file mode 100644 index 939a8f9..0000000 --- a/tests/flask/application.py +++ /dev/null @@ -1,22 +0,0 @@ -import requests -import sys -from flask import Flask, render_template - -sys.path.insert(0, "../../src") - -import cs50 -import cs50.flask - -app = Flask(__name__) - -db = cs50.SQL("sqlite:///../sqlite.db") - -@app.route("/") -def index(): - db.execute("SELECT 1") - """ - def f(): - res = requests.get("cs50.harvard.edu") - f() - """ - return render_template("index.html") diff --git a/tests/flask/requirements.txt b/tests/flask/requirements.txt deleted file mode 100644 index 7d0c101..0000000 --- a/tests/flask/requirements.txt +++ /dev/null @@ -1,2 +0,0 @@ -cs50 -Flask diff --git a/tests/flask/templates/error.html b/tests/flask/templates/error.html deleted file mode 100644 index 3302040..0000000 --- a/tests/flask/templates/error.html +++ /dev/null @@ -1,10 +0,0 @@ -<!DOCTYPE html> - -<html> - <head> - <title>error</title> - </head> - <body> - error - </body> -</html> diff --git a/tests/flask/templates/index.html b/tests/flask/templates/index.html deleted file mode 100644 index 2f6a145..0000000 --- a/tests/flask/templates/index.html +++ /dev/null @@ -1,10 +0,0 @@ -<!DOCTYPE html> - -<html> - <head> - <title>flask</title> - </head> - <body> - flask - </body> -</html> diff --git a/tests/foo.py b/tests/foo.py deleted file mode 100644 index 7f32a00..0000000 --- a/tests/foo.py +++ /dev/null @@ -1,48 +0,0 @@ -import logging -import sys - -sys.path.insert(0, "../src") - -import cs50 - -""" -db = cs50.SQL("sqlite:///foo.db") - -logging.getLogger("cs50").disabled = False - -#db.execute("SELECT ? FROM ? ORDER BY ?", "a", "tbl", "c") -db.execute("CREATE TABLE IF NOT EXISTS bar (firstname STRING)") - -db.execute("INSERT INTO bar VALUES (?)", "baz") -db.execute("INSERT INTO bar VALUES (?)", "qux") -db.execute("SELECT * FROM bar WHERE firstname IN (?)", ("baz", "qux")) -db.execute("DELETE FROM bar") -""" - -db = cs50.SQL("postgresql://postgres@localhost/test") - -""" -print(db.execute("DROP TABLE IF EXISTS cs50")) -print(db.execute("CREATE TABLE cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16), bin BYTEA)")) -print(db.execute("INSERT INTO cs50 (val) VALUES('foo')")) -print(db.execute("SELECT * FROM cs50")) - -print(db.execute("DROP TABLE IF EXISTS cs50")) -print(db.execute("CREATE TABLE cs50 (val VARCHAR(16), bin BYTEA)")) -print(db.execute("INSERT INTO cs50 (val) VALUES('foo')")) -print(db.execute("SELECT * FROM cs50")) -""" - -print(db.execute("DROP TABLE IF EXISTS cs50")) -print(db.execute("CREATE TABLE cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16), bin BYTEA)")) -print(db.execute("INSERT INTO cs50 (val) VALUES('foo')")) -print(db.execute("INSERT INTO cs50 (val) VALUES('bar')")) -print(db.execute("INSERT INTO cs50 (val) VALUES('baz')")) -print(db.execute("SELECT * FROM cs50")) -try: - print(db.execute("INSERT INTO cs50 (id, val) VALUES(1, 'bar')")) -except Exception as e: - print(e) - pass -print(db.execute("INSERT INTO cs50 (val) VALUES('qux')")) -#print(db.execute("DELETE FROM cs50")) diff --git a/tests/mysql.py b/tests/mysql.py deleted file mode 100644 index 2a431c3..0000000 --- a/tests/mysql.py +++ /dev/null @@ -1,8 +0,0 @@ -import sys - -sys.path.insert(0, "../src") - -from cs50 import SQL - -db = SQL("mysql://root@localhost/test") -db.execute("SELECT 1") diff --git a/tests/python.py b/tests/python.py deleted file mode 100644 index 6a265cb..0000000 --- a/tests/python.py +++ /dev/null @@ -1,8 +0,0 @@ -import sys - -sys.path.insert(0, "../src") - -import cs50 - -i = cs50.get_int("Input: ") -print(f"Output: {i}") diff --git a/tests/redirect/application.py b/tests/redirect/application.py deleted file mode 100644 index 6aff187..0000000 --- a/tests/redirect/application.py +++ /dev/null @@ -1,12 +0,0 @@ -import cs50 -from flask import Flask, redirect, render_template - -app = Flask(__name__) - -@app.route("/") -def index(): - return redirect("/foo") - -@app.route("/foo") -def foo(): - return render_template("foo.html") diff --git a/tests/redirect/templates/foo.html b/tests/redirect/templates/foo.html deleted file mode 100644 index 257cc56..0000000 --- a/tests/redirect/templates/foo.html +++ /dev/null @@ -1 +0,0 @@ -foo diff --git a/tests/sqlite.py b/tests/sqlite.py deleted file mode 100644 index 05c2cea..0000000 --- a/tests/sqlite.py +++ /dev/null @@ -1,44 +0,0 @@ -import logging -import sys - -sys.path.insert(0, "../src") - -from cs50 import SQL - -logging.getLogger("cs50").disabled = False - -db = SQL("sqlite:///sqlite.db") -db.execute("SELECT 1") - -# TODO -#db.execute("SELECT * FROM Employee WHERE FirstName = ?", b'\x00') - -db.execute("SELECT * FROM Employee WHERE FirstName = ?", "' OR 1 = 1") - -db.execute("SELECT * FROM Employee WHERE FirstName IN (?)", "Andrew") -db.execute("SELECT * FROM Employee WHERE FirstName IN (?)", ["Andrew"]) -db.execute("SELECT * FROM Employee WHERE FirstName IN (?)", ("Andrew",)) -db.execute("SELECT * FROM Employee WHERE FirstName IN (?)", ["Andrew", "Nancy"]) -db.execute("SELECT * FROM Employee WHERE FirstName IN (?)", ("Andrew", "Nancy")) -db.execute("SELECT * FROM Employee WHERE FirstName IN (?)", []) -db.execute("SELECT * FROM Employee WHERE FirstName IN (?)", ()) - -db.execute("SELECT * FROM Employee WHERE FirstName = ? AND LastName = ?", "Andrew", "Adams") -db.execute("SELECT * FROM Employee WHERE FirstName = ? AND LastName = ?", ["Andrew", "Adams"]) -db.execute("SELECT * FROM Employee WHERE FirstName = ? AND LastName = ?", ("Andrew", "Adams")) - -db.execute("SELECT * FROM Employee WHERE FirstName = :1 AND LastName = :2", "Andrew", "Adams") -db.execute("SELECT * FROM Employee WHERE FirstName = :1 AND LastName = :2", ["Andrew", "Adams"]) -db.execute("SELECT * FROM Employee WHERE FirstName = :1 AND LastName = :2", ("Andrew", "Adams")) - -db.execute("SELECT * FROM Employee WHERE FirstName = ':Andrew :Adams'") - -db.execute("SELECT * FROM Employee WHERE FirstName = :first AND LastName = :last", first="Andrew", last="Adams") -db.execute("SELECT * FROM Employee WHERE FirstName = :first AND LastName = :last", {"first": "Andrew", "last": "Adams"}) - -db.execute("SELECT * FROM Employee WHERE FirstName = %s AND LastName = %s", "Andrew", "Adams") -db.execute("SELECT * FROM Employee WHERE FirstName = %s AND LastName = %s", ["Andrew", "Adams"]) -db.execute("SELECT * FROM Employee WHERE FirstName = %s AND LastName = %s", ("Andrew", "Adams")) - -db.execute("SELECT * FROM Employee WHERE FirstName = %(first)s AND LastName = %(last)s", first="Andrew", last="Adams") -db.execute("SELECT * FROM Employee WHERE FirstName = %(first)s AND LastName = %(last)s", {"first": "Andrew", "last": "Adams"}) diff --git a/tests/tb.py b/tests/tb.py deleted file mode 100644 index 3ad8175..0000000 --- a/tests/tb.py +++ /dev/null @@ -1,10 +0,0 @@ -import sys - -sys.path.insert(0, "../src") - -import cs50 -import requests - -def f(): - res = requests.get("cs50.harvard.edu") -f() diff --git a/tests/test_cs50.py b/tests/test_cs50.py new file mode 100644 index 0000000..9a0faca --- /dev/null +++ b/tests/test_cs50.py @@ -0,0 +1,141 @@ +import sys +import unittest + +from unittest.mock import patch + +from cs50.cs50 import get_string, _get_int, _get_float + + +class TestCS50(unittest.TestCase): + @patch("cs50.cs50._get_input", return_value="") + def test_get_string_empty_input(self, mock_get_input): + """Returns empty string when input is empty""" + self.assertEqual(get_string("Answer: "), "") + mock_get_input.assert_called_with("Answer: ") + + @patch("cs50.cs50._get_input", return_value="test") + def test_get_string_nonempty_input(self, mock_get_input): + """Returns the provided non-empty input""" + self.assertEqual(get_string("Answer: "), "test") + mock_get_input.assert_called_with("Answer: ") + + @patch("cs50.cs50._get_input", side_effect=EOFError) + def test_get_string_eof(self, mock_get_input): + """Returns None on EOF""" + self.assertIs(get_string("Answer: "), None) + mock_get_input.assert_called_with("Answer: ") + + def test_get_string_invalid_prompt(self): + """Raises TypeError when prompt is not str""" + with self.assertRaises(TypeError): + get_string(1) + + @patch("cs50.cs50.get_string", return_value=None) + def test_get_int_eof(self, mock_get_string): + """Returns None on EOF""" + self.assertIs(_get_int("Answer: "), None) + mock_get_string.assert_called_with("Answer: ") + + def test_get_int_valid_input(self): + """Returns the provided integer input""" + + def assert_equal(return_value, expected_value): + with patch("cs50.cs50.get_string", return_value=return_value) as mock_get_string: + self.assertEqual(_get_int("Answer: "), expected_value) + mock_get_string.assert_called_with("Answer: ") + + values = [ + ("0", 0), + ("50", 50), + ("+50", 50), + ("+42", 42), + ("-42", -42), + ("42", 42), + ] + + for return_value, expected_value in values: + assert_equal(return_value, expected_value) + + def test_get_int_invalid_input(self): + """Raises ValueError when input is invalid base-10 int""" + + def assert_raises_valueerror(return_value): + with patch("cs50.cs50.get_string", return_value=return_value) as mock_get_string: + with self.assertRaises(ValueError): + _get_int("Answer: ") + + mock_get_string.assert_called_with("Answer: ") + + return_values = [ + "++50", + "--50", + "50+", + "50-", + " 50", + " +50", + " -50", + "50 ", + "ab50", + "50ab", + "ab50ab", + ] + + for return_value in return_values: + assert_raises_valueerror(return_value) + + @patch("cs50.cs50.get_string", return_value=None) + def test_get_float_eof(self, mock_get_string): + """Returns None on EOF""" + self.assertIs(_get_float("Answer: "), None) + mock_get_string.assert_called_with("Answer: ") + + def test_get_float_valid_input(self): + """Returns the provided integer input""" + def assert_equal(return_value, expected_value): + with patch("cs50.cs50.get_string", return_value=return_value) as mock_get_string: + f = _get_float("Answer: ") + self.assertAlmostEqual(f, expected_value) + mock_get_string.assert_called_with("Answer: ") + + values = [ + (".0", 0.0), + ("0.", 0.0), + (".42", 0.42), + ("42.", 42.0), + ("50", 50.0), + ("+50", 50.0), + ("-50", -50.0), + ("+3.14", 3.14), + ("-3.14", -3.14), + ] + + for return_value, expected_value in values: + assert_equal(return_value, expected_value) + + def test_get_float_invalid_input(self): + """Raises ValueError when input is invalid float""" + + def assert_raises_valueerror(return_value): + with patch("cs50.cs50.get_string", return_value=return_value) as mock_get_string: + with self.assertRaises(ValueError): + _get_float("Answer: ") + + mock_get_string.assert_called_with("Answer: ") + + return_values = [ + ".", + "..5", + "a.5", + ".5a" + "0.5a", + "a0.42", + " .42", + "3.14 ", + "++3.14", + "3.14+", + "--3.14", + "3.14--", + ] + + for return_value in return_values: + assert_raises_valueerror(return_value)