From 71dea16d78c8fa117162b937516079b4d44fe951 Mon Sep 17 00:00:00 2001
From: "David J. Malan" <malan@harvard.edu>
Date: Sun, 21 May 2017 00:19:17 -0400
Subject: [PATCH] fixed support for PostgreSQL

---
 cs50/sql.py      |  41 ++++++++++-----
 test/sqltests.py | 129 +++++++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 157 insertions(+), 13 deletions(-)
 create mode 100644 test/sqltests.py

diff --git a/cs50/sql.py b/cs50/sql.py
index fa9a864..036e17c 100644
--- a/cs50/sql.py
+++ b/cs50/sql.py
@@ -1,22 +1,25 @@
 import datetime
+import re
 import sqlalchemy
 import sys
+import warnings
 
 class SQL(object):
     """Wrap SQLAlchemy to provide a simple SQL API."""
 
-    def __init__(self, url):
+    def __init__(self, url, **kwargs):
         """
         Create instance of sqlalchemy.engine.Engine.
 
         URL should be a string that indicates database dialect and connection arguments.
 
         http://docs.sqlalchemy.org/en/latest/core/engines.html#sqlalchemy.create_engine
+        http://docs.sqlalchemy.org/en/latest/dialects/index.html
         """
         try:
-            self.engine = sqlalchemy.create_engine(url)
+            self.engine = sqlalchemy.create_engine(url, **kwargs)
         except Exception as e:
-            e.__context__ = None
+            e.__cause__ = None
             raise RuntimeError(e)
 
     def execute(self, text, **params):
@@ -79,6 +82,10 @@ def process(value):
                 else:
                     return process(value)
 
+        # raise exceptions for warnings
+        warnings.filterwarnings("error")
+
+        # prepare, execute statement
         try:
 
             # construct a new TextClause clause
@@ -97,29 +104,37 @@ def process(value):
 
             # stringify bound parameters
             # http://docs.sqlalchemy.org/en/latest/faq/sqlexpressions.html#how-do-i-render-sql-expressions-as-strings-possibly-with-bound-parameters-inlined
-            self.statement = str(statement.compile(compile_kwargs={"literal_binds": True}))
+            statement = str(statement.compile(compile_kwargs={"literal_binds": True}))
 
             # execute statement
-            result = self.engine.execute(self.statement)
+            result = self.engine.execute(statement)
 
             # if SELECT (or INSERT with RETURNING), return result set as list of dict objects
-            if result.returns_rows:
+            if re.search(r"^\s*SELECT\s+", statement, re.I):
                 rows = result.fetchall()
                 return [dict(row) for row in rows]
 
             # if INSERT, return primary key value for a newly inserted row
-            elif result.lastrowid is not None:
-                return result.lastrowid
+            elif re.search(r"^\s*INSERT\s+", statement, re.I):
+                if self.engine.url.get_backend_name() == "postgresql":
+                    result = self.engine.execute(sqlalchemy.text("SELECT LASTVAL()"))
+                    return result.first()[0]
+                else:
+                    return result.lastrowid
 
-            # if DELETE or UPDATE (or INSERT without RETURNING), return number of rows matched
-            else:
+            # if DELETE or UPDATE, return number of rows matched
+            elif re.search(r"^\s*(?:DELETE|UPDATE)\s+", statement, re.I):
                 return result.rowcount
 
+            # if some other statement, return True unless exception
+            return True
+
         # if constraint violated, return None
         except sqlalchemy.exc.IntegrityError:
             return None
 
-        # else raise error
+        # else raise exception
         except Exception as e:
