Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 1440ae5

Browse files
author
Kareem Zidane
committedApr 14, 2021
add docstrings
1 parent a6668c0 commit 1440ae5

File tree

9 files changed

+261
-70
lines changed

9 files changed

+261
-70
lines changed
 

‎src/cs50/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
"""Exposes API and sets up logging"""
2-
31
from .cs50 import get_float, get_int, get_string
42
from .sql import SQL
53
from ._logger import _setup_logger

‎src/cs50/_logger.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
"""Sets up logging for cs50 library"""
1+
"""Sets up logging for the library.
2+
"""
23

34
import logging
45
import os.path
@@ -9,6 +10,22 @@
910
import termcolor
1011

1112

13+
def green(msg):
14+
return _colored(msg, "green")
15+
16+
17+
def red(msg):
18+
return _colored(msg, "red")
19+
20+
21+
def yellow(msg):
22+
return _colored(msg, "yellow")
23+
24+
25+
def _colored(msg, color):
26+
return termcolor.colored(str(msg), color)
27+
28+
1229
def _setup_logger():
1330
_configure_default_logger()
1431
_patch_root_handler_format_exception()
@@ -17,11 +34,16 @@ def _setup_logger():
1734

1835

1936
def _configure_default_logger():
20-
"""Configure default handler and formatter to prevent flask and werkzeug from adding theirs"""
37+
"""Configures a default handler and formatter to prevent flask and werkzeug from adding theirs.
38+
"""
39+
2140
logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.DEBUG)
2241

2342

2443
def _patch_root_handler_format_exception():
44+
"""Patches formatException for the root handler to use ``_format_exception``.
45+
"""
46+
2547
try:
2648
formatter = logging.root.handlers[0].formatter
2749
formatter.formatException = lambda exc_info: _format_exception(*exc_info)
@@ -30,6 +52,10 @@ def _patch_root_handler_format_exception():
3052

3153

3254
def _configure_cs50_logger():
55+
"""Disables the cs50 logger by default. Disables logging propagation to prevent messages from
56+
being logged more than once. Sets the logging handler and formatter.
57+
"""
58+
3359
_logger = logging.getLogger("cs50")
3460
_logger.disabled = True
3561
_logger.setLevel(logging.DEBUG)
@@ -52,9 +78,8 @@ def _patch_excepthook():
5278

5379

