Skip to content

Commit 71dea16

Browse files
committedMay 21, 2017
fixed support for PostgreSQL
1 parent d626967 commit 71dea16

File tree

2 files changed

+157
-13
lines changed

2 files changed

+157
-13
lines changed
 

‎cs50/sql.py

+28-13
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,25 @@
11
import datetime
2+
import re
23
import sqlalchemy
34
import sys
5+
import warnings
46

57
class SQL(object):
68
"""Wrap SQLAlchemy to provide a simple SQL API."""
79

8-
def __init__(self, url):
10+
def __init__(self, url, **kwargs):
911
"""
1012
Create instance of sqlalchemy.engine.Engine.
1113
1214
URL should be a string that indicates database dialect and connection arguments.
1315
1416
http://docs.sqlalchemy.org/en/latest/core/engines.html#sqlalchemy.create_engine
17+
http://docs.sqlalchemy.org/en/latest/dialects/index.html
1518
"""
1619
try:
17-
self.engine = sqlalchemy.create_engine(url)
20+
self.engine = sqlalchemy.create_engine(url, **kwargs)
1821
except Exception as e:
19-
e.__context__ = None
22+
e.__cause__ = None
2023
raise RuntimeError(e)
2124

2225
def execute(self, text, **params):
@@ -79,6 +82,10 @@ def process(value):
7982
else:
8083
return process(value)
8184

85+
# raise exceptions for warnings
86+
warnings.filterwarnings("error")
87+
88+
# prepare, execute statement
8289
try:
8390

8491
# construct a new TextClause clause
@@ -97,29 +104,37 @@ def process(value):
97104

98105
# stringify bound parameters
99106
# http://docs.sqlalchemy.org/en/latest/faq/sqlexpressions.html#how-do-i-render-sql-expressions-as-strings-possibly-with-bound-parameters-inlined
100-
self.statement = str(statement.compile(compile_kwargs={"literal_binds": True}))
107+
statement = str(statement.compile(compile_kwargs={"literal_binds": True}))
101108

102109
# execute statement
103-
result = self.engine.execute(self.statement)
110+
result = self.engine.execute(statement)
104111

105112
# if SELECT (or INSERT with RETURNING), return result set as list of dict objects
106-
if result.returns_rows:
113+
if re.search(r"^\s*SELECT\s+", statement, re.I):
107114
rows = result.fetchall()
108115
return [dict(row) for row in rows]
109116

110117
# if INSERT, return primary key value for a newly inserted row
111-
elif result.lastrowid is not None:
112-
return result.lastrowid
118+
elif re.search(r"^\s*INSERT\s+", statement, re.I):
119+
if self.engine.url.get_backend_name() == "postgresql":
120+
result = self.engine.execute(sqlalchemy.text("SELECT LASTVAL()"))
121+
return result.first()[0]
122+
else:
123+
return result.lastrowid
113124

114-
# if DELETE or UPDATE (or INSERT without RETURNING), return number of rows matched
115-
else:
125+
# if DELETE or UPDATE, return number of rows matched
126+
elif re.search(r"^\s*(?:DELETE|UPDATE)\s+", statement, re.I):
116127
return result.rowcount
117128

129+
# if some other statement, return True unless exception
130+
return True
131+
118132
# if constraint violated, return None
119133
except sqlalchemy.exc.IntegrityError:
120134
return None
121135

122-
# else raise error
136+
# else raise exception
123137
except Exception as e:
124-
e.__context__ = None
125-
raise RuntimeError(e)
138+
_e = RuntimeError(e) # else Python 3 prints warnings' tracebacks
139+
_e.__cause__ = None
140+
raise _e

‎test/sqltests.py