-            e.__context__ = None
-            raise RuntimeError(e)
+            _e = RuntimeError(e) # else Python 3 prints warnings' tracebacks
+            _e.__cause__ = None
+            raise _e
diff --git a/test/sqltests.py b/test/sqltests.py
new file mode 100644
index 0000000..d2204a1
--- /dev/null
+++ b/test/sqltests.py
@@ -0,0 +1,129 @@
+import unittest
+from cs50.sql import SQL
+
+class SQLTests(unittest.TestCase):
+    def test_delete_returns_affected_rows(self):
+        rows = [
+            {"id": 1, "val": "foo"},
+            {"id": 2, "val": "bar"},
+            {"id": 3, "val": "baz"}
+        ]
+        for row in rows:
+            self.db.execute("INSERT INTO cs50(val) VALUES(:val);", val=row["val"])
+
+        print(self.db.execute("DELETE FROM cs50 WHERE id = :id", id=rows[0]["id"]))
+        print(self.db.execute("SELECT * FROM cs50"))
+        return
+
+        self.assertEqual(self.db.execute("DELETE FROM cs50 WHERE id = :id", id=rows[0]["id"]), 1)
+        self.assertEqual(self.db.execute("DELETE FROM cs50 WHERE id = :a or id = :b", a=rows[1]["id"], b=rows[2]["id"]), 2)
+        self.assertEqual(self.db.execute("DELETE FROM cs50 WHERE id = -50"), 0)
+
+    def test_insert_returns_last_row_id(self):
+        self.assertEqual(self.db.execute("INSERT INTO cs50(val) VALUES('foo')"), 1)
+        self.assertEqual(self.db.execute("INSERT INTO cs50(val) VALUES('bar')"), 2)
+
+    def test_select_all(self):
+        self.assertEqual(self.db.execute("SELECT * FROM cs50"), [])
+
+        rows = [
+            {"id": 1, "val": "foo"},
+            {"id": 2, "val": "bar"},
+            {"id": 3, "val": "baz"}
+        ]
+        for row in rows:
+            self.db.execute("INSERT INTO cs50(val) VALUES(:val)", val=row["val"])
+
+        self.assertEqual(self.db.execute("SELECT * FROM cs50"), rows)
+
+    def test_select_cols(self):
+        rows = [
+            {"val": "foo"},
+            {"val": "bar"},
+            {"val": "baz"}
+        ]
+        for row in rows:
+            self.db.execute("INSERT INTO cs50(val) VALUES(:val)", val=row["val"])
+
+        self.assertEqual(self.db.execute("SELECT val FROM cs50"), rows)
+
+    def test_select_where(self):
+        rows = [
+            {"id": 1, "val": "foo"},
+            {"id": 2, "val": "bar"},
+            {"id": 3, "val": "baz"}
+        ]
+        for row in rows:
+            self.db.execute("INSERT INTO cs50(val) VALUES(:val)", val=row["val"])
+
+        self.assertEqual(self.db.execute("SELECT * FROM cs50 WHERE id = :id OR val = :val", id=rows[1]["id"], val=rows[2]["val"]), rows[1:3])
+
+    def test_update_returns_affected_rows(self):
+        rows = [
+            {"id": 1, "val": "foo"},
+            {"id": 2, "val": "bar"},
+            {"id": 3, "val": "baz"}
+        ]
+        for row in rows:
+            self.db.execute("INSERT INTO cs50(val) VALUES(:val)", val=row["val"])
+
+        self.assertEqual(self.db.execute("UPDATE cs50 SET val = 'foo' WHERE id > 1"), 2)
+        self.assertEqual(self.db.execute("UPDATE cs50 SET val = 'foo' WHERE id = -50"), 0)
+
+class MySQLTests(SQLTests):
+    @classmethod
+    def setUpClass(self):
+        self.db = SQL("mysql://root@localhost/cs50_sql_tests")
+
+    def setUp(self):
+        self.db.execute("CREATE TABLE cs50 (id INTEGER NOT NULL AUTO_INCREMENT, val VARCHAR(16), PRIMARY KEY (id))")
+
+    def tearDown(self):
+        self.db.execute("DROP TABLE cs50")
+
+    @classmethod
+    def tearDownClass(self):
+        self.db.execute("DROP TABLE IF EXISTS cs50")
+
+class PostgresTests(SQLTests):
+    @classmethod
+    def setUpClass(self):
+        self.db = SQL("postgresql://postgres:postgres@localhost/cs50_sql_tests")
+
+    def setUp(self):
+        self.db.execute("CREATE TABLE cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16))")
+
+    def tearDown(self):
+        self.db.execute("DROP TABLE cs50")
+
+    @classmethod
+    def tearDownClass(self):
+        self.db.execute("DROP TABLE IF EXISTS cs50")
+
+    def test_insert_returns_last_row_id(self):
+        self.assertEqual(self.db.execute("INSERT INTO cs50(val) VALUES('foo')"), 1)
+        self.assertEqual(self.db.execute("INSERT INTO cs50(val) VALUES('bar')"), 2)
+
+class SQLiteTests(SQLTests):
+    @classmethod
+    def setUpClass(self):
+        self.db = SQL("sqlite:///cs50_sql_tests.db")
+
+    def setUp(self):
+        self.db.execute("CREATE TABLE cs50(id INTEGER PRIMARY KEY, val TEXT)")
+
+    def tearDown(self):
+        self.db.execute("DROP TABLE cs50")
+
+    @classmethod
+    def tearDownClass(self):
+        self.db.execute("DROP TABLE IF EXISTS cs50")
+
+if __name__ == "__main__":
+    suite = unittest.TestSuite([
+        unittest.TestLoader().loadTestsFromTestCase(SQLiteTests),
+        unittest.TestLoader().loadTestsFromTestCase(MySQLTests),
+        unittest.TestLoader().loadTestsFromTestCase(PostgresTests)
+    ])
+
+    unittest.TextTestRunner(verbosity=2).run(suite)