Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 1b644dd

Browse files
committedDec 17, 2023
updated IO wrapper, style, version
1 parent 62ad33c commit 1b644dd

File tree

5 files changed

+184
-77
lines changed

5 files changed

+184
-77
lines changed
 

‎setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,5 @@
1818
package_dir={"": "src"},
1919
packages=["cs50"],
2020
url="https://github.com/cs50/python-cs50",
21-
version="9.3.0"
21+
version="9.3.1"
2222
)

‎src/cs50/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
# Import cs50_*
1010
from .cs50 import get_char, get_float, get_int, get_string
11+
1112
try:
1213
from .cs50 import get_long
1314
except ImportError:

‎src/cs50/cs50.py

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717

1818
try:
1919
# Patch formatException
20-
logging.root.handlers[0].formatter.formatException = lambda exc_info: _formatException(*exc_info)
20+
logging.root.handlers[
21+
0
22+
].formatter.formatException = lambda exc_info: _formatException(*exc_info)
2123
except IndexError:
2224
pass
2325

@@ -37,26 +39,31 @@
3739
_logger.addHandler(handler)
3840

3941

40-
class _flushfile():
42+
class _Unbuffered:
4143
"""
4244
Disable buffering for standard output and standard error.
4345
44-
http://stackoverflow.com/a/231216
46+
https://stackoverflow.com/a/107717
47+
https://docs.python.org/3/library/io.html
4548
"""
4649

47-
def __init__(self, f):
48-
self.f = f
50+
def __init__(self, stream):
51+
self.stream = stream
4952

50-
def __getattr__(self, name):
51-
return getattr(self.f, name)
53+
def __getattr__(self, attr):
54+
return getattr(self.stream, attr)
5255

53-
def write(self, x):
54-
self.f.write(x)
55-
self.f.flush()
56+
def write(self, b):
57+
self.stream.write(b)
58+
self.stream.flush()
5659

60+
def writelines(self, lines):
61+
self.stream.writelines(lines)
62+
self.stream.flush()
5763

58-
sys.stderr = _flushfile(sys.stderr)
59-
sys.stdout = _flushfile(sys.stdout)
64+
65+
sys.stderr = _Unbuffered(sys.stderr)
66+
sys.stdout = _Unbuffered(sys.stdout)
6067

6168

6269
def _formatException(type, value, tb):
@@ -78,19 +85,29 @@ def _formatException(type, value, tb):
7885
lines += line
7986
else:
8087
matches = re.search(r"^(\s*)(.*?)(\s*)$", line, re.DOTALL)
81-
lines.append(matches.group(1) + colored(matches.group(2), "yellow") + matches.group(3))
88+
lines.append(
89+
matches.group(1)
90+
+ colored(matches.group(2), "yellow")
91+
+ matches.group(3)
92+
)
8293
return "".join(lines).rstrip()
8394

8495

85-
sys.excepthook = lambda type, value, tb: print(_formatException(type, value, tb), file=sys.stderr)
96+
sys.excepthook = lambda type, value, tb: print(
97+
_formatException(type, value, tb), file=sys.stderr
98+
)
8699

87100

88101
def eprint(*args, **kwargs):
89-
raise RuntimeError("The CS50 Library for Python no longer supports eprint, but you can use print instead!")
102+
raise RuntimeError(
103+
"The CS50 Library for Python no longer supports eprint, but you can use print instead!"
104+
)
90105

91106

92107
def get_char(prompt):
93-
raise RuntimeError("The CS50 Library for Python no longer supports get_char, but you can use get_string instead!")
108+
raise RuntimeError(
109+
"The CS50 Library for Python no longer supports get_char, but you can use get_string instead!"
110+
)
94111

95112

96113
def get_float(prompt):

‎src/cs50/flask.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pkgutil
33
import sys
44

5+
56
def _wrap_flask(f):
67
if f is None:
78
return
@@ -17,10 +18,15 @@ def _wrap_flask(f):
1718

