@@ -59,34 +59,13 @@ def connect(dbapi_connection, connection_record):
59
59
# Log statements to standard error
60
60
logging .basicConfig (level = logging .DEBUG )
61
61
62
- def parse (self , e ):
63
- """Parses an exception, returns its message."""
64
-
65
- # MySQL
66
- matches = re .search (r"^\(_mysql_exceptions\.OperationalError\) \(\d+, \"(.+)\"\)$" , str (e ))
67
- if matches :
68
- return matches .group (1 )
69
-
70
- # PostgreSQL
71
- matches = re .search (r"^\(psycopg2\.OperationalError\) (.+)$" , str (e ))
72
- if matches :
73
- return matches .group (1 )
74
-
75
- # SQLite
76
- matches = re .search (r"^\(sqlite3\.OperationalError\) (.+)$" , str (e ))
77
- if matches :
78
- return matches .group (1 )
79
-
80
- # Default
81
- return str (e )
82
-
83
62
# Test database
84
63
try :
85
64
disabled = self ._logger .disabled
86
65
self ._logger .disabled = True
87
66
self .execute ("SELECT 1" )
88
67
except sqlalchemy .exc .OperationalError as e :
89
- e = RuntimeError (parse (e ))
68
+ e = RuntimeError (self . _parse_exception (e ))
90
69
e .__cause__ = None
91
70
raise e
92
71
else :
@@ -126,19 +105,8 @@ def execute(self, sql, *args, **kwargs):
126
105
# If token is a placeholder
127
106
if token .ttype == sqlparse .tokens .Name .Placeholder :
128
107
129
- # Determine paramstyle
130
- if token .value == "?" :
131
- _paramstyle = "qmark"
132
- elif re .search (r"^:[1-9]\d*$" , token .value ):
133
- _paramstyle = "numeric"
134
- elif re .search (r"^:[a-zA-Z]\w*$" , token .value ):
135
- _paramstyle = "named"
136
- elif re .search (r"^TODO$" , token .value ): # TODO
137
- _paramstyle = "named"
138
- elif re .search (r"%\([a-zA-Z]\w*\)s$" , token .value ): # TODO
139
- _paramstyle = "pyformat"
140
- else :
141
- raise RuntimeError ("{}: invalid placeholder" .format (token .value ))
108
+ # Determine paramstyle, name
109
+ _paramstyle , name = self ._parse_placeholder (token )
142
110
143
111
# Ensure paramstyle is consistent
144
112
if paramstyle is not None and _paramstyle != paramstyle :
@@ -148,10 +116,15 @@ def execute(self, sql, *args, **kwargs):
148
116
if paramstyle is None :
149
117
paramstyle = _paramstyle
150
118
151
- # Remember placeholder
152
- placeholders [index ] = token . value
119
+ # Remember placeholder's index, name
120
+ placeholders [index ] = name
153
121
154
122
def escape (value ):
123
+ """
124
+ Escapes value using engine's conversion function.
125
+
126
+ https://docs.sqlalchemy.org/en/latest/core/type_api.html#sqlalchemy.types.TypeEngine.literal_processor
127
+ """
155
128
156
129
# bool
157
130
if type (value ) is bool :
@@ -221,18 +194,39 @@ def escape(value):
221
194
elif paramstyle == "numeric" :
222
195
223
196
# Escape values
224
- for index , value in placeholders .items ():
225
- i = int (re . sub ( r"^:" , "" , value ) ) - 1
226
- if i >= len (args ):
197
+ for index , name in placeholders .items ():
198
+ i = int (name ) - 1
199
+ if i < 0 or i >= len (args ):
227
200
raise RuntimeError ("placeholder out of range" )
228
201
tokens [index ] = escape (args [i ])
229
202
230
203
# named
231
204
elif paramstyle == "named" :
232
205
233
206
# Escape values
234
- for index , value in placeholders .items ():
235
- name = re .sub (r"^:" , "" , value )
207
+ for index , name in placeholders .items ():
208
+ if name not in kwargs :
209
+ raise RuntimeError ("missing value for placeholder" )
210
+ tokens [index ] = escape (kwargs [name ])
211
+
212
+ # format
213
+ elif paramstyle == "format" :
214
+
215
+ # Validate number of placeholders
216
+ if len (placeholders ) < len (args ):
217
+ raise RuntimeError ("too few placeholders" )
218
+ elif len (placeholders ) > len (args ):
219
+ raise RuntimeError ("too many placeholders" )
220
+
221
+ # Escape values
222
+ for i , index in enumerate (placeholders .keys ()):
223
+ tokens [index ] = escape (args [i ])
224
+
225
+ # pyformat
226
+ elif paramstyle == "pyformat" :
227
+
228
+ # Escape values
229
+ for index , name in placeholders .items ():
236
230
if name not in kwargs :
237
231
raise RuntimeError ("missing value for placeholder" )
238
232
tokens [index ] = escape (kwargs [name ])
@@ -285,11 +279,65 @@ def escape(value):
285
279
# If user errror
286
280
except sqlalchemy .exc .OperationalError as e :
287
281
self ._logger .debug (termcolor .colored (statement , "red" ))
288
- e = RuntimeError (self ._parse (e ))
282
+ e = RuntimeError (self ._parse_exception (e ))
289
283
e .__cause__ = None
290
284
raise e
291
285
292
286
# Return value
293
287
else :
294
288
self ._logger .debug (termcolor .colored (statement , "green" ))
295
289
return ret
290
+
291
+ def _parse_exception (self , e ):
292
+ """Parses an exception, returns its message."""
293
+
294
+ # MySQL
295
+ matches = re .search (r"^\(_mysql_exceptions\.OperationalError\) \(\d+, \"(.+)\"\)$" , str (e ))
296
+ if matches :
297
+ return matches .group (1 )
298
+
299
+ # PostgreSQL
300
+ matches = re .search (r"^\(psycopg2\.OperationalError\) (.+)$" , str (e ))
301
+ if matches :
302
+ return matches .group (1 )
303
+
304
+ # SQLite
305
+ matches = re .search (r"^\(sqlite3\.OperationalError\) (.+)$" , str (e ))
306
+ if matches :
307
+ return matches .group (1 )
308
+
309
+ # Default
310
+ return str (e )
311
+
312
+ def _parse_placeholder (self , token ):
313
+ """Infers paramstyle, name from sqlparse.tokens.Name.Placeholder."""
314
+
315
+ # Validate token
316
+ if not isinstance (token , sqlparse .sql .Token ) or token .ttype != sqlparse .tokens .Name .Placeholder :
317
+ raise TypeError ()
318
+
319
+ # qmark
320
+ if token .value == "?" :
321
+ return "qmark" , None
322
+
323
+ # numeric
324
+ matches = re .search (r"^:(\d+)$" , token .value )
325
+ if matches :
326
+ return "numeric" , matches .group (1 )
327
+
328
+ # named
329
+ matches = re .search (r"^:([a-zA-Z]\w*)$" , token .value )
330
+ if matches :
331
+ return "named" , matches .group (1 )
332
+
333
+ # format
334
+ if token .value == "%s" :
335
+ return "format" , None
336
+
337
+ # pyformat
338
+ matches = re .search (r"%\((\w+)\)s$" , token .value )
339
+ if matches :
340
+ return "pyformat" , matches .group (1 )
341
+
342
+ # Invalid
343
+ raise RuntimeError ("{}: invalid placeholder" .format (token .value ))
0 commit comments