5480
def _format_exception(type_, value, exc_tb):
55-
"""
56-
Format traceback, darkening entries from global site-packages directories
57-
and user-specific site-packages directory.
81+
"""Formats traceback, darkening entries from global site-packages directories and user-specific
82+
site-packages directory.
5883
https://stackoverflow.com/a/46071447/5156190
5984
"""
6085

@@ -69,6 +94,5 @@ def _format_exception(type_, value, exc_tb):
6994
lines += line
7095
else:
7196
matches = re.search(r"^(\s*)(.*?)(\s*)$", line, re.DOTALL)
72-
lines.append(
73-
matches.group(1) + termcolor.colored(matches.group(2), "yellow") + matches.group(3))
97+
lines.append(matches.group(1) + yellow(matches.group(2)) + matches.group(3))
7498
return "".join(lines).rstrip()

‎src/cs50/_session.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
"""Wraps a SQLAlchemy scoped session"""
2-
31
import sqlalchemy
42
import sqlalchemy.orm
53

@@ -11,7 +9,8 @@
119

1210

1311
class Session:
14-
"""Wraps a SQLAlchemy scoped session"""
12+
"""Wraps a SQLAlchemy scoped session.
13+
"""
1514

1615
def __init__(self, url, **engine_kwargs):
1716
if is_sqlite_url(url):
@@ -20,9 +19,16 @@ def __init__(self, url, **engine_kwargs):
2019
self._session = create_session(url, **engine_kwargs)
2120

2221
def execute(self, statement):
23-
"""Converts statement to str and executes it"""
22+
"""Converts statement to str and executes it.
23+
24+
:param statement: The SQL statement to be executed
25+
"""
26+
2427
# pylint: disable=no-member
2528
return self._session.execute(sqlalchemy.text(str(statement)))
2629

2730
def __getattr__(self, attr):
31+
"""Proxies any attributes to the underlying SQLAlchemy scoped session.
32+
"""
33+
2834
return getattr(self._session, attr)

‎src/cs50/_session_util.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
"""Utility functions used by _session.py"""
1+
"""Utility functions used by _session.py.
2+
"""
23

34
import os
45
import sqlite3
@@ -49,8 +50,11 @@ def _create_scoped_session(engine):
4950

5051

5152
def _disable_auto_begin_commit(dbapi_connection):
52-
# Disable underlying API's own emitting of BEGIN and COMMIT so we can ourselves
53-
# https://docs.sqlalchemy.org/en/13/dialects/sqlite.html#serializable-isolation-savepoints-transactional-ddl
53+
"""Disables the underlying API's own emitting of BEGIN and COMMIT so we can support manual
54+
transactions.
55+
https://docs.sqlalchemy.org/en/13/dialects/sqlite.html#serializable-isolation-savepoints-transactional-ddl
56+
"""
57+
5458
dbapi_connection.isolation_level = None
5559

5660

‎src/cs50/_sql_sanitizer.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
"""Escapes SQL values"""
2-
31
import datetime
42
import re
53

@@ -8,15 +6,19 @@
86

97

108
class SQLSanitizer:
11-
"""Escapes SQL values"""
9+
"""Sanitizes SQL values.
10+
"""
1211

1312
def __init__(self, dialect):
1413
self._dialect = dialect
1514

1615
def escape(self, value):
17-
"""
18-
Escapes value using engine's conversion function.
16+
"""Escapes value using engine's conversion function.
1917
https://docs.sqlalchemy.org/en/latest/core/type_api.html#sqlalchemy.types.TypeEngine.literal_processor
18+
19+
:param value: The value to be sanitized
20+
21+
:returns: The sanitized value
2022
"""
2123
# pylint: disable=too-many-return-statements
2224
if isinstance(value, (list, tuple)):
@@ -71,13 +73,22 @@ def escape(self, value):
7173
raise RuntimeError(f"unsupported value: {value}")
7274

7375
def escape_iterable(self, iterable):
74-
"""Escapes a collection of values (e.g., list, tuple)"""
76+
"""Escapes each value in iterable and joins all the escaped values with ", ", formatted for
77+
SQL's ``IN`` operator.
78+
79+
:param: An iterable of values to be escaped
80+
81+
:returns: A comma-separated list of escaped values from ``iterable``
82+
:rtype: :class:`sqlparse.sql.TokenList`
83+
"""
84+
7585
return sqlparse.sql.TokenList(
7686
sqlparse.parse(", ".join([str(self.escape(v)) for v in iterable])))
7787

7888

7989
def escape_verbatim_colon(value):
80-
"""Escapes verbatim colon from a value so as it is not confused with a placeholder"""
90+
"""Escapes verbatim colon from a value so as it is not confused with a parameter marker.
91+
"""
8192

8293
# E.g., ':foo, ":foo, :foo will be replaced with
8394
# '\:foo, "\:foo, \:foo respectively

‎src/cs50/_sql_util.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
1-
"""Utility functions used by sql.py"""
1+
"""Utility functions used by sql.py.
2+
"""
23

34
import contextlib
45
import decimal
56
import warnings
67

78

8-
def fetch_select_result(result):
9+
def process_select_result(result):
10+
"""Converts a SQLAlchemy result to a ``list`` of ``dict`` objects, each of which represents a
11+
row in the result set.
12+
13+
:param result: A SQLAlchemy result
14+
:type result: :class:`sqlalchemy.engine.Result`
15+
"""
916
rows = [dict(row) for row in result.fetchall()]
1017
for row in rows:
1118
for column in row:
@@ -23,6 +30,9 @@ def fetch_select_result(result):
2330