1819
if os.getenv("CS50_IDE_TYPE") == "online":
1920
from werkzeug.middleware.proxy_fix import ProxyFix
21+
2022
_flask_init_before = f.Flask.__init__
23+
2124
def _flask_init_after(self, *args, **kwargs):
2225
_flask_init_before(self, *args, **kwargs)
23-
self.wsgi_app = ProxyFix(self.wsgi_app, x_proto=1) # For HTTPS-to-HTTP proxy
26+
self.wsgi_app = ProxyFix(
27+
self.wsgi_app, x_proto=1
28+
) # For HTTPS-to-HTTP proxy
29+
2430
f.Flask.__init__ = _flask_init_after
2531

2632

@@ -30,7 +36,7 @@ def _flask_init_after(self, *args, **kwargs):
3036

3137
# If Flask wasn't imported
3238
else:
33-
flask_loader = pkgutil.get_loader('flask')
39+
flask_loader = pkgutil.get_loader("flask")
3440
if flask_loader:
3541
_exec_module_before = flask_loader.exec_module
3642

‎src/cs50/sql.py

Lines changed: 141 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ def _enable_logging(f):
1414

1515
@functools.wraps(f)
1616
def decorator(*args, **kwargs):
17-
1817
# Infer whether Flask is installed
1918
try:
2019
import flask
@@ -71,17 +70,20 @@ def __init__(self, url, **kwargs):
7170
# Create engine, disabling SQLAlchemy's own autocommit mode raising exception if back end's module not installed;
7271
# without isolation_level, PostgreSQL warns with "there is already a transaction in progress" for our own BEGIN and
7372
# "there is no transaction in progress" for our own COMMIT
74-
self._engine = sqlalchemy.create_engine(url, **kwargs).execution_options(autocommit=False, isolation_level="AUTOCOMMIT")
73+
self._engine = sqlalchemy.create_engine(url, **kwargs).execution_options(
74+
autocommit=False, isolation_level="AUTOCOMMIT"
75+
)
7576

7677
# Get logger
7778
self._logger = logging.getLogger("cs50")
7879

7980
# Listener for connections
8081
def connect(dbapi_connection, connection_record):
81-
8282
# Enable foreign key constraints
8383
try:
84-
if isinstance(dbapi_connection, sqlite3.Connection): # If back end is sqlite
84+
if isinstance(
85+
dbapi_connection, sqlite3.Connection
86+
): # If back end is sqlite
8587
cursor = dbapi_connection.cursor()
8688
cursor.execute("PRAGMA foreign_keys=ON")
8789
cursor.close()
@@ -150,14 +152,33 @@ def execute(self, sql, *args, **kwargs):
150152
raise RuntimeError("cannot pass both positional and named parameters")
151153

152154
# Infer command from flattened statement to a single string separated by spaces
153-
full_statement = ' '.join(str(token) for token in statements[0].tokens if token.ttype in [sqlparse.tokens.Keyword, sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML])
155+
full_statement = " ".join(
156+
str(token)
157+
for token in statements[0].tokens
158+
if token.ttype
159+
in [
160+
sqlparse.tokens.Keyword,
161+
sqlparse.tokens.Keyword.DDL,
162+
sqlparse.tokens.Keyword.DML,
163+
]
164+
)
154165
full_statement = full_statement.upper()
155166

156167
# Set of possible commands
157-
commands = {"BEGIN", "CREATE VIEW", "DELETE", "INSERT", "SELECT", "START", "UPDATE"}
168+
commands = {
169+
"BEGIN",
170+
"CREATE VIEW",
171+
"DELETE",
172+
"INSERT",
173+
"SELECT",
174+
"START",
175+
"UPDATE",
176+
}
158177

159178
# Check if the full_statement starts with any command
160-
command = next((cmd for cmd in commands if full_statement.startswith(cmd)), None)
179+
command = next(
180+
(cmd for cmd in commands if full_statement.startswith(cmd)), None
181+
)
161182

