Skip to content

Commit 1b671e2

Browse files
author
Kareem Zidane
committedApr 9, 2021
refactor, fix scoped session
1 parent 599d968 commit 1b671e2

File tree

8 files changed

+684
-638
lines changed

8 files changed

+684
-638
lines changed
 

‎setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,5 @@
1616
package_dir={"": "src"},
1717
packages=["cs50"],
1818
url="https://github.com/cs50/python-cs50",
19-
version="6.0.4"
19+
version="7.0.0"
2020
)

‎src/cs50/__init__.py

+3-17
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,6 @@
1-
import logging
2-
import os
3-
import sys
1+
from ._logger import _setup_logger
2+
_setup_logger()
43

5-
6-
# Disable cs50 logger by default
7-
logging.getLogger("cs50").disabled = True
8-
9-
# Import cs50_*
10-
from .cs50 import get_char, get_float, get_int, get_string
11-
try:
12-
from .cs50 import get_long
13-
except ImportError:
14-
pass
15-
16-
# Hook into flask importing
4+
from .cs50 import get_float, get_int, get_string
175
from . import flask
18-
19-
# Wrap SQLAlchemy
206
from .sql import SQL

‎src/cs50/_logger.py

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import logging
2+
import os.path
3+
import re
4+
import sys
5+
import traceback
6+
7+
import termcolor
8+
9+
10+
def _setup_logger():
11+
_logger = logging.getLogger("cs50")
12+
_logger.disabled = True
13+
_logger.setLevel(logging.DEBUG)
14+
15+
# Log messages once
16+
_logger.propagate = False
17+
18+
handler = logging.StreamHandler()
19+
handler.setLevel(logging.DEBUG)
20+
21+
formatter = logging.Formatter("%(levelname)s: %(message)s")
22+
formatter.formatException = lambda exc_info: _formatException(*exc_info)
23+
handler.setFormatter(formatter)
24+
_logger.addHandler(handler)
25+
26+
27+
def _formatException(type, value, tb):
28+
"""
29+
Format traceback, darkening entries from global site-packages directories
30+
and user-specific site-packages directory.
31+
https://stackoverflow.com/a/46071447/5156190
32+
"""
33+
34+
# Absolute paths to site-packages
35+
packages = tuple(os.path.join(os.path.abspath(p), "") for p in sys.path[1:])
36+
37+
# Highlight lines not referring to files in site-packages
38+
lines = []
39+
for line in traceback.format_exception(type, value, tb):
40+
matches = re.search(r"^ File \"([^\"]+)\", line \d+, in .+", line)
41+
if matches and matches.group(1).startswith(packages):
42+
lines += line
43+
else:
44+
matches = re.search(r"^(\s*)(.*?)(\s*)$", line, re.DOTALL)
45+
lines.append(matches.group(1) + termcolor.colored(matches.group(2), "yellow") + matches.group(3))
46+
return "".join(lines).rstrip()
47+
48+

‎src/cs50/_session.py

