Skip to content

Commit c73d1d2

Browse files
authoredDec 17, 2023
Merge pull request #180 from cs50/updates
updated IO wrapper, style, version
2 parents 781c1c2 + d981368 commit c73d1d2

File tree

5 files changed

+184
-77
lines changed

5 files changed

+184
-77
lines changed
 

‎setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,5 @@
1818
package_dir={"": "src"},
1919
packages=["cs50"],
2020
url="https://github.com/cs50/python-cs50",
21-
version="9.3.1"
21+
version="9.3.2"
2222
)

‎src/cs50/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
# Import cs50_*
1010
from .cs50 import get_char, get_float, get_int, get_string
11+
1112
try:
1213
from .cs50 import get_long
1314
except ImportError:

‎src/cs50/cs50.py

+33-16
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717

1818
try:
1919
# Patch formatException
20-
logging.root.handlers[0].formatter.formatException = lambda exc_info: _formatException(*exc_info)
20+
logging.root.handlers[
21+
0
22+
].formatter.formatException = lambda exc_info: _formatException(*exc_info)
2123
except IndexError:
2224
pass
2325

@@ -37,26 +39,31 @@
3739
_logger.addHandler(handler)
3840

3941

40-
class _flushfile():
42+
class _Unbuffered:
4143
"""
4244
Disable buffering for standard output and standard error.
4345
44-
http://stackoverflow.com/a/231216
46+
https://stackoverflow.com/a/107717
47+
https://docs.python.org/3/library/io.html
4548
"""
4649

47-
def __init__(self, f):
48-
self.f = f
50+
def __init__(self, stream):
51+
self.stream = stream
4952

50-
def __getattr__(self, name):
51-
return getattr(self.f, name)
53+
def __getattr__(self, attr):
54+
return getattr(self.stream, attr)
5255

53-
def write(self, x):
54-
self.f.write(x)
55-
self.f.flush()
56+
def write(self, b):
57+
self.stream.write(b)
58+
self.stream.flush()
5659

60+
def writelines(self, lines):
61+
self.stream.writelines(lines)
62+
self.stream.flush()
5763

58-
sys.stderr = _flushfile(sys.stderr)
59-
sys.stdout = _flushfile(sys.stdout)
64+
65+
sys.stderr = _Unbuffered(sys.stderr)
66+
sys.stdout = _Unbuffered(sys.stdout)
6067

6168

6269
def _formatException(type, value, tb):
@@ -78,19 +85,29 @@ def _formatException(type, value, tb):
7885
lines += line
7986
else:
8087
matches = re.search(r"^(\s*)(.*?)(\s*)$", line, re.DOTALL)
81-
lines.append(matches.group(1) + colored(matches.group(2), "yellow") + matches.group(3))
88+
lines.append(
89+
matches.group(1)
90+
+ colored(matches.group(2), "yellow")
91+
+ matches.group(3)
92+
)
8293
return "".join(lines).rstrip()
8394

8495

85-
sys.excepthook = lambda type, value, tb: print(_formatException(type, value, tb), file=sys.stderr)
96+
sys.excepthook = lambda type, value, tb: print(
97+
_formatException(type, value, tb), file=sys.stderr
98+
)
8699

87100

88101
def eprint(*args, **kwargs):
89-
raise RuntimeError("The CS50 Library for Python no longer supports eprint, but you can use print instead!")
102+
raise RuntimeError(
103+
"The CS50 Library for Python no longer supports eprint, but you can use print instead!"
104+
)
90105

91106

92107
def get_char(prompt):
93-
raise RuntimeError("The CS50 Library for Python no longer supports get_char, but you can use get_string instead!")
108+
raise RuntimeError(
109+
"The CS50 Library for Python no longer supports get_char, but you can use get_string instead!"
110+
)
94111

95112

96113
def get_float(prompt):

‎src/cs50/flask.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pkgutil
33
import sys
44