162183
# Flatten statement
163184
tokens = list(statements[0].flatten())
@@ -166,10 +187,8 @@ def execute(self, sql, *args, **kwargs):
166187
placeholders = {}
167188
paramstyle = None
168189
for index, token in enumerate(tokens):
169-
170190
# If token is a placeholder
171191
if token.ttype == sqlparse.tokens.Name.Placeholder:
172-
173192
# Determine paramstyle, name
174193
_paramstyle, name = _parse_placeholder(token)
175194

@@ -186,7 +205,6 @@ def execute(self, sql, *args, **kwargs):
186205

187206
# If no placeholders
188207
if not paramstyle:
189-
190208
# Error-check like qmark if args
191209
if args:
192210
paramstyle = "qmark"
@@ -201,41 +219,55 @@ def execute(self, sql, *args, **kwargs):
201219

202220
# qmark
203221
if paramstyle == "qmark":
204-
205222
# Validate number of placeholders
206223
if len(placeholders) != len(args):
207224
if len(placeholders) < len(args):
208-
raise RuntimeError("fewer placeholders ({}) than values ({})".format(_placeholders, _args))
225+
raise RuntimeError(
226+
"fewer placeholders ({}) than values ({})".format(
227+
_placeholders, _args
228+
)
229+
)
209230
else:
210-
raise RuntimeError("more placeholders ({}) than values ({})".format(_placeholders, _args))
231+
raise RuntimeError(
232+
"more placeholders ({}) than values ({})".format(
233+
_placeholders, _args
234+
)
235+
)
211236

212237
# Escape values
213238
for i, index in enumerate(placeholders.keys()):
214239
tokens[index] = self._escape(args[i])
215240

216241
# numeric
217242
elif paramstyle == "numeric":
218-
219243
# Escape values
220244
for index, i in placeholders.items():
221245
if i >= len(args):
222-
raise RuntimeError("missing value for placeholder (:{})".format(i + 1, len(args)))
246+
raise RuntimeError(
247+
"missing value for placeholder (:{})".format(i + 1, len(args))
248+
)
223249
tokens[index] = self._escape(args[i])
224250

225251
# Check if any values unused
226252
indices = set(range(len(args))) - set(placeholders.values())
227253
if indices:
228-
raise RuntimeError("unused {} ({})".format(
229-
"value" if len(indices) == 1 else "values",
230-
", ".join([str(self._escape(args[index])) for index in indices])))
254+
raise RuntimeError(
255+
"unused {} ({})".format(
256+
"value" if len(indices) == 1 else "values",
257+
", ".join(
258+
[str(self._escape(args[index])) for index in indices]
259+
),
260+
)
261+
)
231262

232263
# named
233264
elif paramstyle == "named":
234-
235265
# Escape values
236266
for index, name in placeholders.items():
237267
if name not in kwargs:
238-
raise RuntimeError("missing value for placeholder (:{})".format(name))
268+
raise RuntimeError(
269+
"missing value for placeholder (:{})".format(name)
270+
)
239271
tokens[index] = self._escape(kwargs[name])
240272

241273
# Check if any keys unused
@@ -245,54 +277,65 @@ def execute(self, sql, *args, **kwargs):
245277

246278
# format
247279
elif paramstyle == "format":
248-
249280
# Validate number of placeholders
250281
if len(placeholders) != len(args):
251282
if len(placeholders) < len(args):
252-
raise RuntimeError("fewer placeholders ({}) than values ({})".format(_placeholders, _args))
283+
raise RuntimeError(
284+
"fewer placeholders ({}) than values ({})".format(
285+
_placeholders, _args
286+
)
287+
)
253288
else:
254-
raise RuntimeError("more placeholders ({}) than values ({})".format(_placeholders, _args))
289+
raise RuntimeError(
290+
"more placeholders ({}) than values ({})".format(
291+
_placeholders, _args
292+
)
293+
)
255294

