Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixed support for PostgreSQL #20

Merged
merged 1 commit into from
May 21, 2017
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 28 additions & 13 deletions cs50/sql.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
import datetime
import re
import sqlalchemy
import sys
import warnings

class SQL(object):
"""Wrap SQLAlchemy to provide a simple SQL API."""

def __init__(self, url):
def __init__(self, url, **kwargs):
"""
Create instance of sqlalchemy.engine.Engine.
URL should be a string that indicates database dialect and connection arguments.
http://docs.sqlalchemy.org/en/latest/core/engines.html#sqlalchemy.create_engine
http://docs.sqlalchemy.org/en/latest/dialects/index.html
"""
try:
self.engine = sqlalchemy.create_engine(url)
self.engine = sqlalchemy.create_engine(url, **kwargs)
except Exception as e:
e.__context__ = None
e.__cause__ = None
raise RuntimeError(e)

def execute(self, text, **params):
@@ -79,6 +82,10 @@ def process(value):
else:
return process(value)

# raise exceptions for warnings
warnings.filterwarnings("error")

# prepare, execute statement
try:

# construct a new TextClause clause
@@ -97,29 +104,37 @@ def process(value):

# stringify bound parameters
# http://docs.sqlalchemy.org/en/latest/faq/sqlexpressions.html#how-do-i-render-sql-expressions-as-strings-possibly-with-bound-parameters-inlined
self.statement = str(statement.compile(compile_kwargs={"literal_binds": True}))
statement = str(statement.compile(compile_kwargs={"literal_binds": True}))

# execute statement
result = self.engine.execute(self.statement)
result = self.engine.execute(statement)

# if SELECT (or INSERT with RETURNING), return result set as list of dict objects
if result.returns_rows:
if re.search(r"^\s*SELECT\s+", statement, re.I):
rows = result.fetchall()
return [dict(row) for row in rows]

# if INSERT, return primary key value for a newly inserted row
elif result.lastrowid is not None:
return result.lastrowid
elif re.search(r"^\s*INSERT\s+", statement, re.I):
if self.engine.url.get_backend_name() == "postgresql":
result = self.engine.execute(sqlalchemy.text("SELECT LASTVAL()"))
return result.first()[0]
else:
return result.lastrowid

# if DELETE or UPDATE (or INSERT without RETURNING), return number of rows matched
else:
# if DELETE or UPDATE, return number of rows matched
elif re.search(r"^\s*(?:DELETE|UPDATE)\s+", statement, re.I):
return result.rowcount

# if some other statement, return True unless exception
return True

# if constraint violated, return None
except sqlalchemy.exc.IntegrityError:
return None

# else raise error
# else raise exception
except Exception as e:
e.__context__ = None
raise RuntimeError(e)
_e = RuntimeError(e) # else Python 3 prints warnings' tracebacks
_e.__cause__ = None
raise _e
129 changes: 129 additions & 0 deletions test/sqltests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import unittest
from cs50.sql import SQL

class SQLTests(unittest.TestCase):
def test_delete_returns_affected_rows(self):
rows = [
{"id": 1, "val": "foo"},
{"id": 2, "val": "bar"},
{"id": 3, "val": "baz"}
]
for row in rows:
self.db.execute("INSERT INTO cs50(val) VALUES(:val);", val=row["val"])

print(self.db.execute("DELETE FROM cs50 WHERE id = :id", id=rows[0]["id"]))
print(self.db.execute("SELECT * FROM cs50"))
return

self.assertEqual(self.db.execute("DELETE FROM cs50 WHERE id = :id", id=rows[0]["id"]), 1)
self.assertEqual(self.db.execute("DELETE FROM cs50 WHERE id = :a or id = :b", a=rows[1]["id"], b=rows[2]["id"]), 2)
self.assertEqual(self.db.execute("DELETE FROM cs50 WHERE id = -50"), 0)