5+
56
def _wrap_flask(f):
67
if f is None:
78
return
@@ -17,10 +18,15 @@ def _wrap_flask(f):
1718

1819
if os.getenv("CS50_IDE_TYPE") == "online":
1920
from werkzeug.middleware.proxy_fix import ProxyFix
21+
2022
_flask_init_before = f.Flask.__init__
23+
2124
def _flask_init_after(self, *args, **kwargs):
2225
_flask_init_before(self, *args, **kwargs)
23-
self.wsgi_app = ProxyFix(self.wsgi_app, x_proto=1) # For HTTPS-to-HTTP proxy
26+
self.wsgi_app = ProxyFix(
27+
self.wsgi_app, x_proto=1
28+
) # For HTTPS-to-HTTP proxy
29+
2430
f.Flask.__init__ = _flask_init_after
2531

2632

@@ -30,7 +36,7 @@ def _flask_init_after(self, *args, **kwargs):
3036

3137
# If Flask wasn't imported
3238
else:
33-
flask_loader = pkgutil.get_loader('flask')
39+
flask_loader = pkgutil.get_loader("flask")
3440
if flask_loader:
3541
_exec_module_before = flask_loader.exec_module
3642

‎src/cs50/sql.py

+141-58
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ def _enable_logging(f):
1414

1515
@functools.wraps(f)
1616
def decorator(*args, **kwargs):
17-
1817
# Infer whether Flask is installed
1918
try:
2019
import flask
@@ -71,17 +70,20 @@ def __init__(self, url, **kwargs):
7170
# Create engine, disabling SQLAlchemy's own autocommit mode raising exception if back end's module not installed;
7271
# without isolation_level, PostgreSQL warns with "there is already a transaction in progress" for our own BEGIN and
7372
# "there is no transaction in progress" for our own COMMIT
74-
self._engine = sqlalchemy.create_engine(url, **kwargs).execution_options(autocommit=False, isolation_level="AUTOCOMMIT")
73+
self._engine = sqlalchemy.create_engine(url, **kwargs).execution_options(
74+
autocommit=False, isolation_level="AUTOCOMMIT"
75+
)
7576

7677
# Get logger
7778
self._logger = logging.getLogger("cs50")
7879

7980
# Listener for connections
8081
def connect(dbapi_connection, connection_record):
81-
8282
# Enable foreign key constraints
8383
try:
84-
if isinstance(dbapi_connection, sqlite3.Connection): # If back end is sqlite
84+
if isinstance(
85+
dbapi_connection, sqlite3.Connection
86+
): # If back end is sqlite
8587
cursor = dbapi_connection.cursor()
8688
cursor.execute("PRAGMA foreign_keys=ON")
8789
cursor.close()
@@ -150,14 +152,33 @@ def execute(self, sql, *args, **kwargs):
150152
raise RuntimeError("cannot pass both positional and named parameters")
151153

152154
# Infer command from flattened statement to a single string separated by spaces
153-
full_statement = ' '.join(str(token) for token in statements[0].tokens if token.ttype in [sqlparse.tokens.Keyword, sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML])
155+
full_statement = " ".join(
156+
str(token)
157+
for token in statements[0].tokens
158+
if token.ttype
159+
in [
160+
sqlparse.tokens.Keyword,
161+
sqlparse.tokens.Keyword.DDL,
162+
sqlparse.tokens.Keyword.DML,
163+
]
164+
)
154165
full_statement = full_statement.upper()
155166

156167
# Set of possible commands
157-
commands = {"BEGIN", "CREATE VIEW", "DELETE", "INSERT", "SELECT", "START", "UPDATE"}
168+
commands = {
169+
"BEGIN",
170+
"CREATE VIEW",
171+
"DELETE",
172+
"INSERT",
173+
"SELECT",
174+
"START",
175+
"UPDATE",
176+
}
158177

159178
# Check if the full_statement starts with any command
160-
command = next((cmd for cmd in commands if full_statement.startswith(cmd)), None)
179+
command = next(
180+
(cmd for cmd in commands if full_statement.startswith(cmd)), None
181+
)
161182

