Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adds support for expandable parameters #19

Merged
merged 4 commits into from
May 21, 2017
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 77 additions & 43 deletions cs50/sql.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import re
import datetime
import sqlalchemy
import sqlparse
import sys

class SQL(object):
"""Wrap SQLAlchemy to provide a simple SQL API."""
@@ -16,58 +16,91 @@ def __init__(self, url):
try:
self.engine = sqlalchemy.create_engine(url)
except Exception as e:
e.__context__ = None
raise RuntimeError(e)

def execute(self, text, *multiparams, **params):
def execute(self, text, **params):
"""
Execute a SQL statement.
"""

# parse text
parsed = sqlparse.parse(text)
if len(parsed) == 0:
raise RuntimeError("missing statement")
elif len(parsed) > 1:
raise RuntimeError("too many statements")
statement = parsed[0]
if statement.get_type() == "UNKNOWN":
raise RuntimeError("unknown type of statement")

# infer paramstyle
# https://www.python.org/dev/peps/pep-0249/#paramstyle
paramstyle = None
for token in statement.flatten():
if sqlparse.utils.imt(token.ttype, t=sqlparse.tokens.Token.Name.Placeholder):
_paramstyle = None
if re.search(r"^\?$", token.value):
_paramstyle = "qmark"
elif re.search(r"^:\d+$", token.value):
_paramstyle = "numeric"
elif re.search(r"^:\w+$", token.value):
_paramstyle = "named"
elif re.search(r"^%s$", token.value):
_paramstyle = "format"
elif re.search(r"^%\(\w+\)s$", token.value):
_paramstyle = "pyformat"
else:
raise RuntimeError("unknown paramstyle")
if paramstyle and paramstyle != _paramstyle:
raise RuntimeError("inconsistent paramstyle")
paramstyle = _paramstyle
class UserDefinedType(sqlalchemy.TypeDecorator):
"""
Add support for expandable values, a la https://bitbucket.org/zzzeek/sqlalchemy/issues/3953/expanding-parameter.
"""
impl = sqlalchemy.types.UserDefinedType
def process_literal_param(self, value, dialect):
"""Receive a literal parameter value to be rendered inline within a statement."""
def process(value):
"""Render a literal value, escaping as needed."""

# bool
if isinstance(value, bool):
return sqlalchemy.types.Boolean().literal_processor(dialect)(value)

# datetime.date
elif isinstance(value, datetime.date):
return sqlalchemy.types.String().literal_processor(dialect)(value.strftime("%Y-%m-%d"))

# datetime.datetime
elif isinstance(value, datetime.datetime):
return sqlalchemy.types.String().literal_processor(dialect)(value.strftime("%Y-%m-%d %H:%M:%S"))

# datetime.time
elif isinstance(value, datetime.time):
return sqlalchemy.types.String().literal_processor(dialect)(value.strftime("%H:%M:%S"))

# float
elif isinstance(value, float):
return sqlalchemy.types.Float().literal_processor(dialect)(value)

# int
elif isinstance(value, int):
return sqlalchemy.types.Integer().literal_processor(dialect)(value)

# long
elif sys.version_info.major != 3 and isinstance(value, long):
return sqlalchemy.types.Integer().literal_processor(dialect)(value)

# str
elif isinstance(value, str):
return sqlalchemy.types.String().literal_processor(dialect)(value)

# None
elif isinstance(value, sqlalchemy.sql.elements.Null):
return sqlalchemy.types.NullType().literal_processor(dialect)(value)

# unsupported value
raise RuntimeError("unsupported value")

# process value(s), separating with commas as needed
if type(value) is list:
return ", ".join([process(v) for v in value])
else:
return process(value)

try:

parsed = sqlparse.split("SELECT * FROM cs50 WHERE id IN (SELECT id FROM cs50); SELECT 1; CREATE TABLE foo")
print(parsed)
return 0
# construct a new TextClause clause
statement = sqlalchemy.text(text)

# iterate over parameters
for key, value in params.items():

# bind parameters before statement reaches database, so that bound parameters appear in exceptions
# http://docs.sqlalchemy.org/en/latest/core/sqlelement.html#sqlalchemy.sql.expression.text
# https://groups.google.com/forum/#!topic/sqlalchemy/FfLwKT1yQlg
# http://docs.sqlalchemy.org/en/latest/core/connections.html#sqlalchemy.engine.Engine.execute
# translate None to NULL
if value is None:
value = sqlalchemy.sql.null()

# bind parameters before statement reaches database, so that bound parameters appear in exceptions
# http://docs.sqlalchemy.org/en/latest/core/sqlelement.html#sqlalchemy.sql.expression.text
statement = statement.bindparams(sqlalchemy.bindparam(key, value=value, type_=UserDefinedType()))

# stringify bound parameters
# http://docs.sqlalchemy.org/en/latest/faq/sqlexpressions.html#how-do-i-render-sql-expressions-as-strings-possibly-with-bound-parameters-inlined
statement = sqlalchemy.text(text).bindparams(*multiparams, **params)
result = self.engine.execute(str(statement.compile(compile_kwargs={"literal_binds": True})))
self.statement = str(statement.compile(compile_kwargs={"literal_binds": True}))

# execute statement
result = self.engine.execute(self.statement)

# if SELECT (or INSERT with RETURNING), return result set as list of dict objects
if result.returns_rows:
@@ -88,4 +121,5 @@ def execute(self, text, *multiparams, **params):

# else raise error
except Exception as e:
e.__context__ = None
raise RuntimeError(e)