+80
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import os
2+
3+
import sqlalchemy
4+
import sqlalchemy.orm
5+
import sqlite3
6+
7+
class Session:
8+
def __init__(self, url, **engine_kwargs):
9+
self._url = url
10+
if _is_sqlite_url(self._url):
11+
_assert_sqlite_file_exists(self._url)
12+
13+
self._engine = _create_engine(self._url, **engine_kwargs)
14+
self._is_postgres = self._engine.url.get_backend_name() in {"postgres", "postgresql"}
15+
_setup_on_connect(self._engine)
16+
self._session = _create_scoped_session(self._engine)
17+
18+
19+
def is_postgres(self):
20+
return self._is_postgres
21+
22+
23+
def execute(self, statement):
24+
return self._session.execute(sqlalchemy.text(str(statement)))
25+
26+
27+
def __getattr__(self, attr):
28+
return getattr(self._session, attr)
29+
30+
31+
def _is_sqlite_url(url):
32+
return url.startswith("sqlite:///")
33+
34+
35+
def _assert_sqlite_file_exists(url):
36+
path = url[len("sqlite:///"):]
37+
if not os.path.exists(path):
38+
raise RuntimeError(f"does not exist: {path}")
39+
if not os.path.isfile(path):
40+
raise RuntimeError(f"not a file: {path}")
41+
42+
43+
def _create_engine(url, **kwargs):
44+
try:
45+
engine = sqlalchemy.create_engine(url, **kwargs)
46+
except sqlalchemy.exc.ArgumentError:
47+
raise RuntimeError(f"invalid URL: {url}") from None
48+
49+
engine.execution_options(autocommit=False)
50+
return engine
51+
52+
53+
def _setup_on_connect(engine):
54+
def connect(dbapi_connection, _):
55+
_disable_auto_begin_commit(dbapi_connection)
56+
if _is_sqlite_connection(dbapi_connection):
57+
_enable_sqlite_foreign_key_constraints(dbapi_connection)
58+
59+
sqlalchemy.event.listen(engine, "connect", connect)
60+
61+
62+
def _create_scoped_session(engine):
63+
session_factory = sqlalchemy.orm.sessionmaker(bind=engine)
64+
return sqlalchemy.orm.scoping.scoped_session(session_factory)
65+
66+
67+
def _disable_auto_begin_commit(dbapi_connection):
68+
# Disable underlying API's own emitting of BEGIN and COMMIT so we can ourselves
69+
# https://docs.sqlalchemy.org/en/13/dialects/sqlite.html#serializable-isolation-savepoints-transactional-ddl
70+
dbapi_connection.isolation_level = None
71+
72+
73+
def _is_sqlite_connection(dbapi_connection):
74+
return isinstance(dbapi_connection, sqlite3.Connection)
75+
76+
77+
def _enable_sqlite_foreign_key_constraints(dbapi_connection):
78+
cursor = dbapi_connection.cursor()
79+
cursor.execute("PRAGMA foreign_keys=ON")
80+
cursor.close()

‎src/cs50/_statement.py