162183
# Flatten statement
163184
tokens = list(statements[0].flatten())
@@ -166,10 +187,8 @@ def execute(self, sql, *args, **kwargs):
166187
placeholders = {}
167188
paramstyle = None
168189
for index, token in enumerate(tokens):
169-
170190
# If token is a placeholder
171191
if token.ttype == sqlparse.tokens.Name.Placeholder:
172-
173192
# Determine paramstyle, name
174193
_paramstyle, name = _parse_placeholder(token)
175194

@@ -186,7 +205,6 @@ def execute(self, sql, *args, **kwargs):
186205

187206
# If no placeholders
188207
if not paramstyle:
189-
190208
# Error-check like qmark if args
191209
if args:
192210
paramstyle = "qmark"
@@ -201,41 +219,55 @@ def execute(self, sql, *args, **kwargs):
201219

202220
# qmark
203221
if paramstyle == "qmark":
204-
205222
# Validate number of placeholders
206223
if len(placeholders) != len(args):
207224
if len(placeholders) < len(args):
208-
raise RuntimeError("fewer placeholders ({}) than values ({})".format(_placeholders, _args))
225+
raise RuntimeError(
226+
"fewer placeholders ({}) than values ({})".format(
227+
_placeholders, _args
228+
)
229+
)
209230
else:
210-
raise RuntimeError("more placeholders ({}) than values ({})".format(_placeholders, _args))
231+
raise RuntimeError(
232+
"more placeholders ({}) than values ({})".format(
233+
_placeholders, _args
234+
)
235+
)
211236

212237
# Escape values
213238
for i, index in enumerate(placeholders.keys()):
214239
tokens[index] = self._escape(args[i])
215240

216241
# numeric
217242
elif paramstyle == "numeric":
218-
219243
# Escape values
220244
for index, i in placeholders.items():
221245
if i >= len(args):
222-
raise RuntimeError("missing value for placeholder (:{})".format(i + 1, len(args)))
246+
raise RuntimeError(
247+
"missing value for placeholder (:{})".format(i + 1, len(args))
248+
)
223249
tokens[index] = self._escape(args[i])
224250

225251
# Check if any values unused
226252
indices = set(range(len(args))) - set(placeholders.values())
227253
if indices:
228-
raise RuntimeError("unused {} ({})".format(
229-
"value" if len(indices) == 1 else "values",
230-
", ".join([str(self._escape(args[index])) for index in indices])))
254+
raise RuntimeError(
255+
"unused {} ({})".format(
256+
"value" if len(indices) == 1 else "values",
257+
", ".join(
258+
[str(self._escape(args[index])) for index in indices]
259+
),
260+
)
261+
)
231262

232263
# named
233264
elif paramstyle == "named":
234-
235265
# Escape values
236266
for index, name in placeholders.items():
237267
if name not in kwargs:
238-
raise RuntimeError("missing value for placeholder (:{})".format(name))
268+
raise RuntimeError(
269+
"missing value for placeholder (:{})".format(name)
270+
)
239271
tokens[index] = self._escape(kwargs[name])
240272

241273
# Check if any keys unused
@@ -245,54 +277,65 @@ def execute(self, sql, *args, **kwargs):
245277

246278
# format
247279
elif paramstyle == "format":
248-
249280
# Validate number of placeholders
250281
if len(placeholders) != len(args):
251282
if len(placeholders) < len(args):
252-
raise RuntimeError("fewer placeholders ({}) than values ({})".format(_placeholders, _args))
283+
raise RuntimeError(
284+
"fewer placeholders ({}) than values ({})".format(
285+
_placeholders, _args
286+
)
287+
)
253288
else:
254-
raise RuntimeError("more placeholders ({}) than values ({})".format(_placeholders, _args))
289+
raise RuntimeError(
290+
"more placeholders ({}) than values ({})".format(
291+
_placeholders, _args
292+
)
293+
)
255294

