Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adds support for expandable parameters #19

Merged
merged 4 commits into from
May 21, 2017
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 80 additions & 7 deletions cs50/sql.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import datetime
import sqlalchemy
import sys

class SQL(object):
"""Wrap SQLAlchemy to provide a simple SQL API."""
@@ -14,21 +16,91 @@ def __init__(self, url):
try:
self.engine = sqlalchemy.create_engine(url)
except Exception as e:
e.__context__ = None
raise RuntimeError(e)

def execute(self, text, *multiparams, **params):
def execute(self, text, **params):
"""
Execute a SQL statement.
"""

class UserDefinedType(sqlalchemy.TypeDecorator):
"""
Add support for expandable values, a la https://bitbucket.org/zzzeek/sqlalchemy/issues/3953/expanding-parameter.
"""
impl = sqlalchemy.types.UserDefinedType
def process_literal_param(self, value, dialect):
"""Receive a literal parameter value to be rendered inline within a statement."""
def process(value):
"""Render a literal value, escaping as needed."""

# bool
if isinstance(value, bool):
return sqlalchemy.types.Boolean().literal_processor(dialect)(value)

# datetime.date
elif isinstance(value, datetime.date):
return sqlalchemy.types.String().literal_processor(dialect)(value.strftime("%Y-%m-%d"))

# datetime.datetime
elif isinstance(value, datetime.datetime):
return sqlalchemy.types.String().literal_processor(dialect)(value.strftime("%Y-%m-%d %H:%M:%S"))

# datetime.time
elif isinstance(value, datetime.time):
return sqlalchemy.types.String().literal_processor(dialect)(value.strftime("%H:%M:%S"))

# float
elif isinstance(value, float):
return sqlalchemy.types.Float().literal_processor(dialect)(value)

# int
elif isinstance(value, int):
return sqlalchemy.types.Integer().literal_processor(dialect)(value)

# long
elif sys.version_info.major != 3 and isinstance(value, long):
return sqlalchemy.types.Integer().literal_processor(dialect)(value)

# str
elif isinstance(value, str):
return sqlalchemy.types.String().literal_processor(dialect)(value)

# None
elif isinstance(value, sqlalchemy.sql.elements.Null):
return sqlalchemy.types.NullType().literal_processor(dialect)(value)

# unsupported value
raise RuntimeError("unsupported value")

# process value(s), separating with commas as needed
if type(value) is list:
return ", ".join([process(v) for v in value])
else:
return process(value)

try:

# bind parameters before statement reaches database, so that bound parameters appear in exceptions
# http://docs.sqlalchemy.org/en/latest/core/sqlelement.html#sqlalchemy.sql.expression.text
# https://groups.google.com/forum/#!topic/sqlalchemy/FfLwKT1yQlg
# http://docs.sqlalchemy.org/en/latest/core/connections.html#sqlalchemy.engine.Engine.execute
# construct a new TextClause clause
statement = sqlalchemy.text(text)

# iterate over parameters
for key, value in params.items():

# translate None to NULL
if value is None:
value = sqlalchemy.sql.null()

# bind parameters before statement reaches database, so that bound parameters appear in exceptions
# http://docs.sqlalchemy.org/en/latest/core/sqlelement.html#sqlalchemy.sql.expression.text
statement = statement.bindparams(sqlalchemy.bindparam(key, value=value, type_=UserDefinedType()))

# stringify bound parameters
# http://docs.sqlalchemy.org/en/latest/faq/sqlexpressions.html#how-do-i-render-sql-expressions-as-strings-possibly-with-bound-parameters-inlined
statement = sqlalchemy.text(text).bindparams(*multiparams, **params)
result = self.engine.execute(str(statement.compile(compile_kwargs={"literal_binds": True})))
self.statement = str(statement.compile(compile_kwargs={"literal_binds": True}))

# execute statement
result = self.engine.execute(self.statement)

# if SELECT (or INSERT with RETURNING), return result set as list of dict objects
if result.returns_rows:
@@ -49,4 +121,5 @@ def execute(self, text, *multiparams, **params):

# else raise error
except Exception as e:
e.__context__ = None
raise RuntimeError(e)