def test_insert_returns_last_row_id(self):
self.assertEqual(self.db.execute("INSERT INTO cs50(val) VALUES('foo')"), 1)
self.assertEqual(self.db.execute("INSERT INTO cs50(val) VALUES('bar')"), 2)

def test_select_all(self):
self.assertEqual(self.db.execute("SELECT * FROM cs50"), [])

rows = [
{"id": 1, "val": "foo"},
{"id": 2, "val": "bar"},
{"id": 3, "val": "baz"}
]
for row in rows:
self.db.execute("INSERT INTO cs50(val) VALUES(:val)", val=row["val"])

self.assertEqual(self.db.execute("SELECT * FROM cs50"), rows)

def test_select_cols(self):
rows = [
{"val": "foo"},
{"val": "bar"},
{"val": "baz"}
]
for row in rows:
self.db.execute("INSERT INTO cs50(val) VALUES(:val)", val=row["val"])

self.assertEqual(self.db.execute("SELECT val FROM cs50"), rows)

def test_select_where(self):
rows = [
{"id": 1, "val": "foo"},
{"id": 2, "val": "bar"},
{"id": 3, "val": "baz"}
]
for row in rows:
self.db.execute("INSERT INTO cs50(val) VALUES(:val)", val=row["val"])

self.assertEqual(self.db.execute("SELECT * FROM cs50 WHERE id = :id OR val = :val", id=rows[1]["id"], val=rows[2]["val"]), rows[1:3])

def test_update_returns_affected_rows(self):
rows = [
{"id": 1, "val": "foo"},
{"id": 2, "val": "bar"},
{"id": 3, "val": "baz"}
]
for row in rows:
self.db.execute("INSERT INTO cs50(val) VALUES(:val)", val=row["val"])

self.assertEqual(self.db.execute("UPDATE cs50 SET val = 'foo' WHERE id > 1"), 2)
self.assertEqual(self.db.execute("UPDATE cs50 SET val = 'foo' WHERE id = -50"), 0)

class MySQLTests(SQLTests):
@classmethod
def setUpClass(self):
self.db = SQL("mysql://root@localhost/cs50_sql_tests")

def setUp(self):
self.db.execute("CREATE TABLE cs50 (id INTEGER NOT NULL AUTO_INCREMENT, val VARCHAR(16), PRIMARY KEY (id))")

def tearDown(self):
self.db.execute("DROP TABLE cs50")

@classmethod
def tearDownClass(self):
self.db.execute("DROP TABLE IF EXISTS cs50")

class PostgresTests(SQLTests):
@classmethod
def setUpClass(self):
self.db = SQL("postgresql://postgres:postgres@localhost/cs50_sql_tests")

def setUp(self):
self.db.execute("CREATE TABLE cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16))")

def tearDown(self):
self.db.execute("DROP TABLE cs50")

@classmethod
def tearDownClass(self):
self.db.execute("DROP TABLE IF EXISTS cs50")

def test_insert_returns_last_row_id(self):
self.assertEqual(self.db.execute("INSERT INTO cs50(val) VALUES('foo')"), 1)
self.assertEqual(self.db.execute("INSERT INTO cs50(val) VALUES('bar')"), 2)

class SQLiteTests(SQLTests):
@classmethod
def setUpClass(self):
self.db = SQL("sqlite:///cs50_sql_tests.db")

def setUp(self):
self.db.execute("CREATE TABLE cs50(id INTEGER PRIMARY KEY, val TEXT)")

def tearDown(self):
self.db.execute("DROP TABLE cs50")

@classmethod
def tearDownClass(self):
self.db.execute("DROP TABLE IF EXISTS cs50")

if __name__ == "__main__":
suite = unittest.TestSuite([
unittest.TestLoader().loadTestsFromTestCase(SQLiteTests),
unittest.TestLoader().loadTestsFromTestCase(MySQLTests),
unittest.TestLoader().loadTestsFromTestCase(PostgresTests)
])

unittest.TextTestRunner(verbosity=2).run(suite)