256295
# Escape values
257296
for i, index in enumerate(placeholders.keys()):
258297
tokens[index] = self._escape(args[i])
259298

260299
# pyformat
261300
elif paramstyle == "pyformat":
262-
263301
# Escape values
264302
for index, name in placeholders.items():
265303
if name not in kwargs:
266-
raise RuntimeError("missing value for placeholder (%{}s)".format(name))
304+
raise RuntimeError(
305+
"missing value for placeholder (%{}s)".format(name)
306+
)
267307
tokens[index] = self._escape(kwargs[name])
268308

269309
# Check if any keys unused
270310
keys = kwargs.keys() - placeholders.values()
271311
if keys:
272-
raise RuntimeError("unused {} ({})".format(
273-
"value" if len(keys) == 1 else "values",
274-
", ".join(keys)))
312+
raise RuntimeError(
313+
"unused {} ({})".format(
314+
"value" if len(keys) == 1 else "values", ", ".join(keys)
315+
)
316+
)
275317

276318
# For SQL statements where a colon is required verbatim, as within an inline string, use a backslash to escape
277319
# https://docs.sqlalchemy.org/en/13/core/sqlelement.html?highlight=text#sqlalchemy.sql.expression.text
278320
for index, token in enumerate(tokens):
279-
280321
# In string literal
281322
# https://www.sqlite.org/lang_keywords.html
282-
if token.ttype in [sqlparse.tokens.Literal.String, sqlparse.tokens.Literal.String.Single]:
323+
if token.ttype in [
324+
sqlparse.tokens.Literal.String,
325+
sqlparse.tokens.Literal.String.Single,
326+
]:
283327
token.value = re.sub("(^'|\s+):", r"\1\:", token.value)
284328

285329
# In identifier
286330
# https://www.sqlite.org/lang_keywords.html
287331
elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
288-
token.value = re.sub("(^\"|\s+):", r"\1\:", token.value)
332+
token.value = re.sub('(^"|\s+):', r"\1\:", token.value)
289333

290334
# Join tokens into statement
291335
statement = "".join([str(token) for token in tokens])
292336

293337
# If no connection yet
294338
if not hasattr(_data, self._name()):
295-
296339
# Connect to database
297340
setattr(_data, self._name(), self._engine.connect())
298341

@@ -302,25 +345,33 @@ def execute(self, sql, *args, **kwargs):
302345
# Disconnect if/when a Flask app is torn down
303346
try:
304347
import flask
348+
305349
assert flask.current_app
350+
306351
def teardown_appcontext(exception):
307352
self._disconnect()
353+
308354
if teardown_appcontext not in flask.current_app.teardown_appcontext_funcs:
309355
flask.current_app.teardown_appcontext(teardown_appcontext)
310356
except (ModuleNotFoundError, AssertionError):
311357
pass
312358

313359
# Catch SQLAlchemy warnings
314360
with warnings.catch_warnings():
315-
316361
# Raise exceptions for warnings
317362
warnings.simplefilter("error")
318363

319364
# Prepare, execute statement
320365
try:
321-
322366
# Join tokens into statement, abbreviating binary data as <class 'bytes'>
323-
_statement = "".join([str(bytes) if token.ttype == sqlparse.tokens.Other else str(token) for token in tokens])
367+
_statement = "".join(
368+
[
369+
str(bytes)
370+
if token.ttype == sqlparse.tokens.Other
371+
else str(token)
372+
for token in tokens
373+
]
374+
)
324375

325376
# Check for start of transaction
326377
if command in ["BEGIN", "START"]:
@@ -342,12 +393,10 @@ def teardown_appcontext(exception):
342393

