Skip to content

Commit f2732cd

Browse files
authoredMay 21, 2017
Merge pull request #21 from cs50/logging
added logging
2 parents d626967 + b63f521 commit f2732cd

File tree

2 files changed

+163
-13
lines changed

2 files changed

+163
-13
lines changed
 

‎cs50/sql.py

+34-13
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,28 @@
11
import datetime
2+
import logging
3+
import re
24
import sqlalchemy
35
import sys
6+
import warnings
47

58
class SQL(object):
69
"""Wrap SQLAlchemy to provide a simple SQL API."""
710

8-
def __init__(self, url):
11+
def __init__(self, url, **kwargs):
912
"""
1013
Create instance of sqlalchemy.engine.Engine.
1114
1215
URL should be a string that indicates database dialect and connection arguments.
1316
1417
http://docs.sqlalchemy.org/en/latest/core/engines.html#sqlalchemy.create_engine
18+
http://docs.sqlalchemy.org/en/latest/dialects/index.html
1519
"""
20+
logging.basicConfig(level=logging.DEBUG)
21+
self.logger = logging.getLogger(__name__)
1622
try:
17-
self.engine = sqlalchemy.create_engine(url)
23+
self.engine = sqlalchemy.create_engine(url, **kwargs)
1824
except Exception as e:
19-
e.__context__ = None
25+
e.__cause__ = None
2026
raise RuntimeError(e)
2127

2228
def execute(self, text, **params):
@@ -79,6 +85,10 @@ def process(value):
7985
else:
8086
return process(value)
8187

88+
# raise exceptions for warnings
89+
warnings.filterwarnings("error")
90+
91+
# prepare, execute statement
8292
try:
8393

8494
# construct a new TextClause clause
@@ -97,29 +107,40 @@ def process(value):
97107

98108
# stringify bound parameters
99109
# http://docs.sqlalchemy.org/en/latest/faq/sqlexpressions.html#how-do-i-render-sql-expressions-as-strings-possibly-with-bound-parameters-inlined
100-
self.statement = str(statement.compile(compile_kwargs={"literal_binds": True}))
110+
statement = str(statement.compile(compile_kwargs={"literal_binds": True}))
101111

102112
# execute statement
103-
result = self.engine.execute(self.statement)
113+
result = self.engine.execute(statement)
114+
115+
# log statement
116+
self.logger.debug(statement)
104117

105118
# if SELECT (or INSERT with RETURNING), return result set as list of dict objects
106-
if result.returns_rows:
119+
if re.search(r"^\s*SELECT\s+", statement, re.I):
107120
rows = result.fetchall()
108121
return [dict(row) for row in rows]
109122

110123
# if INSERT, return primary key value for a newly inserted row
111-
elif result.lastrowid is not None:
112-
return result.lastrowid
124+
elif re.search(r"^\s*INSERT\s+", statement, re.I):
125+
if self.engine.url.get_backend_name() == "postgresql":
126+
result = self.engine.execute(sqlalchemy.text("SELECT LASTVAL()"))
127+
return result.first()[0]
128+
else:
129+
return result.lastrowid
113130

114-
# if DELETE or UPDATE (or INSERT without RETURNING), return number of rows matched
115-
else:
131+
# if DELETE or UPDATE, return number of rows matched
132+
elif re.search(r"^\s*(?:DELETE|UPDATE)\s+", statement, re.I):
116133
return result.rowcount
117134

135+
# if some other statement, return True unless exception
136+
return True
137+
118138
# if constraint violated, return None
119139
except sqlalchemy.exc.IntegrityError:
120140
return None
121141

122-
# else raise error
142+
# else raise exception
123143
except Exception as e:
124-
e.__context__ = None
125-
raise RuntimeError(e)
144+
_e = RuntimeError(e) # else Python 3 prints warnings' tracebacks
145+
_e.__cause__ = None
146+
raise _e

‎test/sqltests.py

