Skip to content

Commit d626967

Browse files
authoredMay 21, 2017
Merge pull request #19 from cs50/expandable-parameters
adds support for expandable parameters
2 parents 7b69487 + 64c1a04 commit d626967

File tree

1 file changed

+80
-7
lines changed

1 file changed

+80
-7
lines changed
 

‎cs50/sql.py

+80-7
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import datetime
12
import sqlalchemy
3+
import sys
24

35
class SQL(object):
46
"""Wrap SQLAlchemy to provide a simple SQL API."""
@@ -14,21 +16,91 @@ def __init__(self, url):
1416
try:
1517
self.engine = sqlalchemy.create_engine(url)
1618
except Exception as e:
19+
e.__context__ = None
1720
raise RuntimeError(e)
1821

19-
def execute(self, text, *multiparams, **params):
22+
def execute(self, text, **params):
2023
"""
2124
Execute a SQL statement.
2225
"""
26+
27+
class UserDefinedType(sqlalchemy.TypeDecorator):
28+
"""
29+
Add support for expandable values, a la https://bitbucket.org/zzzeek/sqlalchemy/issues/3953/expanding-parameter.
30+
"""
31+
impl = sqlalchemy.types.UserDefinedType
32+
def process_literal_param(self, value, dialect):
33+
"""Receive a literal parameter value to be rendered inline within a statement."""
34+
def process(value):
35+
"""Render a literal value, escaping as needed."""
36+
37+
# bool
38+
if isinstance(value, bool):
39+
return sqlalchemy.types.Boolean().literal_processor(dialect)(value)
40+
41+
# datetime.date
42+
elif isinstance(value, datetime.date):
43+
return sqlalchemy.types.String().literal_processor(dialect)(value.strftime("%Y-%m-%d"))
44+
45+
# datetime.datetime
46+
elif isinstance(value, datetime.datetime):
47+
return sqlalchemy.types.String().literal_processor(dialect)(value.strftime("%Y-%m-%d %H:%M:%S"))
48+
49+
# datetime.time
50+
elif isinstance(value, datetime.time):
51+
return sqlalchemy.types.String().literal_processor(dialect)(value.strftime("%H:%M:%S"))
52+
53+
# float
54+
elif isinstance(value, float):
55+
return sqlalchemy.types.Float().literal_processor(dialect)(value)
56+
57+
# int
58+
elif isinstance(value, int):
59+
return sqlalchemy.types.Integer().literal_processor(dialect)(value)
60+
61+
# long
62+
elif sys.version_info.major != 3 and isinstance(value, long):
63+
return sqlalchemy.types.Integer().literal_processor(dialect)(value)
64+
65+
# str
66+
elif isinstance(value, str):
67+
return sqlalchemy.types.String().literal_processor(dialect)(value)
68+
69+
# None
70+
elif isinstance(value, sqlalchemy.sql.elements.Null):
71+
return sqlalchemy.types.NullType().literal_processor(dialect)(value)
72+
73+
# unsupported value
74+
raise RuntimeError("unsupported value")
75+
76+
# process value(s), separating with commas as needed
77+
if type(value) is list:
78+
return ", ".join([process(v) for v in value])
79+
else:
80+
return process(value)
81+
2382
try:
2483

25-
# bind parameters before statement reaches database, so that bound parameters appear in exceptions
26-
# http://docs.sqlalchemy.org/en/latest/core/sqlelement.html#sqlalchemy.sql.expression.text
27-
# https://groups.google.com/forum/#!topic/sqlalchemy/FfLwKT1yQlg
28-
# http://docs.sqlalchemy.org/en/latest/core/connections.html#sqlalchemy.engine.Engine.execute
84+
# construct a new TextClause clause
85+
statement = sqlalchemy.text(text)
86+
87+
# iterate over parameters
88+
for key, value in params.items():
89+
90+
# translate None to NULL
91+
if value is None:
92+
value = sqlalchemy.sql.null()
93+
94+
# bind parameters before statement reaches database, so that bound parameters appear in exceptions
95+
# http://docs.sqlalchemy.org/en/latest/core/sqlelement.html#sqlalchemy.sql.expression.text
96+
statement = statement.bindparams(sqlalchemy.bindparam(key, value=value, type_=UserDefinedType()))
97+
98+
# stringify bound parameters
2999
# http://docs.sqlalchemy.org/en/latest/faq/sqlexpressions.html#how-do-i-render-sql-expressions-as-strings-possibly-with-bound-parameters-inlined
30-
statement = sqlalchemy.text(text).bindparams(*multiparams, **params)
31-
result = self.engine.execute(str(statement.compile(compile_kwargs={"literal_binds": True})))
100+
self.statement = str(statement.compile(compile_kwargs={"literal_binds": True}))
101+
102+
# execute statement
103+
result = self.engine.execute(self.statement)
32104

33105
# if SELECT (or INSERT with RETURNING), return result set as list of dict objects
34106
if result.returns_rows:
@@ -49,4 +121,5 @@ def execute(self, text, *multiparams, **params):
49121

50122
# else raise error
51123
except Exception as e:
124+
e.__context__ = None
52125
raise RuntimeError(e)

0 commit comments

Comments
 (0)
Please sign in to comment.