343394
# If SELECT, return result set as list of dict objects
344395
if command == "SELECT":
345-
346396
# Coerce types
347397
rows = [dict(row) for row in result.mappings().all()]
348398
for row in rows:
349399
for column in row:
350-
351400
# Coerce decimal.Decimal objects to float objects
352401
# https://groups.google.com/d/msg/sqlalchemy/0qXMYJvq8SA/oqtvMD9Uw-kJ
353402
if isinstance(row[column], decimal.Decimal):
@@ -362,15 +411,15 @@ def teardown_appcontext(exception):
362411

363412
# If INSERT, return primary key value for a newly inserted row (or None if none)
364413
elif command == "INSERT":
365-
366414
# If PostgreSQL
367415
if self._engine.url.get_backend_name() == "postgresql":
368-
369416
# Return LASTVAL() or NULL, avoiding
370417
# "(psycopg2.errors.ObjectNotInPrerequisiteState) lastval is not yet defined in this session",
371418
# a la https://stackoverflow.com/a/24186770/5156190;
372419
# cf. https://www.psycopg.org/docs/errors.html re 55000
373-
result = connection.execute(sqlalchemy.text("""
420+
result = connection.execute(
421+
sqlalchemy.text(
422+
"""
374423
CREATE OR REPLACE FUNCTION _LASTVAL()
375424
RETURNS integer LANGUAGE plpgsql
376425
AS $$
@@ -382,7 +431,9 @@ def teardown_appcontext(exception):
382431
END;
383432
END $$;
384433
SELECT _LASTVAL();
385-
"""))
434+
"""
435+
)
436+
)
386437
ret = result.first()[0]
387438

388439
# If not PostgreSQL
@@ -407,7 +458,10 @@ def teardown_appcontext(exception):
407458
raise e
408459

409460
# If user error
410-
except (sqlalchemy.exc.OperationalError, sqlalchemy.exc.ProgrammingError) as e:
461+
except (
462+
sqlalchemy.exc.OperationalError,
463+
sqlalchemy.exc.ProgrammingError,
464+
) as e:
411465
self._disconnect()
412466
self._logger.error(termcolor.colored(_statement, "red"))
413467
e = RuntimeError(e.orig)
@@ -432,7 +486,6 @@ def _escape(self, value):
432486
import sqlparse
433487

434488
def __escape(value):
435-
436489
# Lazily import
437490
import datetime
438491
import sqlalchemy
@@ -441,66 +494,91 @@ def __escape(value):
441494
if isinstance(value, bool):
442495
return sqlparse.sql.Token(
443496
sqlparse.tokens.Number,
444-
sqlalchemy.types.Boolean().literal_processor(self._engine.dialect)(value))
497+
sqlalchemy.types.Boolean().literal_processor(self._engine.dialect)(
498+
value
499+
),
500+
)
445501

446502
# bytes
447503
elif isinstance(value, bytes):
448504
if self._engine.url.get_backend_name() in ["mysql", "sqlite"]:
449-
return sqlparse.sql.Token(sqlparse.tokens.Other, f"x'{value.hex()}'") # https://dev.mysql.com/doc/refman/8.0/en/hexadecimal-literals.html
505+
return sqlparse.sql.Token(
506+
sqlparse.tokens.Other, f"x'{value.hex()}'"
507+
) # https://dev.mysql.com/doc/refman/8.0/en/hexadecimal-literals.html
450508
elif self._engine.url.get_backend_name() == "postgresql":
451-
return sqlparse.sql.Token(sqlparse.tokens.Other, f"'\\x{value.hex()}'") # https://dba.stackexchange.com/a/203359
509+
return sqlparse.sql.Token(
510+
sqlparse.tokens.Other, f"'\\x{value.hex()}'"
511+
) # https://dba.stackexchange.com/a/203359
452512
else:
453513
raise RuntimeError("unsupported value: {}".format(value))
454514

455515
# datetime.datetime
456516
elif isinstance(value, datetime.datetime):
457517
return sqlparse.sql.Token(
458518
sqlparse.tokens.String,
459-
sqlalchemy.types.String().literal_processor(self._engine.dialect)(value.strftime("%Y-%m-%d %H:%M:%S")))
519+
sqlalchemy.types.String().literal_processor(self._engine.dialect)(
520+
value.strftime("%Y-%m-%d %H:%M:%S")
521+
),
522+
)
460523

