Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 98840ee

Browse files
committedDec 8, 2021
porting thread-local storage to v6
1 parent 959abb5 commit 98840ee

26 files changed

+906
-1082
lines changed
 

‎.gitignore

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
11
.*
2-
!.github
2+
!/.github/
33
!.gitignore
4-
!.travis.yml
54
*.db
65
*.egg-info/
76
*.pyc
8-
build/
9-
dist/
10-
test.db

‎setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@
1010
"Topic :: Software Development :: Libraries :: Python Modules"
1111
],
1212
description="CS50 library for Python",
13-
install_requires=["Flask>=1.0", "SQLAlchemy<2", "sqlparse", "termcolor"],
13+
install_requires=["Flask>=1.0", "SQLAlchemy", "sqlparse", "termcolor"],
1414
keywords="cs50",
1515
name="cs50",
1616
package_dir={"": "src"},
1717
packages=["cs50"],
1818
url="https://github.com/cs50/python-cs50",
19-
version="7.0.2"
19+
version="7.1.0"
2020
)

‎src/cs50/__init__.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,20 @@
1-
from .cs50 import get_float, get_int, get_string
2-
from .sql import SQL
3-
from ._logger import _setup_logger
1+
import logging
2+
import os
3+
import sys
4+
5+
6+
# Disable cs50 logger by default
7+
logging.getLogger("cs50").disabled = True
48

5-
_setup_logger()
9+
# Import cs50_*
10+
from .cs50 import get_char, get_float, get_int, get_string
11+
try:
12+
from .cs50 import get_long
13+
except ImportError:
14+
pass
15+
16+
# Hook into flask importing
17+
from . import flask
18+
19+
# Wrap SQLAlchemy
20+
from .sql import SQL

‎src/cs50/_engine.py

Lines changed: 0 additions & 88 deletions
This file was deleted.

‎src/cs50/_engine_util.py

Lines changed: 0 additions & 43 deletions
This file was deleted.

‎src/cs50/_logger.py

Lines changed: 0 additions & 98 deletions
This file was deleted.

‎src/cs50/_sql_sanitizer.py

Lines changed: 0 additions & 93 deletions
This file was deleted.

‎src/cs50/_sql_util.py

Lines changed: 0 additions & 51 deletions
This file was deleted.

‎src/cs50/_statement.py

Lines changed: 0 additions & 247 deletions
This file was deleted.

‎src/cs50/_statement_util.py

Lines changed: 0 additions & 101 deletions
This file was deleted.

‎src/cs50/cs50.py

Lines changed: 105 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,104 +1,143 @@
1-
"""Exposes simple API for getting and validating user input"""
1+
from __future__ import print_function
22

3+
import inspect
4+
import logging
5+
import os
36
import re
47
import sys
58

9+
from distutils.sysconfig import get_python_lib
10+
from os.path import abspath, join
11+
from termcolor import colored
12+
from traceback import format_exception
613

7-
def get_float(prompt):
8-
"""Reads a line of text from standard input and returns the equivalent float as precisely as
9-
possible; if text does not represent a float, user is prompted to retry. If line can't be read,
10-
returns None.
1114

12-
:type prompt: str
15+
# Configure default logging handler and formatter
16+
# Prevent flask, werkzeug, etc from adding default handler
17+
logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.DEBUG)
1318

14-
"""
19+
try:
20+
# Patch formatException
21+
logging.root.handlers[0].formatter.formatException = lambda exc_info: _formatException(*exc_info)
22+
except IndexError:
23+
pass
1524

16-
while True:
17-
try:
18-
return _get_float(prompt)
19-
except (OverflowError, ValueError):
20-
pass
25+
# Configure cs50 logger
26+
_logger = logging.getLogger("cs50")
27+
_logger.setLevel(logging.DEBUG)
2128

29+
# Log messages once
30+
_logger.propagate = False
2231

23-
def _get_float(prompt):
24-
user_input = get_string(prompt)
25-
if user_input is None:
26-
return None
27-
28-
if len(user_input) > 0 and re.search(r"^[+-]?\d*(?:\.\d*)?$", user_input):
29-
return float(user_input)
32+
handler = logging.StreamHandler()
33+
handler.setLevel(logging.DEBUG)
3034