+269
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
import collections
2+
import datetime
3+
import enum
4+
import re
5+
6+
import sqlalchemy
7+
import sqlparse
8+
9+
10+
class Statement:
11+
def __init__(self, dialect, sql, *args, **kwargs):
12+
if len(args) > 0 and len(kwargs) > 0:
13+
raise RuntimeError("cannot pass both positional and named parameters")
14+
15+
self._dialect = dialect
16+
self._sql = sql
17+
self._args = args
18+
self._kwargs = kwargs
19+
20+
self._statement = self._parse()
21+
self._command = self._parse_command()
22+
self._tokens = self._bind_params()
23+
24+
def _parse(self):
25+
formatted_statements = sqlparse.format(self._sql, strip_comments=True).strip()
26+
parsed_statements = sqlparse.parse(formatted_statements)
27+
num_of_statements = len(parsed_statements)
28+
if num_of_statements == 0:
29+
raise RuntimeError("missing statement")
30+
elif num_of_statements > 1:
31+
raise RuntimeError("too many statements at once")
32+
33+
return parsed_statements[0]
34+
35+
36+
def _parse_command(self):
37+
for token in self._statement:
38+
if token.ttype in [sqlparse.tokens.Keyword, sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML]:
39+
token_value = token.value.upper()
40+
if token_value in ["BEGIN", "DELETE", "INSERT", "SELECT", "START", "UPDATE"]:
41+
command = token_value
42+
break
43+
else:
44+
command = None
45+
46+
return command
47+
48+
49+
def _bind_params(self):
50+
tokens = self._tokenize()
51+
paramstyle, placeholders = self._parse_placeholders(tokens)
52+
if paramstyle in [Paramstyle.FORMAT, Paramstyle.QMARK]:
53+
tokens = self._bind_format_or_qmark(placeholders, tokens)
54+
elif paramstyle == Paramstyle.NUMERIC:
55+
tokens = self._bind_numeric(placeholders, tokens)
56+
if paramstyle in [Paramstyle.NAMED, Paramstyle.PYFORMAT]:
57+
tokens = self._bind_named_or_pyformat(placeholders, tokens)
58+
59+
tokens = _escape_verbatim_colons(tokens)
60+
return tokens
61+
62+
63+
def _tokenize(self):
64+
return list(self._statement.flatten())
65+
66+
67+
def _parse_placeholders(self, tokens):
68+
paramstyle = None
69+
placeholders = collections.OrderedDict()
70+
for index, token in enumerate(tokens):
71+
if _is_placeholder(token):
72+
_paramstyle, name = _parse_placeholder(token)
73+
if paramstyle is None:
74+
paramstyle = _paramstyle
75+
elif _paramstyle != paramstyle:
76+
raise RuntimeError("inconsistent paramstyle")
77+
78+
placeholders[index] = name
79+
80+
if paramstyle is None:
81+
paramstyle = self._default_paramstyle()
82+
83+
return paramstyle, placeholders
84+
85+
86+
def _default_paramstyle(self):
87+
paramstyle = None
88+
if self._args:
89+
paramstyle = Paramstyle.QMARK
90+
elif self._kwargs:
91+
paramstyle = Paramstyle.NAMED
92+
93+
return paramstyle
94+
95+
96+
def _bind_format_or_qmark(self, placeholders, tokens):
97+
if len(placeholders) != len(self._args):
98+
_placeholders = ", ".join([str(token) for token in placeholders.values()])
99+
_args = ", ".join([str(self._escape(arg)) for arg in self._args])
100+
if len(placeholders) < len(self._args):
101+
raise RuntimeError(f"fewer placeholders ({_placeholders}) than values ({_args})")
102+
103+
raise RuntimeError(f"more placeholders ({_placeholders}) than values ({_args})")
104+
105+
for arg_index, token_index in enumerate(placeholders.keys()):
106+
tokens[token_index] = self._escape(self._args[arg_index])
107+
108+
return tokens
109+
110+
111+
def _bind_numeric(self, placeholders, tokens):
112+
unused_arg_indices = set(range(len(self._args)))
113+
for token_index, num in placeholders.items():
114+
if num >= len(self._args):
115+
raise RuntimeError(f"missing value for placeholder ({num + 1})")
116+
117+
tokens[token_index] = self._escape(self._args[num])
118+
unused_arg_indices.remove(num)
119+
120+
if len(unused_arg_indices) > 0:
121+
unused_args = ", ".join([str(self._escape(self._args[i])) for i in sorted(unused_arg_indices)])
122+
raise RuntimeError(f"unused value{'' if len(unused_arg_indices) == 1 else 's'} ({unused_args})")
123+
124+
return tokens
125+
126+
127+
def _bind_named_or_pyformat(self, placeholders, tokens):
128+
unused_params = set(self._kwargs.keys())
129+
for token_index, param_name in placeholders.items():
130+
if param_name not in self._kwargs:
131+
raise RuntimeError(f"missing value for placeholder ({param_name})")
132+
133+
tokens[token_index] = self._escape(self._kwargs[param_name])
134+
unused_params.remove(param_name)
135+
136+
if len(unused_params) > 0:
137+
raise RuntimeError("unused value{'' if len(unused_params) == 1 else 's'} ({', '.join(sorted(unused_params))})")
138+
139+
return tokens
140+
141+
142+
def _escape(self, value):
143+
"""
144+
Escapes value using engine's conversion function.
145+
https://docs.sqlalchemy.org/en/latest/core/type_api.html#sqlalchemy.types.TypeEngine.literal_processor
146+
"""
147+
148+
if isinstance(value, (list, tuple)):
149+
return self._escape_iterable(value)
150+
151+
if isinstance(value, bool):
152+
return sqlparse.sql.Token(
153+
sqlparse.tokens.Number,
154+
sqlalchemy.types.Boolean().literal_processor(self._dialect)(value))
155+
156+
if isinstance(value, bytes):
157+
if self._dialect.name in ["mysql", "sqlite"]:
158+
# https://dev.mysql.com/doc/refman/8.0/en/hexadecimal-literals.html
159+
return sqlparse.sql.Token(sqlparse.tokens.Other, f"x'{value.hex()}'")
160+
if self._dialect.name in ["postgres", "postgresql"]:
161+
# https://dba.stackexchange.com/a/203359
162+
return sqlparse.sql.Token(sqlparse.tokens.Other, f"'\\x{value.hex()}'")
163+
164+
raise RuntimeError(f"unsupported value: {value}")
165+
166+
if isinstance(value, datetime.date):
167+
return sqlparse.sql.Token(
168+
sqlparse.tokens.String,
169+
sqlalchemy.types.String().literal_processor(self._dialect)(value.strftime("%Y-%m-%d")))
170+
171+
if isinstance(value, datetime.datetime):
172+
return sqlparse.sql.Token(
173+
sqlparse.tokens.String,
174+
sqlalchemy.types.String().literal_processor(self._dialect)(value.strftime("%Y-%m-%d %H:%M:%S")))
175+
176+
if isinstance(value, datetime.time):
177+
return sqlparse.sql.Token(
178+
sqlparse.tokens.String,
179+
sqlalchemy.types.String().literal_processor(self._dialect)(value.strftime("%H:%M:%S")))
180+
181+
if isinstance(value, float):
182+
return sqlparse.sql.Token(
183+
sqlparse.tokens.Number,
184+
sqlalchemy.types.Float().literal_processor(self._dialect)(value))
185+
186+
if isinstance(value, int):
187+
return sqlparse.sql.Token(
188+
sqlparse.tokens.Number,
189+
sqlalchemy.types.Integer().literal_processor(self._dialect)(value))
190+
191+
if isinstance(value, str):
192+
return sqlparse.sql.Token(
193+
sqlparse.tokens.String,
194+
sqlalchemy.types.String().literal_processor(self._dialect)(value))
195+
196+
if value is None:
197+
return sqlparse.sql.Token(
198+
sqlparse.tokens.Keyword,
199+
sqlalchemy.types.NullType().literal_processor(self._dialect)(value))
200+
201+
raise RuntimeError(f"unsupported value: {value}")
202+
203+
204+
def _escape_iterable(self, iterable):
205+
return sqlparse.sql.TokenList(
206+
sqlparse.parse(", ".join([str(self._escape(v)) for v in iterable])))
207+
208+
209+
def get_command(self):
210+
return self._command
211+
212+
213+
def __str__(self):
214+
return "".join([str(token) for token in self._tokens])
215+
216+
217+
def _is_placeholder(token):
218+
return token.ttype == sqlparse.tokens.Name.Placeholder
219+
220+
221+
def _parse_placeholder(token):
222+
if token.value == "?":
223+
return Paramstyle.QMARK, None
224+
225+
# E.g., :1
226+
matches = re.search(r"^:([1-9]\d*)$", token.value)
227+
if matches:
228+
return Paramstyle.NUMERIC, int(matches.group(1)) - 1
229+
230+
# E.g., :foo
231+
matches = re.search(r"^:([a-zA-Z]\w*)$", token.value)
232+
if matches:
233+
return Paramstyle.NAMED, matches.group(1)
234+
235+
if token.value == "%s":
236+
return Paramstyle.FORMAT, None
237+
238+
# E.g., %(foo)
239+
matches = re.search(r"%\((\w+)\)s$", token.value)
240+
if matches:
241+
return Paramstyle.PYFORMAT, matches.group(1)
242+
243+
raise RuntimeError(f"{token.value}: invalid placeholder")
244+
245+
246+
def _escape_verbatim_colons(tokens):
247+
for token in tokens:
248+
if _is_string_literal(token):
249+
token.value = re.sub("(^'|\s+):", r"\1\:", token.value)
250+
elif _is_identifier(token):
251+
token.value = re.sub("(^\"|\s+):", r"\1\:", token.value)
252+
253+
return tokens
254+
255+
256+
def _is_string_literal(token):
257+
return token.ttype in [sqlparse.tokens.Literal.String, sqlparse.tokens.Literal.String.Single]
258+
259+
260+
def _is_identifier(token):
261+
return token.ttype == sqlparse.tokens.Literal.String.Symbol
262+
263+
264+
class Paramstyle(enum.Enum):
265+
FORMAT = enum.auto()
266+
NAMED = enum.auto()
267+
NUMERIC = enum.auto()
268+
PYFORMAT = enum.auto()
269+
QMARK = enum.auto()