256295
# Escape values
257296
for i, index in enumerate(placeholders.keys()):
258297
tokens[index] = self._escape(args[i])
259298

260299
# pyformat
261300
elif paramstyle == "pyformat":
262-
263301
# Escape values
264302
for index, name in placeholders.items():
265303
if name not in kwargs:
266-
raise RuntimeError("missing value for placeholder (%{}s)".format(name))
304+
raise RuntimeError(
305+
"missing value for placeholder (%{}s)".format(name)
306+
)
267307
tokens[index] = self._escape(kwargs[name])
268308

269309
# Check if any keys unused
270310
keys = kwargs.keys() - placeholders.values()
271311
if keys:
272-
raise RuntimeError("unused {} ({})".format(
273-
"value" if len(keys) == 1 else "values",
274-
", ".join(keys)))
312+
raise RuntimeError(
313+
"unused {} ({})".format(
314+
"value" if len(keys) == 1 else "values", ", ".join(keys)
315+
)
316+
)
275317

276318
# For SQL statements where a colon is required verbatim, as within an inline string, use a backslash to escape
277319
# https://docs.sqlalchemy.org/en/13/core/sqlelement.html?highlight=text#sqlalchemy.sql.expression.text
278320
for index, token in enumerate(tokens):
279-
280321
# In string literal
281322
# https://www.sqlite.org/lang_keywords.html
282-
if token.ttype in [sqlparse.tokens.Literal.String, sqlparse.tokens.Literal.String.Single]:
323+
if token.ttype in [
324+
sqlparse.tokens.Literal.String,
325+
sqlparse.tokens.Literal.String.Single,
326+
]:
283327
token.value = re.sub("(^'|\s+):", r"\1\:", token.value)
284328

285329
# In identifier
286330
# https://www.sqlite.org/lang_keywords.html
287331
elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
288-
token.value = re.sub("(^\"|\s+):", r"\1\:", token.value)
332+
token.value = re.sub('(^"|\s+):', r"\1\:", token.value)
289333

290334
# Join tokens into statement
291335
statement = "".join([str(token) for token in tokens])
292336

293337
# If no connection yet
294338
if not hasattr(_data, self._name()):
295-
296339
# Connect to database
297340
setattr(_data, self._name(), self._engine.connect())
298341

@@ -302,25 +345,33 @@ def execute(self, sql, *args, **kwargs):
302345
# Disconnect if/when a Flask app is torn down
303346
try:
304347
import flask
348+
305349
assert flask.current_app
350+
306351
def teardown_appcontext(exception):
307352
self._disconnect()
353+
308354
if teardown_appcontext not in flask.current_app.teardown_appcontext_funcs:
309355
flask.current_app.teardown_appcontext(teardown_appcontext)
310356
except (ModuleNotFoundError, AssertionError):
311357
pass
312358

313359
# Catch SQLAlchemy warnings
314360
with warnings.catch_warnings():
315-
316361
# Raise exceptions for warnings
317362
warnings.simplefilter("error")
318363

319364
# Prepare, execute statement
320365
try:
321-
322366
# Join tokens into statement, abbreviating binary data as <class 'bytes'>
323-
_statement = "".join([str(bytes) if token.ttype == sqlparse.tokens.Other else str(token) for token in tokens])
367+
_statement = "".join(
368+
[
369+
str(bytes)
370+
if token.ttype == sqlparse.tokens.Other
371+
else str(token)
372+
for token in tokens
373+
]
374+
)
324375

325376
# Check for start of transaction
326377
if command in ["BEGIN", "START"]:
@@ -342,12 +393,10 @@ def teardown_appcontext(exception):
342393

343394
# If SELECT, return result set as list of dict objects
344395
if command == "SELECT":
345-
346396
# Coerce types
347397
rows = [dict(row) for row in result.mappings().all()]
348398
for row in rows:
349399
for column in row:
350-
351400
# Coerce decimal.Decimal objects to float objects
352401
# https://groups.google.com/d/msg/sqlalchemy/0qXMYJvq8SA/oqtvMD9Uw-kJ
353402
if isinstance(row[column], decimal.Decimal):
@@ -362,15 +411,15 @@ def teardown_appcontext(exception):
362411