31-
raise ValueError(f"invalid float literal: {user_input}")
35+
formatter = logging.Formatter("%(levelname)s: %(message)s")
36+
formatter.formatException = lambda exc_info: _formatException(*exc_info)
37+
handler.setFormatter(formatter)
38+
_logger.addHandler(handler)
3239

3340

34-
def get_int(prompt):
35-
"""Reads a line of text from standard input and return the equivalent int; if text does not
36-
represent an int, user is prompted to retry. If line can't be read, returns None.
41+
class _flushfile():
42+
"""
43+
Disable buffering for standard output and standard error.
3744
38-
:type prompt: str
45+
http://stackoverflow.com/a/231216
3946
"""
4047

41-
while True:
42-
try:
43-
return _get_int(prompt)
44-
except (MemoryError, ValueError):
45-
pass
48+
def __init__(self, f):
49+
self.f = f
4650

51+
def __getattr__(self, name):
52+
return object.__getattribute__(self.f, name)
4753

48-
def _get_int(prompt):
49-
user_input = get_string(prompt)
50-
if user_input is None:
51-
return None
54+
def write(self, x):
55+
self.f.write(x)
56+
self.f.flush()
5257

53-
if re.search(r"^[+-]?\d+$", user_input):
54-
return int(user_input, 10)
5558

56-
raise ValueError(f"invalid int literal for base 10: {user_input}")
59+
sys.stderr = _flushfile(sys.stderr)
60+
sys.stdout = _flushfile(sys.stdout)
5761

5862

59-
def get_string(prompt):
60-
"""Reads a line of text from standard input and returns it as a string, sans trailing line
61-
ending. Supports CR (\r), LF (\n), and CRLF (\r\n) as line endings. If user inputs only a line
62-
ending, returns "", not None. Returns None upon error or no input whatsoever (i.e., just EOF).
63+
def _formatException(type, value, tb):
64+
"""
65+
Format traceback, darkening entries from global site-packages directories
66+
and user-specific site-packages directory.
6367
64-
:type prompt: str
68+
https://stackoverflow.com/a/46071447/5156190
6569
"""
6670

67-
if not isinstance(prompt, str):
68-
raise TypeError("prompt must be of type str")
71+
# Absolute paths to site-packages
72+
packages = tuple(join(abspath(p), "") for p in sys.path[1:])
6973

70-
try:
71-
return _get_input(prompt)
72-
except EOFError:
73-
return None
74+
# Highlight lines not referring to files in site-packages
75+
lines = []
76+
for line in format_exception(type, value, tb):
77+
matches = re.search(r"^ File \"([^\"]+)\", line \d+, in .+", line)
78+
if matches and matches.group(1).startswith(packages):
79+
lines += line
80+
else:
81+
matches = re.search(r"^(\s*)(.*?)(\s*)$", line, re.DOTALL)
82+
lines.append(matches.group(1) + colored(matches.group(2), "yellow") + matches.group(3))
83+
return "".join(lines).rstrip()
7484

7585

76-
def _get_input(prompt):
77-
return input(prompt)
86+
sys.excepthook = lambda type, value, tb: print(_formatException(type, value, tb), file=sys.stderr)
7887

7988

80-
class _flushfile():
81-
""" Disable buffering for standard output and standard error.
82-
http://stackoverflow.com/a/231216
83-
"""
89+
def eprint(*args, **kwargs):
90+
raise RuntimeError("The CS50 Library for Python no longer supports eprint, but you can use print instead!")
8491

85-
def __init__(self, stream):
86-
self.stream = stream
8792

88-
def __getattr__(self, name):
89-
return object.__getattribute__(self.stream, name)
93+
def get_char(prompt):
94+
raise RuntimeError("The CS50 Library for Python no longer supports get_char, but you can use get_string instead!")
9095

91-
def write(self, data):
92-
"""Writes data to stream"""
93-
self.stream.write(data)
94-
self.stream.flush()
9596

97+
def get_float(prompt):
98+
"""
99+
Read a line of text from standard input and return the equivalent float
100+
as precisely as possible; if text does not represent a double, user is
101+
prompted to retry. If line can't be read, return None.
102+
"""
103+
while True:
104+
s = get_string(prompt)
105+
if s is None:
106+
return None
107+
if len(s) > 0 and re.search(r"^[+-]?\d*(?:\.\d*)?$", s):
108+
try:
109+
return float(s)
110+
except (OverflowError, ValueError):
111+
pass
96112