‎src/cs50/cs50.py

+60-110
Original file line numberDiff line numberDiff line change
@@ -1,98 +1,6 @@
1-
from __future__ import print_function
2-
3-
import inspect
4-
import logging
5-
import os
61
import re
72
import sys
83

9-
from distutils.sysconfig import get_python_lib
10-
from os.path import abspath, join
11-
from termcolor import colored
12-
from traceback import format_exception
13-
14-
15-
# Configure default logging handler and formatter
16-
# Prevent flask, werkzeug, etc from adding default handler
17-
logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.DEBUG)
18-
19-
try:
20-
# Patch formatException
21-
logging.root.handlers[0].formatter.formatException = lambda exc_info: _formatException(*exc_info)
22-
except IndexError:
23-
pass
24-
25-
# Configure cs50 logger
26-
_logger = logging.getLogger("cs50")
27-
_logger.setLevel(logging.DEBUG)
28-
29-
# Log messages once
30-
_logger.propagate = False
31-
32-
handler = logging.StreamHandler()
33-
handler.setLevel(logging.DEBUG)
34-
35-
formatter = logging.Formatter("%(levelname)s: %(message)s")
36-
formatter.formatException = lambda exc_info: _formatException(*exc_info)
37-
handler.setFormatter(formatter)
38-
_logger.addHandler(handler)
39-
40-
41-
class _flushfile():
42-
"""
43-
Disable buffering for standard output and standard error.
44-
45-
http://stackoverflow.com/a/231216
46-
"""
47-
48-
def __init__(self, f):
49-
self.f = f
50-
51-
def __getattr__(self, name):
52-
return object.__getattribute__(self.f, name)
53-
54-
def write(self, x):
55-
self.f.write(x)
56-
self.f.flush()
57-
58-
59-
sys.stderr = _flushfile(sys.stderr)
60-
sys.stdout = _flushfile(sys.stdout)
61-
62-
63-
def _formatException(type, value, tb):
64-
"""
65-
Format traceback, darkening entries from global site-packages directories
66-
and user-specific site-packages directory.
67-
68-
https://stackoverflow.com/a/46071447/5156190
69-
"""
70-
71-
# Absolute paths to site-packages
72-
packages = tuple(join(abspath(p), "") for p in sys.path[1:])
73-
74-
# Highlight lines not referring to files in site-packages
75-
lines = []
76-
for line in format_exception(type, value, tb):
77-
matches = re.search(r"^ File \"([^\"]+)\", line \d+, in .+", line)
78-
if matches and matches.group(1).startswith(packages):
79-
lines += line
80-
else:
81-
matches = re.search(r"^(\s*)(.*?)(\s*)$", line, re.DOTALL)
82-
lines.append(matches.group(1) + colored(matches.group(2), "yellow") + matches.group(3))
83-
return "".join(lines).rstrip()
84-
85-
86-
sys.excepthook = lambda type, value, tb: print(_formatException(type, value, tb), file=sys.stderr)
87-
88-
89-
def eprint(*args, **kwargs):
90-
raise RuntimeError("The CS50 Library for Python no longer supports eprint, but you can use print instead!")
91-
92-
93-
def get_char(prompt):
94-
raise RuntimeError("The CS50 Library for Python no longer supports get_char, but you can use get_string instead!")
95-
964