2431
@contextlib.contextmanager
2532
def raise_errors_for_warnings():
33+
"""Catches warnings and raises errors instead.
34+
"""
35+
2636
with warnings.catch_warnings():
2737
warnings.simplefilter("error")
2838
yield

‎src/cs50/_statement.py

Lines changed: 69 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
"""Parses a SQL statement and replaces the placeholders with the corresponding parameters"""
2-
31
import collections
42

53
from ._sql_sanitizer import SQLSanitizer, escape_verbatim_colon
@@ -17,6 +15,13 @@
1715

1816

1917
def statement_factory(dialect):
18+
"""Creates a sanitizer for ``dialect`` and injects it into ``Statement``, exposing a simpler
19+
interface for ``Statement``.
20+
21+
:param dialect: a SQLAlchemy dialect
22+
:type dialect: :class:`sqlalchemy.engine.Dialect`
23+
"""
24+
2025
sql_sanitizer = SQLSanitizer(dialect)
2126

2227
def statement(sql, *args, **kwargs):
@@ -26,9 +31,23 @@ def statement(sql, *args, **kwargs):
2631

2732

2833
class Statement:
29-
"""Parses a SQL statement and replaces the placeholders with the corresponding parameters"""
34+
"""Parses a SQL statement and substitutes any parameter markers with their corresponding
35+
placeholders.
36+
"""
3037

3138
def __init__(self, sql_sanitizer, sql, *args, **kwargs):
39+
"""
40+
:param sql_sanitizer: The SQL sanitizer used to sanitize the parameters
41+
:type sql_sanitizer: :class:`_sql_sanitizer.SQLSanitizer`
42+
43+
:param sql: The SQL statement
44+
:type sql: str
45+
46+
:param *args: Zero or more positional parameters to be substituted for the parameter markers
47+
48+
:param *kwargs: Zero or more keyword arguments to be substituted for the parameter markers
49+
"""
50+
3251
if len(args) > 0 and len(kwargs) > 0:
3352
raise RuntimeError("cannot pass both positional and named parameters")
3453

@@ -54,9 +73,18 @@ def _get_escaped_kwargs(self, kwargs):
5473
return {k: self._sql_sanitizer.escape(v) for k, v in kwargs.items()}
5574

5675
def _tokenize(self):
76+
"""
77+
:returns: A flattened list of SQLParse tokens that represent the SQL statement
78+
"""
79+
5780
return list(self._statement.flatten())
5881

5982
def _get_operation_keyword(self):
83+
"""
84+
:returns: The operation keyword of the SQL statement (e.g., ``SELECT``, ``DELETE``, etc)
85+
:rtype: str
86+
"""
87+
6088
for token in self._statement:
6189
if is_operation_token(token.ttype):
6290
token_value = token.value.upper()
@@ -69,6 +97,11 @@ def _get_operation_keyword(self):
6997
return operation_keyword
7098

7199
def _get_paramstyle(self):
100+
"""
101+
:returns: The paramstyle used in the SQL statement (if any)
102+
:rtype: :class:_statement_util.Paramstyle``
103+
"""
104+
72105
paramstyle = None
73106
for token in self._tokens:
74107
if is_placeholder(token.ttype):
@@ -80,6 +113,11 @@ def _get_paramstyle(self):
80113
return paramstyle
81114

82115
def _default_paramstyle(self):
116+
"""
117+
:returns: If positional args were passed, returns ``Paramstyle.QMARK``; if keyword arguments
118+
were passed, returns ``Paramstyle.NAMED``; otherwise, returns ``None``
119+
"""
120+
83121
paramstyle = None
84122
if self._args:
85123
paramstyle = Paramstyle.QMARK
@@ -89,6 +127,12 @@ def _default_paramstyle(self):
89127
return paramstyle
90128