363412
# If INSERT, return primary key value for a newly inserted row (or None if none)
364413
elif command == "INSERT":
365-
366414
# If PostgreSQL
367415
if self._engine.url.get_backend_name() == "postgresql":
368-
369416
# Return LASTVAL() or NULL, avoiding
370417
# "(psycopg2.errors.ObjectNotInPrerequisiteState) lastval is not yet defined in this session",
371418
# a la https://stackoverflow.com/a/24186770/5156190;
372419
# cf. https://www.psycopg.org/docs/errors.html re 55000
373-
result = connection.execute(sqlalchemy.text("""
420+
result = connection.execute(
421+
sqlalchemy.text(
422+
"""
374423
CREATE OR REPLACE FUNCTION _LASTVAL()
375424
RETURNS integer LANGUAGE plpgsql
376425
AS $$
@@ -382,7 +431,9 @@ def teardown_appcontext(exception):
382431
END;
383432
END $$;
384433
SELECT _LASTVAL();
385-
"""))
434+
"""
435+
)
436+
)
386437
ret = result.first()[0]
387438

388439
# If not PostgreSQL
@@ -405,7 +456,10 @@ def teardown_appcontext(exception):
405456
raise e
406457

407458
# If user error
408-
except (sqlalchemy.exc.OperationalError, sqlalchemy.exc.ProgrammingError) as e:
459+
except (
460+
sqlalchemy.exc.OperationalError,
461+
sqlalchemy.exc.ProgrammingError,
462+
) as e:
409463
self._disconnect()
410464
self._logger.error(termcolor.colored(_statement, "red"))
411465
e = RuntimeError(e.orig)
@@ -430,7 +484,6 @@ def _escape(self, value):
430484
import sqlparse
431485

432486
def __escape(value):
433-
434487
# Lazily import
435488
import datetime
436489
import sqlalchemy
@@ -439,66 +492,91 @@ def __escape(value):
439492
if isinstance(value, bool):
440493
return sqlparse.sql.Token(
441494
sqlparse.tokens.Number,
442-
sqlalchemy.types.Boolean().literal_processor(self._engine.dialect)(value))
495+
sqlalchemy.types.Boolean().literal_processor(self._engine.dialect)(
496+
value
497+
),
498+
)
443499

444500
# bytes
445501
elif isinstance(value, bytes):
446502
if self._engine.url.get_backend_name() in ["mysql", "sqlite"]:
447-
return sqlparse.sql.Token(sqlparse.tokens.Other, f"x'{value.hex()}'") # https://dev.mysql.com/doc/refman/8.0/en/hexadecimal-literals.html
503+
return sqlparse.sql.Token(
504+
sqlparse.tokens.Other, f"x'{value.hex()}'"
505+
) # https://dev.mysql.com/doc/refman/8.0/en/hexadecimal-literals.html
448506
elif self._engine.url.get_backend_name() == "postgresql":
449-
return sqlparse.sql.Token(sqlparse.tokens.Other, f"'\\x{value.hex()}'") # https://dba.stackexchange.com/a/203359
507+
return sqlparse.sql.Token(
508+
sqlparse.tokens.Other, f"'\\x{value.hex()}'"
509+
) # https://dba.stackexchange.com/a/203359
450510
else:
451511
raise RuntimeError("unsupported value: {}".format(value))
452512

453513
# datetime.datetime
454514
elif isinstance(value, datetime.datetime):
455515
return sqlparse.sql.Token(
456516
sqlparse.tokens.String,
457-
sqlalchemy.types.String().literal_processor(self._engine.dialect)(value.strftime("%Y-%m-%d %H:%M:%S")))
517+
sqlalchemy.types.String().literal_processor(self._engine.dialect)(
518+
value.strftime("%Y-%m-%d %H:%M:%S")
519+
),
520+
)
458521