975
def get_float(prompt):
986
"""
@@ -101,14 +9,21 @@ def get_float(prompt):
1019
prompted to retry. If line can't be read, return None.
10210
"""
10311
while True:
104-
s = get_string(prompt)
105-
if s is None:
106-
return None
107-
if len(s) > 0 and re.search(r"^[+-]?\d*(?:\.\d*)?$", s):
108-
try:
109-
return float(s)
110-
except (OverflowError, ValueError):
111-
pass
12+
try:
13+
return _get_float(prompt)
14+
except (OverflowError, ValueError):
15+
pass
16+
17+
18+
def _get_float(prompt):
19+
s = get_string(prompt)
20+
if s is None:
21+
return
22+
23+
if len(s) > 0 and re.search(r"^[+-]?\d*(?:\.\d*)?$", s):
24+
return float(s)
25+
26+
raise ValueError(f"invalid float literal: {s}")
11227

11328

11429
def get_int(prompt):
@@ -118,14 +33,21 @@ def get_int(prompt):
11833
can't be read, return None.
11934
"""
12035
while True:
121-
s = get_string(prompt)
122-
if s is None:
123-
return None
124-
if re.search(r"^[+-]?\d+$", s):
125-
try:
126-
return int(s, 10)
127-
except ValueError:
128-
pass
36+
try:
37+
return _get_int(prompt)
38+
except (MemoryError, ValueError):
39+
pass
40+
41+
42+
def _get_int(prompt):
43+
s = get_string(prompt)
44+
if s is None:
45+
return
46+
47+
if re.search(r"^[+-]?\d+$", s):
48+
return int(s, 10)
49+
50+
raise ValueError(f"invalid int literal for base 10: {s}")
12951

13052

13153
def get_string(prompt):
@@ -137,7 +59,35 @@ def get_string(prompt):
13759
"""
13860
if type(prompt) is not str:
13961
raise TypeError("prompt must be of type str")
62+
14063
try:
141-
return input(prompt)
64+
return _get_input(prompt)
14265
except EOFError:
143-
return None
66+
return
67+
68+
69+
def _get_input(prompt):
70+
return input(prompt)
71+
72+
73+
class _flushfile():
74+
"""
75+
Disable buffering for standard output and standard error.
76+
http://stackoverflow.com/a/231216
77+
"""
78+
79+
def __init__(self, f):
80+
self.f = f
81+
82+
def __getattr__(self, name):
83+
return object.__getattribute__(self.f, name)
84+
85+
def write(self, x):
86+
self.f.write(x)
87+
self.f.flush()
88+
89+
def disable_buffering():
90+
sys.stderr = _flushfile(sys.stderr)
91+
sys.stdout = _flushfile(sys.stdout)
92+
93+
disable_buffering()

