From 71dea16d78c8fa117162b937516079b4d44fe951 Mon Sep 17 00:00:00 2001 From: "David J. Malan" <malan@harvard.edu> Date: Sun, 21 May 2017 00:19:17 -0400 Subject: [PATCH] fixed support for PostgreSQL --- cs50/sql.py | 41 ++++++++++----- test/sqltests.py | 129 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 157 insertions(+), 13 deletions(-) create mode 100644 test/sqltests.py diff --git a/cs50/sql.py b/cs50/sql.py index fa9a864..036e17c 100644 --- a/cs50/sql.py +++ b/cs50/sql.py @@ -1,22 +1,25 @@ import datetime +import re import sqlalchemy import sys +import warnings class SQL(object): """Wrap SQLAlchemy to provide a simple SQL API.""" - def __init__(self, url): + def __init__(self, url, **kwargs): """ Create instance of sqlalchemy.engine.Engine. URL should be a string that indicates database dialect and connection arguments. http://docs.sqlalchemy.org/en/latest/core/engines.html#sqlalchemy.create_engine + http://docs.sqlalchemy.org/en/latest/dialects/index.html """ try: - self.engine = sqlalchemy.create_engine(url) + self.engine = sqlalchemy.create_engine(url, **kwargs) except Exception as e: - e.__context__ = None + e.__cause__ = None raise RuntimeError(e) def execute(self, text, **params): @@ -79,6 +82,10 @@ def process(value): else: return process(value) + # raise exceptions for warnings + warnings.filterwarnings("error") + + # prepare, execute statement try: # construct a new TextClause clause @@ -97,29 +104,37 @@ def process(value): # stringify bound parameters # http://docs.sqlalchemy.org/en/latest/faq/sqlexpressions.html#how-do-i-render-sql-expressions-as-strings-possibly-with-bound-parameters-inlined - self.statement = str(statement.compile(compile_kwargs={"literal_binds": True})) + statement = str(statement.compile(compile_kwargs={"literal_binds": True})) # execute statement - result = self.engine.execute(self.statement) + result = self.engine.execute(statement) # if SELECT (or INSERT with RETURNING), return result set as list of dict objects - if result.returns_rows: + if re.search(r"^\s*SELECT\s+", statement, re.I): rows = result.fetchall() return [dict(row) for row in rows] # if INSERT, return primary key value for a newly inserted row - elif result.lastrowid is not None: - return result.lastrowid + elif re.search(r"^\s*INSERT\s+", statement, re.I): + if self.engine.url.get_backend_name() == "postgresql": + result = self.engine.execute(sqlalchemy.text("SELECT LASTVAL()")) + return result.first()[0] + else: + return result.lastrowid - # if DELETE or UPDATE (or INSERT without RETURNING), return number of rows matched - else: + # if DELETE or UPDATE, return number of rows matched + elif re.search(r"^\s*(?:DELETE|UPDATE)\s+", statement, re.I): return result.rowcount + # if some other statement, return True unless exception + return True + # if constraint violated, return None except sqlalchemy.exc.IntegrityError: return None - # else raise error + # else raise exception except Exception as e: - e.__context__ = None - raise RuntimeError(e) + _e = RuntimeError(e) # else Python 3 prints warnings' tracebacks + _e.__cause__ = None + raise _e diff --git a/test/sqltests.py b/test/sqltests.py new file mode 100644 index 0000000..d2204a1 --- /dev/null +++ b/test/sqltests.py @@ -0,0 +1,129 @@ +import unittest +from cs50.sql import SQL + +class SQLTests(unittest.TestCase): + def test_delete_returns_affected_rows(self): + rows = [ + {"id": 1, "val": "foo"}, + {"id": 2, "val": "bar"}, + {"id": 3, "val": "baz"} + ] + for row in rows: + self.db.execute("INSERT INTO cs50(val) VALUES(:val);", val=row["val"]) + + print(self.db.execute("DELETE FROM cs50 WHERE id = :id", id=rows[0]["id"])) + print(self.db.execute("SELECT * FROM cs50")) + return + + self.assertEqual(self.db.execute("DELETE FROM cs50 WHERE id = :id", id=rows[0]["id"]), 1) + self.assertEqual(self.db.execute("DELETE FROM cs50 WHERE id = :a or id = :b", a=rows[1]["id"], b=rows[2]["id"]), 2) + self.assertEqual(self.db.execute("DELETE FROM cs50 WHERE id = -50"), 0) + + def test_insert_returns_last_row_id(self): + self.assertEqual(self.db.execute("INSERT INTO cs50(val) VALUES('foo')"), 1) + self.assertEqual(self.db.execute("INSERT INTO cs50(val) VALUES('bar')"), 2) + + def test_select_all(self): + self.assertEqual(self.db.execute("SELECT * FROM cs50"), []) + + rows = [ + {"id": 1, "val": "foo"}, + {"id": 2, "val": "bar"}, + {"id": 3, "val": "baz"} + ] + for row in rows: + self.db.execute("INSERT INTO cs50(val) VALUES(:val)", val=row["val"]) + + self.assertEqual(self.db.execute("SELECT * FROM cs50"), rows) + + def test_select_cols(self): + rows = [ + {"val": "foo"}, + {"val": "bar"}, + {"val": "baz"} + ] + for row in rows: + self.db.execute("INSERT INTO cs50(val) VALUES(:val)", val=row["val"]) + + self.assertEqual(self.db.execute("SELECT val FROM cs50"), rows) + + def test_select_where(self): + rows = [ + {"id": 1, "val": "foo"}, + {"id": 2, "val": "bar"}, + {"id": 3, "val": "baz"} + ] + for row in rows: + self.db.execute("INSERT INTO cs50(val) VALUES(:val)", val=row["val"]) + + self.assertEqual(self.db.execute("SELECT * FROM cs50 WHERE id = :id OR val = :val", id=rows[1]["id"], val=rows[2]["val"]), rows[1:3]) + + def test_update_returns_affected_rows(self): + rows = [ + {"id": 1, "val": "foo"}, + {"id": 2, "val": "bar"}, + {"id": 3, "val": "baz"} + ] + for row in rows: + self.db.execute("INSERT INTO cs50(val) VALUES(:val)", val=row["val"]) + + self.assertEqual(self.db.execute("UPDATE cs50 SET val = 'foo' WHERE id > 1"), 2) + self.assertEqual(self.db.execute("UPDATE cs50 SET val = 'foo' WHERE id = -50"), 0) + +class MySQLTests(SQLTests): + @classmethod + def setUpClass(self): + self.db = SQL("mysql://root@localhost/cs50_sql_tests") + + def setUp(self): + self.db.execute("CREATE TABLE cs50 (id INTEGER NOT NULL AUTO_INCREMENT, val VARCHAR(16), PRIMARY KEY (id))") + + def tearDown(self): + self.db.execute("DROP TABLE cs50") + + @classmethod + def tearDownClass(self): + self.db.execute("DROP TABLE IF EXISTS cs50") + +class PostgresTests(SQLTests): + @classmethod + def setUpClass(self): + self.db = SQL("postgresql://postgres:postgres@localhost/cs50_sql_tests") + + def setUp(self): + self.db.execute("CREATE TABLE cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16))") + + def tearDown(self): + self.db.execute("DROP TABLE cs50") + + @classmethod + def tearDownClass(self): + self.db.execute("DROP TABLE IF EXISTS cs50") + + def test_insert_returns_last_row_id(self): + self.assertEqual(self.db.execute("INSERT INTO cs50(val) VALUES('foo')"), 1) + self.assertEqual(self.db.execute("INSERT INTO cs50(val) VALUES('bar')"), 2) + +class SQLiteTests(SQLTests): + @classmethod + def setUpClass(self): + self.db = SQL("sqlite:///cs50_sql_tests.db") + + def setUp(self): + self.db.execute("CREATE TABLE cs50(id INTEGER PRIMARY KEY, val TEXT)") + + def tearDown(self): + self.db.execute("DROP TABLE cs50") + + @classmethod + def tearDownClass(self): + self.db.execute("DROP TABLE IF EXISTS cs50") + +if __name__ == "__main__": + suite = unittest.TestSuite([ + unittest.TestLoader().loadTestsFromTestCase(SQLiteTests), + unittest.TestLoader().loadTestsFromTestCase(MySQLTests), + unittest.TestLoader().loadTestsFromTestCase(PostgresTests) + ]) + + unittest.TextTestRunner(verbosity=2).run(suite)