97-
def disable_output_buffering():
98-
"""Disables output buffering to prevent prompts from being buffered.
113+
114+
def get_int(prompt):
99115
"""
100-
sys.stderr = _flushfile(sys.stderr)
101-
sys.stdout = _flushfile(sys.stdout)
116+
Read a line of text from standard input and return the equivalent int;
117+
if text does not represent an int, user is prompted to retry. If line
118+
can't be read, return None.
119+
"""
120+
while True:
121+
s = get_string(prompt)
122+
if s is None:
123+
return None
124+
if re.search(r"^[+-]?\d+$", s):
125+
try:
126+
return int(s, 10)
127+
except ValueError:
128+
pass
102129

103130

104-
disable_output_buffering()
131+
def get_string(prompt):
132+
"""
133+
Read a line of text from standard input and return it as a string,
134+
sans trailing line ending. Supports CR (\r), LF (\n), and CRLF (\r\n)
135+
as line endings. If user inputs only a line ending, returns "", not None.
136+
Returns None upon error or no input whatsoever (i.e., just EOF).
137+
"""
138+
if type(prompt) is not str:
139+
raise TypeError("prompt must be of type str")
140+
try:
141+
return input(prompt)
142+
except EOFError:
143+
return None

‎src/cs50/flask.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import os
2+
import pkgutil
3+
import sys
4+
5+
def _wrap_flask(f):
6+
if f is None:
7+
return
8+
9+
from distutils.version import StrictVersion
10+
from .cs50 import _formatException
11+
12+
if f.__version__ < StrictVersion("1.0"):
13+
return
14+
15+
if os.getenv("CS50_IDE_TYPE") == "online":
16+
from werkzeug.middleware.proxy_fix import ProxyFix
17+
_flask_init_before = f.Flask.__init__
18+
def _flask_init_after(self, *args, **kwargs):
19+
_flask_init_before(self, *args, **kwargs)
20+
self.wsgi_app = ProxyFix(self.wsgi_app, x_proto=1) # For HTTPS-to-HTTP proxy
21+
f.Flask.__init__ = _flask_init_after
22+
23+
24+
# If Flask was imported before cs50
25+
if "flask" in sys.modules:
26+
_wrap_flask(sys.modules["flask"])
27+
28+
# If Flask wasn't imported
29+
else:
30+
flask_loader = pkgutil.get_loader('flask')
31+
if flask_loader:
32+
_exec_module_before = flask_loader.exec_module
33+
34+
def _exec_module_after(*args, **kwargs):
35+
_exec_module_before(*args, **kwargs)
36+
_wrap_flask(sys.modules["flask"])
37+
38+
flask_loader.exec_module = _exec_module_after

‎src/cs50/sql.py

Lines changed: 503 additions & 83 deletions
Large diffs are not rendered by default.

‎tests/flask/application.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import requests
2+
import sys
3+
from flask import Flask, render_template
4+
5+
sys.path.insert(0, "../../src")
6+
7+
import cs50
8+
import cs50.flask
9+
10+
app = Flask(__name__)
11+
12+
db = cs50.SQL("sqlite:///../sqlite.db")
13+
14+
@app.route("/")
15+
def index():
16+
db.execute("SELECT 1")
17+
"""
18+
def f():
19+
res = requests.get("cs50.harvard.edu")
20+
f()
21+
"""
22+
return render_template("index.html")

‎tests/flask/requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
cs50
2+
Flask

‎tests/flask/templates/error.html

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
<!DOCTYPE html>
2+
3+
<html>
4+
<head>
5+
<title>error</title>
6+
</head>
7+
<body>
8+
error
9+
</body>
10+
</html>

‎tests/flask/templates/index.html

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
<!DOCTYPE html>
2+
3+
<html>
4+
<head>
5+
<title>flask</title>
6+
</head>
7+
<body>
8+
flask
9+
</body>
10+
</html>

‎tests/foo.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import logging
2+
import sys
3+
4+
sys.path.insert(0, "../src")
5+
6+
import cs50
7+
8+
"""
9+
db = cs50.SQL("sqlite:///foo.db")
10+
11+
logging.getLogger("cs50").disabled = False
12+
13+
#db.execute("SELECT ? FROM ? ORDER BY ?", "a", "tbl", "c")
14+
db.execute("CREATE TABLE IF NOT EXISTS bar (firstname STRING)")
15+
16+
db.execute("INSERT INTO bar VALUES (?)", "baz")
17+
db.execute("INSERT INTO bar VALUES (?)", "qux")
18+
db.execute("SELECT * FROM bar WHERE firstname IN (?)", ("baz", "qux"))
19+
db.execute("DELETE FROM bar")
20+
"""
21+
22+
db = cs50.SQL("postgresql://postgres@localhost/test")
23+
24+
"""
25+
print(db.execute("DROP TABLE IF EXISTS cs50"))
26+
print(db.execute("CREATE TABLE cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16), bin BYTEA)"))
27+
print(db.execute("INSERT INTO cs50 (val) VALUES('foo')"))
28+
print(db.execute("SELECT * FROM cs50"))
29+
30+
print(db.execute("DROP TABLE IF EXISTS cs50"))
31+
print(db.execute("CREATE TABLE cs50 (val VARCHAR(16), bin BYTEA)"))
32+
print(db.execute("INSERT INTO cs50 (val) VALUES('foo')"))
33+
print(db.execute("SELECT * FROM cs50"))
34+
"""
35+
36+
print(db.execute("DROP TABLE IF EXISTS cs50"))
37+
print(db.execute("CREATE TABLE cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16), bin BYTEA)"))
38+
print(db.execute("INSERT INTO cs50 (val) VALUES('foo')"))
39+
print(db.execute("INSERT INTO cs50 (val) VALUES('bar')"))
40+
print(db.execute("INSERT INTO cs50 (val) VALUES('baz')"))
41+
print(db.execute("SELECT * FROM cs50"))
42+
try:
43+
print(db.execute("INSERT INTO cs50 (id, val) VALUES(1, 'bar')"))
44+
except Exception as e:
45+
print(e)
46+
pass
47+
print(db.execute("INSERT INTO cs50 (val) VALUES('qux')"))
48+
#print(db.execute("DELETE FROM cs50"))

‎tests/mysql.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import sys
2+
3+
sys.path.insert(0, "../src")
4+
5+
from cs50 import SQL
6+
7+
db = SQL("mysql://root@localhost/test")
8+
db.execute("SELECT 1")

‎tests/python.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import sys
2+
3+
sys.path.insert(0, "../src")
4+
5+
import cs50
6+
7+
i = cs50.get_int("Input: ")
8+
print(f"Output: {i}")

‎tests/redirect/application.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import cs50
2+
from flask import Flask, redirect, render_template
3+
4+
app = Flask(__name__)
5+
6+
@app.route("/")
7+
def index():
8+
return redirect("/foo")
9+
10+
@app.route("/foo")
11+
def foo():
12+
return render_template("foo.html")

‎tests/redirect/templates/foo.html

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
foo

‎tests/sql.py

Lines changed: 63 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010

1111
class SQLTests(unittest.TestCase):
12+
1213
def test_multiple_statements(self):
1314
self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO cs50(val) VALUES('baz'); INSERT INTO cs50(val) VALUES('qux')")
1415

@@ -27,7 +28,6 @@ def test_delete_returns_affected_rows(self):
2728
def test_insert_returns_last_row_id(self):
2829
self.assertEqual(self.db.execute("INSERT INTO cs50(val) VALUES('foo')"), 1)
2930
self.assertEqual(self.db.execute("INSERT INTO cs50(val) VALUES('bar')"), 2)
30-
self.assertEqual(self.db.execute("INSERT INTO cs50(val) VALUES('qux')"), 3)
3131

3232
def test_select_all(self):
3333
self.assertEqual(self.db.execute("SELECT * FROM cs50"), [])
@@ -132,13 +132,64 @@ def test_rollback(self):
132132
def test_identifier_case(self):
133133
self.assertIn("count", self.db.execute("SELECT 1 AS count")[0])
134134

135-
def test_none(self):
136-
self.db.execute("CREATE TABLE foo (val INTEGER)")
137-
self.db.execute("SELECT * FROM foo WHERE val = ?", None)
135+
def tearDown(self):
136+
self.db.execute("DROP TABLE cs50")
137+
self.db.execute("DROP TABLE IF EXISTS foo")
138+
self.db.execute("DROP TABLE IF EXISTS bar")
139+
140+
@classmethod
141+
def tearDownClass(self):
142+
try:
143+
self.db.execute("DROP TABLE IF EXISTS cs50")
144+
except Warning as e:
145+
# suppress "unknown table"
146+
if not str(e).startswith("(1051"):
147+
raise e
148+
149+
150+
class MySQLTests(SQLTests):
151+
@classmethod
152+
def setUpClass(self):
153+
self.db = SQL("mysql://root@localhost/test")
154+
155+
def setUp(self):
156+
self.db.execute("CREATE TABLE IF NOT EXISTS cs50 (id INTEGER NOT NULL AUTO_INCREMENT, val VARCHAR(16), bin BLOB, PRIMARY KEY (id))")
157+
self.db.execute("DELETE FROM cs50")
158+
159+
160+
class PostgresTests(SQLTests):
161+
@classmethod
162+
def setUpClass(self):
163+
self.db = SQL("postgresql://postgres@localhost/test")
164+
165+
def setUp(self):
166+
self.db.execute("CREATE TABLE IF NOT EXISTS cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16), bin BYTEA)")
167+
self.db.execute("DELETE FROM cs50")
168+
169+
def test_cte(self):
170+
self.assertEqual(self.db.execute("WITH foo AS ( SELECT 1 AS bar ) SELECT bar FROM foo"), [{"bar": 1}])
171+
172+
173+
class SQLiteTests(SQLTests):
174+
175+
@classmethod
176+
def setUpClass(self):
177+
open("test.db", "w").close()
178+
self.db = SQL("sqlite:///test.db")
179+
180+
def setUp(self):
181+
self.db.execute("CREATE TABLE IF NOT EXISTS cs50 (id INTEGER PRIMARY KEY, val TEXT, bin BLOB)")
182+
self.db.execute("DELETE FROM cs50")
183+
184+
def test_lastrowid(self):
185+
self.db.execute("CREATE TABLE foo(id INTEGER PRIMARY KEY AUTOINCREMENT, firstname TEXT, lastname TEXT)")
186+
self.assertEqual(self.db.execute("INSERT INTO foo (firstname, lastname) VALUES('firstname', 'lastname')"), 1)
187+
self.assertRaises(ValueError, self.db.execute, "INSERT INTO foo (id, firstname, lastname) VALUES(1, 'firstname', 'lastname')")
188+
self.assertEqual(self.db.execute("INSERT OR IGNORE INTO foo (id, firstname, lastname) VALUES(1, 'firstname', 'lastname')"), None)
138189

139190
def test_integrity_constraints(self):
140191
self.db.execute("CREATE TABLE foo(id INTEGER PRIMARY KEY)")
141-
self.db.execute("INSERT INTO foo VALUES(1)")
192+
self.assertEqual(self.db.execute("INSERT INTO foo VALUES(1)"), 1)
142193
self.assertRaises(ValueError, self.db.execute, "INSERT INTO foo VALUES(1)")
143194

144195
def test_foreign_key_support(self):
@@ -147,7 +198,7 @@ def test_foreign_key_support(self):
147198
self.assertRaises(ValueError, self.db.execute, "INSERT INTO bar VALUES(50)")
148199

149200
def test_qmark(self):
150-
self.db.execute("CREATE TABLE foo (firstname VARCHAR(255), lastname VARCHAR(255))")
201+
self.db.execute("CREATE TABLE foo (firstname STRING, lastname STRING)")
151202

152203
self.db.execute("INSERT INTO foo VALUES (?, 'bar')", "baz")
153204
self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "baz", "lastname": "bar"}])
@@ -177,7 +228,7 @@ def test_qmark(self):
177228
self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "bar", "lastname": "baz"}])
178229
self.db.execute("DELETE FROM foo")
179230

180-
self.db.execute("CREATE TABLE bar (firstname VARCHAR(255))")
231+
self.db.execute("CREATE TABLE bar (firstname STRING)")
181232

182233
self.db.execute("INSERT INTO bar VALUES (?)", "baz")
183234
self.assertEqual(self.db.execute("SELECT * FROM bar"), [{"firstname": "baz"}])
@@ -201,7 +252,7 @@ def test_qmark(self):
201252
self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo VALUES (?, ?)", 'bar', baz='baz')
202253

203254
def test_named(self):
204-
self.db.execute("CREATE TABLE foo (firstname VARCHAR(255), lastname VARCHAR(255))")
255+
self.db.execute("CREATE TABLE foo (firstname STRING, lastname STRING)")
205256

206257
self.db.execute("INSERT INTO foo VALUES (:baz, 'bar')", baz="baz")
207258
self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "baz", "lastname": "bar"}])
@@ -223,11 +274,7 @@ def test_named(self):
223274
self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "bar", "lastname": "baz"}])
224275
self.db.execute("DELETE FROM foo")
225276

226-
self.db.execute("INSERT INTO foo VALUES (:baz, :baz)", baz="baz")
227-
self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "baz", "lastname": "baz"}])
228-
self.db.execute("DELETE FROM foo")
229-
230-
self.db.execute("CREATE TABLE bar (firstname VARCHAR(255))")
277+
self.db.execute("CREATE TABLE bar (firstname STRING)")
231278
self.db.execute("INSERT INTO bar VALUES (:baz)", baz="baz")
232279
self.assertEqual(self.db.execute("SELECT * FROM bar"), [{"firstname": "baz"}])
233280

@@ -236,8 +283,9 @@ def test_named(self):
236283
self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo VALUES (:bar, :baz)", bar='bar', baz='baz', qux='qux')
237284
self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo VALUES (:bar, :baz)", 'baz', bar='bar')
238285

286+
239287
def test_numeric(self):
240-
self.db.execute("CREATE TABLE foo (firstname VARCHAR(255), lastname VARCHAR(255))")
288+
self.db.execute("CREATE TABLE foo (firstname STRING, lastname STRING)")
241289

242290
self.db.execute("INSERT INTO foo VALUES (:1, 'bar')", "baz")
243291
self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "baz", "lastname": "bar"}])
@@ -259,7 +307,7 @@ def test_numeric(self):
259307
self.assertEqual(self.db.execute("SELECT * FROM foo"), [{"firstname": "bar", "lastname": "baz"}])
260308
self.db.execute("DELETE FROM foo")
261309

262-
self.db.execute("CREATE TABLE bar (firstname VARCHAR(255))")
310+
self.db.execute("CREATE TABLE bar (firstname STRING)")
263311
self.db.execute("INSERT INTO bar VALUES (:1)", "baz")
264312
self.assertEqual(self.db.execute("SELECT * FROM bar"), [{"firstname": "baz"}])
265313

@@ -271,51 +319,6 @@ def test_numeric(self):
271319
def test_cte(self):
272320
self.assertEqual(self.db.execute("WITH foo AS ( SELECT 1 AS bar ) SELECT bar FROM foo"), [{"bar": 1}])
273321

274-
def tearDown(self):
275-
self.db.execute("DROP TABLE IF EXISTS cs50")
276-
self.db.execute("DROP TABLE IF EXISTS bar")
277-
self.db.execute("DROP TABLE IF EXISTS foo")
278-
279-
class MySQLTests(SQLTests):
280-
@classmethod
281-
def setUpClass(self):
282-
self.db = SQL("mysql://root@127.0.0.1/test")
283-
284-
def setUp(self):
285-
self.db.execute("CREATE TABLE IF NOT EXISTS cs50 (id INTEGER NOT NULL AUTO_INCREMENT, val VARCHAR(16), bin BLOB, PRIMARY KEY (id))")
286-
self.db.execute("DELETE FROM cs50")
287-
288-
class PostgresTests(SQLTests):
289-
@classmethod
290-
def setUpClass(self):
291-
self.db = SQL("postgresql://postgres:postgres@127.0.0.1/test")
292-
293-
def setUp(self):
294-
self.db.execute("CREATE TABLE IF NOT EXISTS cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16), bin BYTEA)")
295-
self.db.execute("DELETE FROM cs50")
296-
297-
def test_cte(self):
298-
self.assertEqual(self.db.execute("WITH foo AS ( SELECT 1 AS bar ) SELECT bar FROM foo"), [{"bar": 1}])
299-
300-
def test_postgres_scheme(self):
301-
db = SQL("postgres://postgres:postgres@127.0.0.1/test")
302-
db.execute("SELECT 1")
303-
304-
class SQLiteTests(SQLTests):
305-
@classmethod
306-
def setUpClass(self):
307-
open("test.db", "w").close()
308-
self.db = SQL("sqlite:///test.db")
309-
310-
def setUp(self):
311-
self.db.execute("CREATE TABLE IF NOT EXISTS cs50 (id INTEGER PRIMARY KEY, val TEXT, bin BLOB)")
312-
self.db.execute("DELETE FROM cs50")
313-
314-
def test_lastrowid(self):
315-
self.db.execute("CREATE TABLE foo(id INTEGER PRIMARY KEY AUTOINCREMENT, firstname TEXT, lastname TEXT)")
316-
self.assertEqual(self.db.execute("INSERT INTO foo (firstname, lastname) VALUES('firstname', 'lastname')"), 1)
317-
self.assertRaises(ValueError, self.db.execute, "INSERT INTO foo (id, firstname, lastname) VALUES(1, 'firstname', 'lastname')")
318-
self.assertEqual(self.db.execute("INSERT OR IGNORE INTO foo (id, firstname, lastname) VALUES(1, 'firstname', 'lastname')"), None)
319322

320323
if __name__ == "__main__":
321324
suite = unittest.TestSuite([

‎tests/sqlite.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import logging
2+
import sys
3+
4+
sys.path.insert(0, "../src")
5+
6+
from cs50 import SQL
7+
8+
logging.getLogger("cs50").disabled = False
9+
10+
db = SQL("sqlite:///sqlite.db")
11+
db.execute("SELECT 1")
12+
13+
# TODO
14+
#db.execute("SELECT * FROM Employee WHERE FirstName = ?", b'\x00')
15+
16+
db.execute("SELECT * FROM Employee WHERE FirstName = ?", "' OR 1 = 1")
17+
18+
db.execute("SELECT * FROM Employee WHERE FirstName IN (?)", "Andrew")
19+
db.execute("SELECT * FROM Employee WHERE FirstName IN (?)", ["Andrew"])
20+
db.execute("SELECT * FROM Employee WHERE FirstName IN (?)", ("Andrew",))
21+
db.execute("SELECT * FROM Employee WHERE FirstName IN (?)", ["Andrew", "Nancy"])
22+
db.execute("SELECT * FROM Employee WHERE FirstName IN (?)", ("Andrew", "Nancy"))
23+
db.execute("SELECT * FROM Employee WHERE FirstName IN (?)", [])
24+
db.execute("SELECT * FROM Employee WHERE FirstName IN (?)", ())
25+
26+
db.execute("SELECT * FROM Employee WHERE FirstName = ? AND LastName = ?", "Andrew", "Adams")
27+
db.execute("SELECT * FROM Employee WHERE FirstName = ? AND LastName = ?", ["Andrew", "Adams"])
28+
db.execute("SELECT * FROM Employee WHERE FirstName = ? AND LastName = ?", ("Andrew", "Adams"))
29+
30+
db.execute("SELECT * FROM Employee WHERE FirstName = :1 AND LastName = :2", "Andrew", "Adams")
31+
db.execute("SELECT * FROM Employee WHERE FirstName = :1 AND LastName = :2", ["Andrew", "Adams"])
32+
db.execute("SELECT * FROM Employee WHERE FirstName = :1 AND LastName = :2", ("Andrew", "Adams"))
33+
34+
db.execute("SELECT * FROM Employee WHERE FirstName = ':Andrew :Adams'")
35+
36+
db.execute("SELECT * FROM Employee WHERE FirstName = :first AND LastName = :last", first="Andrew", last="Adams")
37+
db.execute("SELECT * FROM Employee WHERE FirstName = :first AND LastName = :last", {"first": "Andrew", "last": "Adams"})
38+
39+
db.execute("SELECT * FROM Employee WHERE FirstName = %s AND LastName = %s", "Andrew", "Adams")
40+
db.execute("SELECT * FROM Employee WHERE FirstName = %s AND LastName = %s", ["Andrew", "Adams"])
41+
db.execute("SELECT * FROM Employee WHERE FirstName = %s AND LastName = %s", ("Andrew", "Adams"))
42+
43+
db.execute("SELECT * FROM Employee WHERE FirstName = %(first)s AND LastName = %(last)s", first="Andrew", last="Adams")
44+
db.execute("SELECT * FROM Employee WHERE FirstName = %(first)s AND LastName = %(last)s", {"first": "Andrew", "last": "Adams"})

‎tests/tb.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import sys
2+
3+
sys.path.insert(0, "../src")
4+
5+
import cs50
6+
import requests
7+
8+
def f():
9+
res = requests.get("cs50.harvard.edu")
10+
f()

‎tests/test_cs50.py

Lines changed: 0 additions & 141 deletions
This file was deleted.

0 commit comments

Comments
 (0)
Please sign in to comment.