‎src/cs50/sql.py

+72-510
Large diffs are not rendered by default.

‎tests/test_cs50.py

+151
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
import math
2+
import sys
3+
import unittest
4+
5+
from unittest.mock import patch
6+
7+
from cs50.cs50 import get_string, _get_int, _get_float
8+
9+
10+
class TestCS50(unittest.TestCase):
11+
@patch("cs50.cs50._get_input", return_value="")
12+
def test_get_string_empty_input(self, mock_get_input):
13+
"""Returns empty string when input is empty"""
14+
self.assertEqual(get_string("Answer: "), "")
15+
mock_get_input.assert_called_with("Answer: ")
16+
17+
18+
@patch("cs50.cs50._get_input", return_value="test")
19+
def test_get_string_nonempty_input(self, mock_get_input):
20+
"""Returns the provided non-empty input"""
21+
self.assertEqual(get_string("Answer: "), "test")
22+
mock_get_input.assert_called_with("Answer: ")
23+
24+
25+
@patch("cs50.cs50._get_input", side_effect=EOFError)
26+
def test_get_string_eof(self, mock_get_input):
27+
"""Returns None on EOF"""
28+
self.assertIs(get_string("Answer: "), None)
29+
mock_get_input.assert_called_with("Answer: ")
30+
31+
32+
def test_get_string_invalid_prompt(self):
33+
"""Raises TypeError when prompt is not str"""
34+
with self.assertRaises(TypeError):
35+
get_string(1)
36+
37+
38+
@patch("cs50.cs50.get_string", return_value=None)
39+
def test_get_int_eof(self, mock_get_string):
40+
"""Returns None on EOF"""
41+
self.assertIs(_get_int("Answer: "), None)
42+
mock_get_string.assert_called_with("Answer: ")
43+
44+
45+
def test_get_int_valid_input(self):
46+
"""Returns the provided integer input"""
47+
48+
def assert_equal(return_value, expected_value):
49+
with patch("cs50.cs50.get_string", return_value=return_value) as mock_get_string:
50+
self.assertEqual(_get_int("Answer: "), expected_value)
51+
mock_get_string.assert_called_with("Answer: ")
52+
53+
values = [
54+
("0", 0),
55+
("50", 50),
56+
("+50", 50),
57+
("+42", 42),
58+
("-42", -42),
59+
("42", 42),
60+
]
61+
62+
for return_value, expected_value in values:
63+
assert_equal(return_value, expected_value)
64+
65+
66+
def test_get_int_invalid_input(self):
67+
"""Raises ValueError when input is invalid base-10 int"""
68+
69+
def assert_raises_valueerror(return_value):
70+
with patch("cs50.cs50.get_string", return_value=return_value) as mock_get_string:
71+
with self.assertRaises(ValueError):
72+
_get_int("Answer: ")
73+
74+
mock_get_string.assert_called_with("Answer: ")
75+
76+
return_values = [
77+
"++50",
78+
"--50",
79+
"50+",
80+
"50-",
81+
" 50",
82+
" +50",
83+
" -50",
84+
"50 ",
85+
"ab50",
86+
"50ab",
87+
"ab50ab",
88+
]
89+
90+
for return_value in return_values:
91+
assert_raises_valueerror(return_value)
92+
93+
94+
@patch("cs50.cs50.get_string", return_value=None)
95+
def test_get_float_eof(self, mock_get_string):
96+
"""Returns None on EOF"""
97+
self.assertIs(_get_float("Answer: "), None)
98+
mock_get_string.assert_called_with("Answer: ")
99+
100+
101+
def test_get_float_valid_input(self):
102+
"""Returns the provided integer input"""
103+
def assert_equal(return_value, expected_value):
104+
with patch("cs50.cs50.get_string", return_value=return_value) as mock_get_string:
105+
f = _get_float("Answer: ")
106+
self.assertTrue(math.isclose(f, expected_value))
107+
mock_get_string.assert_called_with("Answer: ")
108+
109+
values = [
110+
(".0", 0.0),
111+
("0.", 0.0),
112+
(".42", 0.42),
113+
("42.", 42.0),
114+
("50", 50.0),
115+
("+50", 50.0),
116+
("-50", -50.0),
117+
("+3.14", 3.14),
118+
("-3.14", -3.14),
119+
]
120+
121+
for return_value, expected_value in values:
122+
assert_equal(return_value, expected_value)
123+
124+
125+
def test_get_float_invalid_input(self):
126+
"""Raises ValueError when input is invalid float"""
127+
128+
def assert_raises_valueerror(return_value):
129+
with patch("cs50.cs50.get_string", return_value=return_value) as mock_get_string:
130+
with self.assertRaises(ValueError):
131+
_get_float("Answer: ")
132+
133+
mock_get_string.assert_called_with("Answer: ")
134+
135+
return_values = [
136+
".",
137+
"..5",
138+
"a.5",
139+
".5a"
140+
"0.5a",
141+
"a0.42",
142+
" .42",
143+
"3.14 ",
144+
"++3.14",
145+
"3.14+",
146+
"--3.14",
147+
"3.14--",
148+
]
149+
150+
for return_value in return_values:
151+
assert_raises_valueerror(return_value)

0 commit comments

Comments
 (0)
Please sign in to comment.