91129
def _get_placeholders(self):
130+
"""
131+
:returns: A dict that maps the index of each parameter marker in the tokens list to the name
132+
of that parameter marker (if applicable) or ``None``
133+
:rtype: dict
134+
"""
135+
92136
placeholders = collections.OrderedDict()
93137
for index, token in enumerate(self._tokens):
94138
if is_placeholder(token.ttype):
@@ -109,11 +153,18 @@ def _substitute_markers_with_escaped_params(self):
109153
self._substitute_named_or_pyformat_markers()
110154

111155
def _substitute_format_or_qmark_markers(self):
156+
"""Substitutes format or qmark parameter markers with their corresponding parameters.
157+
"""
158+
112159
self._assert_valid_arg_count()
113160
for arg_index, token_index in enumerate(self._placeholders.keys()):
114161
self._tokens[token_index] = self._args[arg_index]
115162

116163
def _assert_valid_arg_count(self):
164+
"""Raises a ``RuntimeError`` if the number of arguments does not match the number of
165+
placeholders.
166+
"""
167+
117168
if len(self._placeholders) != len(self._args):
118169
placeholders = get_human_readable_list(self._placeholders.values())
119170
args = get_human_readable_list(self._args)
@@ -123,6 +174,10 @@ def _assert_valid_arg_count(self):
123174
raise RuntimeError(f"more placeholders ({placeholders}) than values ({args})")
124175

125176
def _substitue_numeric_markers(self):
177+
"""Substitutes numeric parameter markers with their corresponding parameters. Raises a
178+
``RuntimeError`` if any parameters are missing or unused.
179+
"""
180+
126181
unused_arg_indices = set(range(len(self._args)))
127182
for token_index, num in self._placeholders.items():
128183
if num >= len(self._args):
@@ -138,6 +193,10 @@ def _substitue_numeric_markers(self):
138193
f"unused value{'' if len(unused_args) == 1 else 's'} ({unused_args})")
139194

140195
def _substitute_named_or_pyformat_markers(self):
196+
"""Substitutes named or pyformat parameter markers with their corresponding parameters.
197+
Raises a ``RuntimeError`` if any parameters are missing or unused.
198+
"""
199+
141200
unused_params = set(self._kwargs.keys())
142201
for token_index, param_name in self._placeholders.items():
143202
if param_name not in self._kwargs:
@@ -152,6 +211,10 @@ def _substitute_named_or_pyformat_markers(self):
152211
f"unused value{'' if len(unused_params) == 1 else 's'} ({joined_unused_params})")
153212

154213
def _escape_verbatim_colons(self):
214+
"""Escapes verbatim colons from string literal and identifier tokens so they aren't treated
215+
as parameter markers.
216+
"""
217+
155218
for token in self._tokens:
156219
if is_string_literal(token.ttype) or is_identifier(token.ttype):
157220
token.value = escape_verbatim_colon(token.value)
@@ -175,4 +238,7 @@ def is_update(self):
175238
return self._operation_keyword == "UPDATE"
176239

177240
def __str__(self):
241+
"""Joins the statement tokens into a string.
242+
"""
243+
178244
return "".join([str(token) for token in self._tokens])

‎src/cs50/_statement_util.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
"""Utility functions used by _statement.py"""
1+
"""Utility functions used by _statement.py.
2+
"""
23

34
import enum
45
import re
@@ -19,6 +20,9 @@
1920

2021

2122
class Paramstyle(enum.Enum):
23+
"""Represents the supported parameter marker styles.
24+
"""
25+
2226
FORMAT = enum.auto()
2327
NAMED = enum.auto()
2428
NUMERIC = enum.auto()
@@ -27,6 +31,15 @@ class Paramstyle(enum.Enum):
2731

2832