+129
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
import unittest
2+
from cs50.sql import SQL
3+
4+
class SQLTests(unittest.TestCase):
5+
def test_delete_returns_affected_rows(self):
6+
rows = [
7+
{"id": 1, "val": "foo"},
8+
{"id": 2, "val": "bar"},
9+
{"id": 3, "val": "baz"}
10+
]
11+
for row in rows:
12+
self.db.execute("INSERT INTO cs50(val) VALUES(:val);", val=row["val"])
13+
14+
print(self.db.execute("DELETE FROM cs50 WHERE id = :id", id=rows[0]["id"]))
15+
print(self.db.execute("SELECT * FROM cs50"))
16+
return
17+
18+
self.assertEqual(self.db.execute("DELETE FROM cs50 WHERE id = :id", id=rows[0]["id"]), 1)
19+
self.assertEqual(self.db.execute("DELETE FROM cs50 WHERE id = :a or id = :b", a=rows[1]["id"], b=rows[2]["id"]), 2)
20+
self.assertEqual(self.db.execute("DELETE FROM cs50 WHERE id = -50"), 0)
21+
22+
def test_insert_returns_last_row_id(self):
23+
self.assertEqual(self.db.execute("INSERT INTO cs50(val) VALUES('foo')"), 1)
24+
self.assertEqual(self.db.execute("INSERT INTO cs50(val) VALUES('bar')"), 2)
25+
26+
def test_select_all(self):
27+
self.assertEqual(self.db.execute("SELECT * FROM cs50"), [])
28+
29+
rows = [
30+
{"id": 1, "val": "foo"},
31+
{"id": 2, "val": "bar"},
32+
{"id": 3, "val": "baz"}
33+
]
34+
for row in rows:
35+
self.db.execute("INSERT INTO cs50(val) VALUES(:val)", val=row["val"])
36+
37+
self.assertEqual(self.db.execute("SELECT * FROM cs50"), rows)
38+
39+
def test_select_cols(self):
40+
rows = [
41+
{"val": "foo"},
42+
{"val": "bar"},
43+
{"val": "baz"}
44+
]
45+
for row in rows:
46+
self.db.execute("INSERT INTO cs50(val) VALUES(:val)", val=row["val"])
47+
48+
self.assertEqual(self.db.execute("SELECT val FROM cs50"), rows)
49+
50+
def test_select_where(self):
51+
rows = [
52+
{"id": 1, "val": "foo"},
53+
{"id": 2, "val": "bar"},
54+
{"id": 3, "val": "baz"}
55+
]
56+
for row in rows:
57+
self.db.execute("INSERT INTO cs50(val) VALUES(:val)", val=row["val"])
58+
59+
self.assertEqual(self.db.execute("SELECT * FROM cs50 WHERE id = :id OR val = :val", id=rows[1]["id"], val=rows[2]["val"]), rows[1:3])
60+
61+
def test_update_returns_affected_rows(self):
62+
rows = [
63+
{"id": 1, "val": "foo"},
64+
{"id": 2, "val": "bar"},
65+
{"id": 3, "val": "baz"}
66+
]
67+
for row in rows:
68+
self.db.execute("INSERT INTO cs50(val) VALUES(:val)", val=row["val"])
69+
70+
self.assertEqual(self.db.execute("UPDATE cs50 SET val = 'foo' WHERE id > 1"), 2)
71+
self.assertEqual(self.db.execute("UPDATE cs50 SET val = 'foo' WHERE id = -50"), 0)
72+
73+
class MySQLTests(SQLTests):
74+
@classmethod
75+
def setUpClass(self):
76+
self.db = SQL("mysql://root@localhost/cs50_sql_tests")
77+
78+
def setUp(self):
79+
self.db.execute("CREATE TABLE cs50 (id INTEGER NOT NULL AUTO_INCREMENT, val VARCHAR(16), PRIMARY KEY (id))")
80+
81+
def tearDown(self):
82+
self.db.execute("DROP TABLE cs50")
83+
84+
@classmethod
85+
def tearDownClass(self):
86+
self.db.execute("DROP TABLE IF EXISTS cs50")
87+
88+
class PostgresTests(SQLTests):
89+
@classmethod
90+
def setUpClass(self):
91+
self.db = SQL("postgresql://postgres:postgres@localhost/cs50_sql_tests")
92+
93+
def setUp(self):
94+
self.db.execute("CREATE TABLE cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16))")
95+
96+
def tearDown(self):
97+
self.db.execute("DROP TABLE cs50")
98+
99+
@classmethod
100+
def tearDownClass(self):
101+
self.db.execute("DROP TABLE IF EXISTS cs50")
102+
103+
def test_insert_returns_last_row_id(self):
104+
self.assertEqual(self.db.execute("INSERT INTO cs50(val) VALUES('foo')"), 1)
105+
self.assertEqual(self.db.execute("INSERT INTO cs50(val) VALUES('bar')"), 2)
106+
107+
class SQLiteTests(SQLTests):
108+
@classmethod
109+
def setUpClass(self):
110+
self.db = SQL("sqlite:///cs50_sql_tests.db")
111+
112+
def setUp(self):
113+
self.db.execute("CREATE TABLE cs50(id INTEGER PRIMARY KEY, val TEXT)")
114+
115+
def tearDown(self):
116+
self.db.execute("DROP TABLE cs50")
117+
118+
@classmethod
119+
def tearDownClass(self):
120+
self.db.execute("DROP TABLE IF EXISTS cs50")
121+
122+
if __name__ == "__main__":
123+
suite = unittest.TestSuite([
124+
unittest.TestLoader().loadTestsFromTestCase(SQLiteTests),
125+
unittest.TestLoader().loadTestsFromTestCase(MySQLTests),
126+
unittest.TestLoader().loadTestsFromTestCase(PostgresTests)
127+
])
128+
129+
unittest.TextTestRunner(verbosity=2).run(suite)

0 commit comments

Comments
 (0)
Please sign in to comment.