+129
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
import unittest
2+
from cs50.sql import SQL
3+
4+
class SQLTests(unittest.TestCase):
5+
def test_delete_returns_affected_rows(self):
6+
rows = [
7+
{"id": 1, "val": "foo"},
8+
{"id": 2, "val": "bar"},
9+
{"id": 3, "val": "baz"}
10+
]
11+
for row in rows:
12+
self.db.execute("INSERT INTO cs50(val) VALUES(:val);", val=row["val"])
13+
14+
print(self.db.execute("DELETE FROM cs50 WHERE id = :id", id=rows[0]["id"]))
15+
print(self.db.execute("SELECT * FROM cs50"))
16+
return
17+
18+
self.assertEqual(self.db.execute("DELETE FROM cs50 WHERE id = :id", id=rows[0]["id"]), 1)
19+
self.assertEqual(self.db.execute("DELETE FROM cs50 WHERE id = :a or id = :b", a=rows[1]["id"], b=rows[2]["id"]), 2)
20+
self.assertEqual(self.db.execute("DELETE FROM cs50 WHERE id = -50"), 0)
21+
22+
def test_insert_returns_last_row_id(self):
23+
self.assertEqual(self.db.execute("INSERT INTO cs50(val) VALUES('foo')"), 1)
24+
self.assertEqual(self.db.execute("INSERT INTO cs50(val) VALUES('bar')"), 2)
25+
26+
def test_select_all(self):
27+
self.assertEqual(self.db.execute("SELECT * FROM cs50"), [])
28+
29+
rows = [
30+
{"id": 1, "val": "foo"},
31+
{"id": 2, "val": "bar"},
32+
{"id": 3, "val": "baz"}
33+
]
34+
for row in rows:
35+
self.db.execute("INSERT INTO cs50(val) VALUES(:val)", val=row["val"])
36+
37+
self.assertEqual(self.db.execute("SELECT * FROM cs50"), rows)
38+
39+
def test_select_cols(self):
40+
rows = [
41+
{"val": "foo"},
42+
{"val": "bar"},
43+
{"val": "baz"}
44+
]
45+
for row in rows:
46+
self.db.execute("INSERT INTO cs50(val) VALUES(:val)", val=row["val"])
47+
48+
self.assertEqual(self.db.execute("SELECT val FROM cs50"), rows)
49+
50+
def test_select_where(self):
51+
rows = [
52+
{"id": 1, "val": "foo"},
53+
{"id": 2, "val": "bar"},
54+
{"id": 3, "val": "baz"}
55+
]
56+
for row in rows:
57+
self.db.execute("INSERT INTO cs50(val) VALUES(:val)", val=row["val"])
58+
59+
self.assertEqual(self.db.execute("SELECT * FROM cs50 WHERE id = :id OR val = :val", id=rows[1]["id"], val=rows[2]["val"]), rows[1:3])
60+
61+
def test_update_returns_affected_rows(self):
62+
rows = [
63+
{"id": 1, "val": "foo"},
64+
{"id": 2, "val": "bar"},
65+
{"id": 3, "val": "baz"}
66+
]
67+
for row in rows:
68+
self.db.execute("INSERT INTO cs50(val) VALUES(:val)", val=row["val"])
69+
70+
self.assertEqual(self.db.execute("UPDATE cs50 SET val = 'foo' WHERE id > 1"), 2)
71+
self.assertEqual(self.db.execute("UPDATE cs50 SET val = 'foo' WHERE id = -50"), 0)
72+
73+
class MySQLTests(SQLTests):
74+
@classmethod
75+
def setUpClass(self):
76+
self.db = SQL("mysql://root@localhost/cs50_sql_tests")
77+
78+
def setUp(self):
79+
self.db.execute("CREATE TABLE cs50 (id INTEGER NOT NULL AUTO_INCREMENT, val VARCHAR(16), PRIMARY KEY (id))")
80+
81+
def tearDown(self):
82+
self.db.execute("DROP TABLE cs50")
83+
84+
@classmethod
85+
def tearDownClass(self):
86+
self.db.execute("DROP TABLE IF EXISTS cs50")
87+
88+
class PostgresTests(SQLTests):
89+
@classmethod
90+
def setUpClass(self):
91+
self.db = SQL("postgresql://postgres:postgres@localhost/cs50_sql_tests")
92+
93+
def setUp(self):
94+
self.db.execute("CREATE TABLE cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16))")
95+
96+
def tearDown(self):
97+
self.db.execute("DROP TABLE cs50")
98+
99+
@classmethod
100+
def tearDownClass(self):
101+
self.db.execute("DROP TABLE IF EXISTS cs50")
102+
103+
def test_insert_returns_last_row_id(self):
104+
self.assertEqual(self.db.execute("INSERT INTO cs50(val) VALUES('foo')"), 1)
105+
self.assertEqual(self.db.execute("INSERT INTO cs50(val) VALUES('bar')"), 2)
106+
107+
class SQLiteTests(SQLTests):
108+
@classmethod
109+
def setUpClass(self):
110+
self.db = SQL("sqlite:///cs50_sql_tests.db")
111+
112+
def setUp(self):
113+
self.db.execute("CREATE TABLE cs50(id INTEGER PRIMARY KEY, val TEXT)")
114+
115+
def tearDown(self):
116+
self.db.execute("DROP TABLE cs50")
117+
118+
@classmethod
119+
def tearDownClass(self):
120+
self.db.execute("DROP TABLE IF EXISTS cs50")
121+
122+
if __name__ == "__main__":
123+
suite = unittest.TestSuite([
124+
unittest.TestLoader().loadTestsFromTestCase(SQLiteTests),
125+
unittest.TestLoader().loadTestsFromTestCase(MySQLTests),
126+
unittest.TestLoader().loadTestsFromTestCase(PostgresTests)
127+
])
128+
129+
unittest.TextTestRunner(verbosity=2).run(suite)

0 commit comments

Comments
 (0)
Please sign in to comment.