459522
# datetime.date
460523
elif isinstance(value, datetime.date):
461524
return sqlparse.sql.Token(
462525
sqlparse.tokens.String,
463-
sqlalchemy.types.String().literal_processor(self._engine.dialect)(value.strftime("%Y-%m-%d")))
526+
sqlalchemy.types.String().literal_processor(self._engine.dialect)(
527+
value.strftime("%Y-%m-%d")
528+
),
529+
)
464530

465531
# datetime.time
466532
elif isinstance(value, datetime.time):
467533
return sqlparse.sql.Token(
468534
sqlparse.tokens.String,
469-
sqlalchemy.types.String().literal_processor(self._engine.dialect)(value.strftime("%H:%M:%S")))
535+
sqlalchemy.types.String().literal_processor(self._engine.dialect)(
536+
value.strftime("%H:%M:%S")
537+
),
538+
)
470539

471540
# float
472541
elif isinstance(value, float):
473542
return sqlparse.sql.Token(
474543
sqlparse.tokens.Number,
475-
sqlalchemy.types.Float().literal_processor(self._engine.dialect)(value))
544+
sqlalchemy.types.Float().literal_processor(self._engine.dialect)(
545+
value
546+
),
547+
)
476548

477549
# int
478550
elif isinstance(value, int):
479551
return sqlparse.sql.Token(
480552
sqlparse.tokens.Number,
481-
sqlalchemy.types.Integer().literal_processor(self._engine.dialect)(value))
553+
sqlalchemy.types.Integer().literal_processor(self._engine.dialect)(
554+
value
555+
),
556+
)
482557

483558
# str
484559
elif isinstance(value, str):
485560
return sqlparse.sql.Token(
486561
sqlparse.tokens.String,
487-
sqlalchemy.types.String().literal_processor(self._engine.dialect)(value))
562+
sqlalchemy.types.String().literal_processor(self._engine.dialect)(
563+
value
564+
),
565+
)
488566

489567
# None
490568
elif value is None:
491-
return sqlparse.sql.Token(
492-
sqlparse.tokens.Keyword,
493-
sqlalchemy.null())
569+
return sqlparse.sql.Token(sqlparse.tokens.Keyword, sqlalchemy.null())
494570

495571
# Unsupported value
496572
else:
497573
raise RuntimeError("unsupported value: {}".format(value))
498574

499575
# Escape value(s), separating with commas as needed
500576
if isinstance(value, (list, tuple)):
501-
return sqlparse.sql.TokenList(sqlparse.parse(", ".join([str(__escape(v)) for v in value])))
577+
return sqlparse.sql.TokenList(
578+
sqlparse.parse(", ".join([str(__escape(v)) for v in value]))
579+
)
502580
else:
503581
return __escape(value)
504582

@@ -510,7 +588,9 @@ def _parse_exception(e):
510588
import re
511589

512590
# MySQL
513-
matches = re.search(r"^\(_mysql_exceptions\.OperationalError\) \(\d+, \"(.+)\"\)$", str(e))
591+
matches = re.search(
592+
r"^\(_mysql_exceptions\.OperationalError\) \(\d+, \"(.+)\"\)$", str(e)
593+
)
514594
if matches:
515595
return matches.group(1)
516596

@@ -536,7 +616,10 @@ def _parse_placeholder(token):
536616
import sqlparse
537617

538618
# Validate token
539-
if not isinstance(token, sqlparse.sql.Token) or token.ttype != sqlparse.tokens.Name.Placeholder:
619+
if (
620+
not isinstance(token, sqlparse.sql.Token)
621+
or token.ttype != sqlparse.tokens.Name.Placeholder
622+
):
540623
raise TypeError()
541624

542625
# qmark

0 commit comments

Comments
 (0)
Please sign in to comment.