1
+ import datetime
1
2
import sqlalchemy
3
+ import sys
2
4
3
5
class SQL (object ):
4
6
"""Wrap SQLAlchemy to provide a simple SQL API."""
@@ -14,21 +16,91 @@ def __init__(self, url):
14
16
try :
15
17
self .engine = sqlalchemy .create_engine (url )
16
18
except Exception as e :
19
+ e .__context__ = None
17
20
raise RuntimeError (e )
18
21
19
- def execute (self , text , * multiparams , * *params ):
22
+ def execute (self , text , ** params ):
20
23
"""
21
24
Execute a SQL statement.
22
25
"""
26
+
27
+ class UserDefinedType (sqlalchemy .TypeDecorator ):
28
+ """
29
+ Add support for expandable values, a la https://bitbucket.org/zzzeek/sqlalchemy/issues/3953/expanding-parameter.
30
+ """
31
+ impl = sqlalchemy .types .UserDefinedType
32
+ def process_literal_param (self , value , dialect ):
33
+ """Receive a literal parameter value to be rendered inline within a statement."""
34
+ def process (value ):
35
+ """Render a literal value, escaping as needed."""
36
+
37
+ # bool
38
+ if isinstance (value , bool ):
39
+ return sqlalchemy .types .Boolean ().literal_processor (dialect )(value )
40
+
41
+ # datetime.date
42
+ elif isinstance (value , datetime .date ):
43
+ return sqlalchemy .types .String ().literal_processor (dialect )(value .strftime ("%Y-%m-%d" ))
44
+
45
+ # datetime.datetime
46
+ elif isinstance (value , datetime .datetime ):
47
+ return sqlalchemy .types .String ().literal_processor (dialect )(value .strftime ("%Y-%m-%d %H:%M:%S" ))
48
+
49
+ # datetime.time
50
+ elif isinstance (value , datetime .time ):
51
+ return sqlalchemy .types .String ().literal_processor (dialect )(value .strftime ("%H:%M:%S" ))
52
+
53
+ # float
54
+ elif isinstance (value , float ):
55
+ return sqlalchemy .types .Float ().literal_processor (dialect )(value )
56
+
57
+ # int
58
+ elif isinstance (value , int ):
59
+ return sqlalchemy .types .Integer ().literal_processor (dialect )(value )
60
+
61
+ # long
62
+ elif sys .version_info .major != 3 and isinstance (value , long ):
63
+ return sqlalchemy .types .Integer ().literal_processor (dialect )(value )
64
+
65
+ # str
66
+ elif isinstance (value , str ):
67
+ return sqlalchemy .types .String ().literal_processor (dialect )(value )
68
+
69
+ # None
70
+ elif isinstance (value , sqlalchemy .sql .elements .Null ):
71
+ return sqlalchemy .types .NullType ().literal_processor (dialect )(value )
72
+
73
+ # unsupported value
74
+ raise RuntimeError ("unsupported value" )
75
+
76
+ # process value(s), separating with commas as needed
77
+ if type (value ) is list :
78
+ return ", " .join ([process (v ) for v in value ])
79
+ else :
80
+ return process (value )
81
+
23
82
try :
24
83
25
- # bind parameters before statement reaches database, so that bound parameters appear in exceptions
26
- # http://docs.sqlalchemy.org/en/latest/core/sqlelement.html#sqlalchemy.sql.expression.text
27
- # https://groups.google.com/forum/#!topic/sqlalchemy/FfLwKT1yQlg
28
- # http://docs.sqlalchemy.org/en/latest/core/connections.html#sqlalchemy.engine.Engine.execute
84
+ # construct a new TextClause clause
85
+ statement = sqlalchemy .text (text )
86
+
87
+ # iterate over parameters
88
+ for key , value in params .items ():
89
+
90
+ # translate None to NULL
91
+ if value is None :
92
+ value = sqlalchemy .sql .null ()
93
+
94
+ # bind parameters before statement reaches database, so that bound parameters appear in exceptions
95
+ # http://docs.sqlalchemy.org/en/latest/core/sqlelement.html#sqlalchemy.sql.expression.text
96
+ statement = statement .bindparams (sqlalchemy .bindparam (key , value = value , type_ = UserDefinedType ()))
97
+
98
+ # stringify bound parameters
29
99
# http://docs.sqlalchemy.org/en/latest/faq/sqlexpressions.html#how-do-i-render-sql-expressions-as-strings-possibly-with-bound-parameters-inlined
30
- statement = sqlalchemy .text (text ).bindparams (* multiparams , ** params )
31
- result = self .engine .execute (str (statement .compile (compile_kwargs = {"literal_binds" : True })))
100
+ self .statement = str (statement .compile (compile_kwargs = {"literal_binds" : True }))
101
+
102
+ # execute statement
103
+ result = self .engine .execute (self .statement )
32
104
33
105
# if SELECT (or INSERT with RETURNING), return result set as list of dict objects
34
106
if result .returns_rows :
@@ -49,4 +121,5 @@ def execute(self, text, *multiparams, **params):
49
121
50
122
# else raise error
51
123
except Exception as e :
124
+ e .__context__ = None
52
125
raise RuntimeError (e )
0 commit comments