2933
def format_and_parse(sql):
34+
"""Formats and parses a SQL statement. Raises ``RuntimeError`` if ``sql`` represents more than
35+
one statement.
36+
37+
:param sql: The SQL statement to be formatted and parsed
38+
:type sql: str
39+
40+
:returns: A list of unflattened SQLParse tokens that represent the parsed statement
41+
"""
42+
3043
formatted_statements = sqlparse.format(sql, strip_comments=True).strip()
3144
parsed_statements = sqlparse.parse(formatted_statements)
3245
statement_count = len(parsed_statements)
@@ -43,6 +56,10 @@ def is_placeholder(ttype):
4356

4457

4558
def parse_placeholder(value):
59+
"""
60+
:returns: A tuple of the paramstyle and the name of the parameter marker (if any) or ``None``
61+
:rtype: tuple
62+
"""
4663
if value == "?":
4764
return Paramstyle.QMARK, None
4865

‎src/cs50/sql.py

Lines changed: 96 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,90 +1,145 @@
1-
"""Wraps SQLAlchemy"""
2-
31
import logging
42

53
import sqlalchemy
6-
import termcolor
74

5+
from ._logger import green, red, yellow
86
from ._session import Session
97
from ._statement import statement_factory
10-
from ._sql_util import fetch_select_result, raise_errors_for_warnings
8+
from ._sql_util import process_select_result, raise_errors_for_warnings
9+
1110

1211
_logger = logging.getLogger("cs50")
1312

1413

1514
class SQL:
16-
"""Wraps SQLAlchemy"""
15+
"""An API for executing SQL Statements.
16+
"""
17+
18+
def __init__(self, url):
19+
"""
20+
:param url: The database URL
21+
"""
1722

18-
def __init__(self, url, **engine_kwargs):
19-
self._session = Session(url, **engine_kwargs)
20-
dialect = self._session.get_bind().dialect
23+
self._session = Session(url)
24+
dialect = self._get_dialect()
2125
self._is_postgres = dialect.name in {"postgres", "postgresql"}
22-
self._sanitize_statement = statement_factory(dialect)
26+
self._substitute_markers_with_params = statement_factory(dialect)
2327
self._autocommit = False
2428

29+
def _get_dialect(self):
30+
return self._session.get_bind().dialect
31+
2532
def execute(self, sql, *args, **kwargs):
26-
"""Execute a SQL statement."""
27-
statement = self._sanitize_statement(sql, *args, **kwargs)
28-
if statement.is_transaction_start():
29-
self._autocommit = False
33+
"""Executes a SQL statement.
3034
31-
if self._autocommit:
32-
self._session.execute("BEGIN")
35+
:param sql: a SQL statement, possibly with parameters markers
36+
:type sql: str
37+
:param *args: zero or more positional arguments to substitute the parameter markers with
38+
:param **kwargs: zero or more keyword arguments to substitute the parameter markers with
3339
34-
result = self._execute(statement)
40+
:returns: For ``SELECT``, a :py:class:`list` of :py:class:`dict` objects, each of which
41+
represents a row in the result set; for ``INSERT``, the primary key of a newly inserted row
42+
(or ``None`` if none); for ``UPDATE``, the number of rows updated; for ``DELETE``, the
43+
number of rows deleted; for other statements, ``True``; on integrity errors, a
44+
:py:class:`ValueError` is raised, on other errors, a :py:class:`RuntimeError` is raised
3545
36-
if self._autocommit:
37-
self._session.execute("COMMIT")
46+
"""
3847

39-
if statement.is_transaction_end():
40-
self._autocommit = True
48+
statement = self._substitute_markers_with_params(sql, *args, **kwargs)
49+
if statement.is_transaction_start():
50+
self._disable_autocommit()
51+
52+
self._begin_transaction_in_autocommit_mode()
53+
result = self._execute(statement)
54+
self._commit_transaction_in_autocommit_mode()
4155

4256
if statement.is_select():
43-
ret = fetch_select_result(result)
57+
ret = process_select_result(result)
4458
elif statement.is_insert():
4559
ret = self._last_row_id_or_none(result)
4660
elif statement.is_delete() or statement.is_update():
4761
ret = result.rowcount
4862
else:
4963
ret = True
5064