461524
# datetime.date
462525
elif isinstance(value, datetime.date):
463526
return sqlparse.sql.Token(
464527
sqlparse.tokens.String,
465-
sqlalchemy.types.String().literal_processor(self._engine.dialect)(value.strftime("%Y-%m-%d")))
528+
sqlalchemy.types.String().literal_processor(self._engine.dialect)(
529+
value.strftime("%Y-%m-%d")
530+
),
531+
)
466532

467533
# datetime.time
468534
elif isinstance(value, datetime.time):
469535
return sqlparse.sql.Token(
470536
sqlparse.tokens.String,
471-
sqlalchemy.types.String().literal_processor(self._engine.dialect)(value.strftime("%H:%M:%S")))
537+
sqlalchemy.types.String().literal_processor(self._engine.dialect)(
538+
value.strftime("%H:%M:%S")
539+
),
540+
)
472541

473542
# float
474543
elif isinstance(value, float):
475544
return sqlparse.sql.Token(
476545
sqlparse.tokens.Number,
477-
sqlalchemy.types.Float().literal_processor(self._engine.dialect)(value))
546+
sqlalchemy.types.Float().literal_processor(self._engine.dialect)(
547+
value
548+
),
549+
)
478550

479551
# int
480552
elif isinstance(value, int):
481553
return sqlparse.sql.Token(
482554
sqlparse.tokens.Number,
483-
sqlalchemy.types.Integer().literal_processor(self._engine.dialect)(value))
555+
sqlalchemy.types.Integer().literal_processor(self._engine.dialect)(
556+
value
557+
),
558+
)
484559

485560
# str
486561
elif isinstance(value, str):
487562
return sqlparse.sql.Token(
488563
sqlparse.tokens.String,
489-
sqlalchemy.types.String().literal_processor(self._engine.dialect)(value))
564+
sqlalchemy.types.String().literal_processor(self._engine.dialect)(
565+
value
566+
),
567+
)
490568

491569
# None
492570
elif value is None:
493-
return sqlparse.sql.Token(
494-
sqlparse.tokens.Keyword,
495-
sqlalchemy.null())
571+
return sqlparse.sql.Token(sqlparse.tokens.Keyword, sqlalchemy.null())
496572

497573
# Unsupported value
498574
else:
499575
raise RuntimeError("unsupported value: {}".format(value))
500576

501577
# Escape value(s), separating with commas as needed
502578
if isinstance(value, (list, tuple)):
503-
return sqlparse.sql.TokenList(sqlparse.parse(", ".join([str(__escape(v)) for v in value])))
579+
return sqlparse.sql.TokenList(
580+
sqlparse.parse(", ".join([str(__escape(v)) for v in value]))
581+
)
504582
else:
505583
return __escape(value)
506584

@@ -512,7 +590,9 @@ def _parse_exception(e):
512590
import re
513591

514592
# MySQL
515-
matches = re.search(r"^\(_mysql_exceptions\.OperationalError\) \(\d+, \"(.+)\"\)$", str(e))
593+
matches = re.search(
594+
r"^\(_mysql_exceptions\.OperationalError\) \(\d+, \"(.+)\"\)$", str(e)
595+
)
516596
if matches:
517597
return matches.group(1)
518598

@@ -538,7 +618,10 @@ def _parse_placeholder(token):
538618
import sqlparse
539619

540620
# Validate token
541-
if not isinstance(token, sqlparse.sql.Token) or token.ttype != sqlparse.tokens.Name.Placeholder:
621+
if (
622+
not isinstance(token, sqlparse.sql.Token)
623+
or token.ttype != sqlparse.tokens.Name.Placeholder
624+
):
542625
raise TypeError()
543626

544627
# qmark

0 commit comments

Comments
 (0)
Please sign in to comment.