51-
if self._autocommit:
52-
self._session.remove()
65+
if statement.is_transaction_end():
66+
self._enable_autocommit()
5367

68+
self._shutdown_session_in_autocommit_mode()
5469
return ret
5570

71+
def _disable_autocommit(self):
72+
self._autocommit = False
73+
74+
def _begin_transaction_in_autocommit_mode(self):
75+
if self._autocommit:
76+
self._session.execute("BEGIN")
77+
5678
def _execute(self, statement):
57-
with raise_errors_for_warnings():
58-
try:
79+
"""
80+
:param statement: a SQL statement represented as a ``str`` or a
81+
:class:`_statement.Statement`
82+
83+
:rtype: :class:`sqlalchemy.engine.Result`
84+
"""
85+
try:
86+
with raise_errors_for_warnings():
5987
result = self._session.execute(statement)
60-
except sqlalchemy.exc.IntegrityError as exc:
61-
_logger.debug(termcolor.colored(str(statement), "yellow"))
62-
if self._autocommit:
63-
self._session.remove()
64-
raise ValueError(exc.orig) from None
65-
except (sqlalchemy.exc.OperationalError, sqlalchemy.exc.ProgrammingError) as exc:
66-
self._session.remove()
67-
_logger.debug(termcolor.colored(statement, "red"))
68-
raise RuntimeError(exc.orig) from None
69-
70-
_logger.debug(termcolor.colored(str(statement), "green"))
71-
return result
88+
# E.g., failed constraint
89+
except sqlalchemy.exc.IntegrityError as exc:
90+
_logger.debug(yellow(statement))
91+
self._shutdown_session_in_autocommit_mode()
92+
raise ValueError(exc.orig) from None
93+
# E.g., connection error or syntax error
94+
except (sqlalchemy.exc.OperationalError, sqlalchemy.exc.ProgrammingError) as exc:
95+
self._shutdown_session()
96+
_logger.debug(red(statement))
97+
raise RuntimeError(exc.orig) from None
98+
99+
_logger.debug(green(statement))
100+
return result
101+
102+
def _shutdown_session_in_autocommit_mode(self):
103+
if self._autocommit:
104+
self._shutdown_session()
105+
106+
def _shutdown_session(self):
107+
self._session.remove()
108+
109+
def _commit_transaction_in_autocommit_mode(self):
110+
if self._autocommit:
111+
self._session.execute("COMMIT")
112+
113+
def _enable_autocommit(self):
114+
self._autocommit = True
72115

73116
def _last_row_id_or_none(self, result):
117+
"""
118+
:param result: A SQLAlchemy result object
119+
:type result: :class:`sqlalchemy.engine.Result`
120+
121+
:returns: The ID of the last inserted row or ``None``
122+
"""
123+
74124
if self._is_postgres:
75-
return self._get_last_val()
125+
return self._postgres_lastval()
76126
return result.lastrowid if result.rowcount == 1 else None
77127

78-
def _get_last_val(self):
128+
def _postgres_lastval(self):
79129
try:
80130
return self._session.execute("SELECT LASTVAL()").first()[0]
81131
except sqlalchemy.exc.OperationalError: # If lastval is not yet defined in this session
82132
return None
83133

84134
def init_app(self, app):
85-
"""Registers a teardown_appcontext listener to remove session and enables logging"""
135+
"""Enables logging and registers a ``teardown_appcontext`` listener to remove the session.
136+
137+
:param app: a Flask application instance
138+
:type app: :class:`flask.Flask`
139+
"""
140+
86141
@app.teardown_appcontext
87142
def _(_):
88-
self._session.remove()
143+
self._shutdown_session()
89144

90145
logging.getLogger("cs50").disabled = False

0 commit comments

Comments
 (0)
Please sign in to comment.