@@ -14,7 +14,6 @@ def _enable_logging(f):
14
14
15
15
@functools .wraps (f )
16
16
def decorator (* args , ** kwargs ):
17
-
18
17
# Infer whether Flask is installed
19
18
try :
20
19
import flask
@@ -71,17 +70,20 @@ def __init__(self, url, **kwargs):
71
70
# Create engine, disabling SQLAlchemy's own autocommit mode raising exception if back end's module not installed;
72
71
# without isolation_level, PostgreSQL warns with "there is already a transaction in progress" for our own BEGIN and
73
72
# "there is no transaction in progress" for our own COMMIT
74
- self ._engine = sqlalchemy .create_engine (url , ** kwargs ).execution_options (autocommit = False , isolation_level = "AUTOCOMMIT" )
73
+ self ._engine = sqlalchemy .create_engine (url , ** kwargs ).execution_options (
74
+ autocommit = False , isolation_level = "AUTOCOMMIT"
75
+ )
75
76
76
77
# Get logger
77
78
self ._logger = logging .getLogger ("cs50" )
78
79
79
80
# Listener for connections
80
81
def connect (dbapi_connection , connection_record ):
81
-
82
82
# Enable foreign key constraints
83
83
try :
84
- if isinstance (dbapi_connection , sqlite3 .Connection ): # If back end is sqlite
84
+ if isinstance (
85
+ dbapi_connection , sqlite3 .Connection
86
+ ): # If back end is sqlite
85
87
cursor = dbapi_connection .cursor ()
86
88
cursor .execute ("PRAGMA foreign_keys=ON" )
87
89
cursor .close ()
@@ -150,14 +152,33 @@ def execute(self, sql, *args, **kwargs):
150
152
raise RuntimeError ("cannot pass both positional and named parameters" )
151
153
152
154
# Infer command from flattened statement to a single string separated by spaces
153
- full_statement = ' ' .join (str (token ) for token in statements [0 ].tokens if token .ttype in [sqlparse .tokens .Keyword , sqlparse .tokens .Keyword .DDL , sqlparse .tokens .Keyword .DML ])
155
+ full_statement = " " .join (
156
+ str (token )
157
+ for token in statements [0 ].tokens
158
+ if token .ttype
159
+ in [
160
+ sqlparse .tokens .Keyword ,
161
+ sqlparse .tokens .Keyword .DDL ,
162
+ sqlparse .tokens .Keyword .DML ,
163
+ ]
164
+ )
154
165
full_statement = full_statement .upper ()
155
166
156
167
# Set of possible commands
157
- commands = {"BEGIN" , "CREATE VIEW" , "DELETE" , "INSERT" , "SELECT" , "START" , "UPDATE" }
168
+ commands = {
169
+ "BEGIN" ,
170
+ "CREATE VIEW" ,
171
+ "DELETE" ,
172
+ "INSERT" ,
173
+ "SELECT" ,
174
+ "START" ,
175
+ "UPDATE" ,
176
+ }
158
177
159
178
# Check if the full_statement starts with any command
160
- command = next ((cmd for cmd in commands if full_statement .startswith (cmd )), None )
179
+ command = next (
180
+ (cmd for cmd in commands if full_statement .startswith (cmd )), None
181
+ )
161
182
162
183
# Flatten statement
163
184
tokens = list (statements [0 ].flatten ())
@@ -166,10 +187,8 @@ def execute(self, sql, *args, **kwargs):
166
187
placeholders = {}
167
188
paramstyle = None
168
189
for index , token in enumerate (tokens ):
169
-
170
190
# If token is a placeholder
171
191
if token .ttype == sqlparse .tokens .Name .Placeholder :
172
-
173
192
# Determine paramstyle, name
174
193
_paramstyle , name = _parse_placeholder (token )
175
194
@@ -186,7 +205,6 @@ def execute(self, sql, *args, **kwargs):
186
205
187
206
# If no placeholders
188
207
if not paramstyle :
189
-
190
208
# Error-check like qmark if args
191
209
if args :
192
210
paramstyle = "qmark"
@@ -201,41 +219,55 @@ def execute(self, sql, *args, **kwargs):
201
219
202
220
# qmark
203
221
if paramstyle == "qmark" :
204
-
205
222
# Validate number of placeholders
206
223
if len (placeholders ) != len (args ):
207
224
if len (placeholders ) < len (args ):
208
- raise RuntimeError ("fewer placeholders ({}) than values ({})" .format (_placeholders , _args ))
225
+ raise RuntimeError (
226
+ "fewer placeholders ({}) than values ({})" .format (
227
+ _placeholders , _args
228
+ )
229
+ )
209
230
else :
210
- raise RuntimeError ("more placeholders ({}) than values ({})" .format (_placeholders , _args ))
231
+ raise RuntimeError (
232
+ "more placeholders ({}) than values ({})" .format (
233
+ _placeholders , _args
234
+ )
235
+ )
211
236
212
237
# Escape values
213
238
for i , index in enumerate (placeholders .keys ()):
214
239
tokens [index ] = self ._escape (args [i ])
215
240
216
241
# numeric
217
242
elif paramstyle == "numeric" :
218
-
219
243
# Escape values
220
244
for index , i in placeholders .items ():
221
245
if i >= len (args ):
222
- raise RuntimeError ("missing value for placeholder (:{})" .format (i + 1 , len (args )))
246
+ raise RuntimeError (
247
+ "missing value for placeholder (:{})" .format (i + 1 , len (args ))
248
+ )
223
249
tokens [index ] = self ._escape (args [i ])
224
250
225
251
# Check if any values unused
226
252
indices = set (range (len (args ))) - set (placeholders .values ())
227
253
if indices :
228
- raise RuntimeError ("unused {} ({})" .format (
229
- "value" if len (indices ) == 1 else "values" ,
230
- ", " .join ([str (self ._escape (args [index ])) for index in indices ])))
254
+ raise RuntimeError (
255
+ "unused {} ({})" .format (
256
+ "value" if len (indices ) == 1 else "values" ,
257
+ ", " .join (
258
+ [str (self ._escape (args [index ])) for index in indices ]
259
+ ),
260
+ )
261
+ )
231
262
232
263
# named
233
264
elif paramstyle == "named" :
234
-
235
265
# Escape values
236
266
for index , name in placeholders .items ():
237
267
if name not in kwargs :
238
- raise RuntimeError ("missing value for placeholder (:{})" .format (name ))
268
+ raise RuntimeError (
269
+ "missing value for placeholder (:{})" .format (name )
270
+ )
239
271
tokens [index ] = self ._escape (kwargs [name ])
240
272
241
273
# Check if any keys unused
@@ -245,54 +277,65 @@ def execute(self, sql, *args, **kwargs):
245
277
246
278
# format
247
279
elif paramstyle == "format" :
248
-
249
280
# Validate number of placeholders
250
281
if len (placeholders ) != len (args ):
251
282
if len (placeholders ) < len (args ):
252
- raise RuntimeError ("fewer placeholders ({}) than values ({})" .format (_placeholders , _args ))
283
+ raise RuntimeError (
284
+ "fewer placeholders ({}) than values ({})" .format (
285
+ _placeholders , _args
286
+ )
287
+ )
253
288
else :
254
- raise RuntimeError ("more placeholders ({}) than values ({})" .format (_placeholders , _args ))
289
+ raise RuntimeError (
290
+ "more placeholders ({}) than values ({})" .format (
291
+ _placeholders , _args
292
+ )
293
+ )
255
294
256
295
# Escape values
257
296
for i , index in enumerate (placeholders .keys ()):
258
297
tokens [index ] = self ._escape (args [i ])
259
298
260
299
# pyformat
261
300
elif paramstyle == "pyformat" :
262
-
263
301
# Escape values
264
302
for index , name in placeholders .items ():
265
303
if name not in kwargs :
266
- raise RuntimeError ("missing value for placeholder (%{}s)" .format (name ))
304
+ raise RuntimeError (
305
+ "missing value for placeholder (%{}s)" .format (name )
306
+ )
267
307
tokens [index ] = self ._escape (kwargs [name ])
268
308
269
309
# Check if any keys unused
270
310
keys = kwargs .keys () - placeholders .values ()
271
311
if keys :
272
- raise RuntimeError ("unused {} ({})" .format (
273
- "value" if len (keys ) == 1 else "values" ,
274
- ", " .join (keys )))
312
+ raise RuntimeError (
313
+ "unused {} ({})" .format (
314
+ "value" if len (keys ) == 1 else "values" , ", " .join (keys )
315
+ )
316
+ )
275
317
276
318
# For SQL statements where a colon is required verbatim, as within an inline string, use a backslash to escape
277
319
# https://docs.sqlalchemy.org/en/13/core/sqlelement.html?highlight=text#sqlalchemy.sql.expression.text
278
320
for index , token in enumerate (tokens ):
279
-
280
321
# In string literal
281
322
# https://www.sqlite.org/lang_keywords.html
282
- if token .ttype in [sqlparse .tokens .Literal .String , sqlparse .tokens .Literal .String .Single ]:
323
+ if token .ttype in [
324
+ sqlparse .tokens .Literal .String ,
325
+ sqlparse .tokens .Literal .String .Single ,
326
+ ]:
283
327
token .value = re .sub ("(^'|\s+):" , r"\1\:" , token .value )
284
328
285
329
# In identifier
286
330
# https://www.sqlite.org/lang_keywords.html
287
331
elif token .ttype == sqlparse .tokens .Literal .String .Symbol :
288
- token .value = re .sub ("(^ \ " |\s+):" , r"\1\:" , token .value )
332
+ token .value = re .sub ('(^ "|\s+):' , r"\1\:" , token .value )
289
333
290
334
# Join tokens into statement
291
335
statement = "" .join ([str (token ) for token in tokens ])
292
336
293
337
# If no connection yet
294
338
if not hasattr (_data , self ._name ()):
295
-
296
339
# Connect to database
297
340
setattr (_data , self ._name (), self ._engine .connect ())
298
341
@@ -302,25 +345,33 @@ def execute(self, sql, *args, **kwargs):
302
345
# Disconnect if/when a Flask app is torn down
303
346
try :
304
347
import flask
348
+
305
349
assert flask .current_app
350
+
306
351
def teardown_appcontext (exception ):
307
352
self ._disconnect ()
353
+
308
354
if teardown_appcontext not in flask .current_app .teardown_appcontext_funcs :
309
355
flask .current_app .teardown_appcontext (teardown_appcontext )
310
356
except (ModuleNotFoundError , AssertionError ):
311
357
pass
312
358
313
359
# Catch SQLAlchemy warnings
314
360
with warnings .catch_warnings ():
315
-
316
361
# Raise exceptions for warnings
317
362
warnings .simplefilter ("error" )
318
363
319
364
# Prepare, execute statement
320
365
try :
321
-
322
366
# Join tokens into statement, abbreviating binary data as <class 'bytes'>
323
- _statement = "" .join ([str (bytes ) if token .ttype == sqlparse .tokens .Other else str (token ) for token in tokens ])
367
+ _statement = "" .join (
368
+ [
369
+ str (bytes )
370
+ if token .ttype == sqlparse .tokens .Other
371
+ else str (token )
372
+ for token in tokens
373
+ ]
374
+ )
324
375
325
376
# Check for start of transaction
326
377
if command in ["BEGIN" , "START" ]:
@@ -342,12 +393,10 @@ def teardown_appcontext(exception):
342
393
343
394
# If SELECT, return result set as list of dict objects
344
395
if command == "SELECT" :
345
-
346
396
# Coerce types
347
397
rows = [dict (row ) for row in result .mappings ().all ()]
348
398
for row in rows :
349
399
for column in row :
350
-
351
400
# Coerce decimal.Decimal objects to float objects
352
401
# https://groups.google.com/d/msg/sqlalchemy/0qXMYJvq8SA/oqtvMD9Uw-kJ
353
402
if isinstance (row [column ], decimal .Decimal ):
@@ -362,15 +411,15 @@ def teardown_appcontext(exception):
362
411
363
412
# If INSERT, return primary key value for a newly inserted row (or None if none)
364
413
elif command == "INSERT" :
365
-
366
414
# If PostgreSQL
367
415
if self ._engine .url .get_backend_name () == "postgresql" :
368
-
369
416
# Return LASTVAL() or NULL, avoiding
370
417
# "(psycopg2.errors.ObjectNotInPrerequisiteState) lastval is not yet defined in this session",
371
418
# a la https://stackoverflow.com/a/24186770/5156190;
372
419
# cf. https://www.psycopg.org/docs/errors.html re 55000
373
- result = connection .execute (sqlalchemy .text ("""
420
+ result = connection .execute (
421
+ sqlalchemy .text (
422
+ """
374
423
CREATE OR REPLACE FUNCTION _LASTVAL()
375
424
RETURNS integer LANGUAGE plpgsql
376
425
AS $$
@@ -382,7 +431,9 @@ def teardown_appcontext(exception):
382
431
END;
383
432
END $$;
384
433
SELECT _LASTVAL();
385
- """ ))
434
+ """
435
+ )
436
+ )
386
437
ret = result .first ()[0 ]
387
438
388
439
# If not PostgreSQL
@@ -407,7 +458,10 @@ def teardown_appcontext(exception):
407
458
raise e
408
459
409
460
# If user error
410
- except (sqlalchemy .exc .OperationalError , sqlalchemy .exc .ProgrammingError ) as e :
461
+ except (
462
+ sqlalchemy .exc .OperationalError ,
463
+ sqlalchemy .exc .ProgrammingError ,
464
+ ) as e :
411
465
self ._disconnect ()
412
466
self ._logger .error (termcolor .colored (_statement , "red" ))
413
467
e = RuntimeError (e .orig )
@@ -432,7 +486,6 @@ def _escape(self, value):
432
486
import sqlparse
433
487
434
488
def __escape (value ):
435
-
436
489
# Lazily import
437
490
import datetime
438
491
import sqlalchemy
@@ -441,66 +494,91 @@ def __escape(value):
441
494
if isinstance (value , bool ):
442
495
return sqlparse .sql .Token (
443
496
sqlparse .tokens .Number ,
444
- sqlalchemy .types .Boolean ().literal_processor (self ._engine .dialect )(value ))
497
+ sqlalchemy .types .Boolean ().literal_processor (self ._engine .dialect )(
498
+ value
499
+ ),
500
+ )
445
501
446
502
# bytes
447
503
elif isinstance (value , bytes ):
448
504
if self ._engine .url .get_backend_name () in ["mysql" , "sqlite" ]:
449
- return sqlparse .sql .Token (sqlparse .tokens .Other , f"x'{ value .hex ()} '" ) # https://dev.mysql.com/doc/refman/8.0/en/hexadecimal-literals.html
505
+ return sqlparse .sql .Token (
506
+ sqlparse .tokens .Other , f"x'{ value .hex ()} '"
507
+ ) # https://dev.mysql.com/doc/refman/8.0/en/hexadecimal-literals.html
450
508
elif self ._engine .url .get_backend_name () == "postgresql" :
451
- return sqlparse .sql .Token (sqlparse .tokens .Other , f"'\\ x{ value .hex ()} '" ) # https://dba.stackexchange.com/a/203359
509
+ return sqlparse .sql .Token (
510
+ sqlparse .tokens .Other , f"'\\ x{ value .hex ()} '"
511
+ ) # https://dba.stackexchange.com/a/203359
452
512
else :
453
513
raise RuntimeError ("unsupported value: {}" .format (value ))
454
514
455
515
# datetime.datetime
456
516
elif isinstance (value , datetime .datetime ):
457
517
return sqlparse .sql .Token (
458
518
sqlparse .tokens .String ,
459
- sqlalchemy .types .String ().literal_processor (self ._engine .dialect )(value .strftime ("%Y-%m-%d %H:%M:%S" )))
519
+ sqlalchemy .types .String ().literal_processor (self ._engine .dialect )(
520
+ value .strftime ("%Y-%m-%d %H:%M:%S" )
521
+ ),
522
+ )
460
523
461
524
# datetime.date
462
525
elif isinstance (value , datetime .date ):
463
526
return sqlparse .sql .Token (
464
527
sqlparse .tokens .String ,
465
- sqlalchemy .types .String ().literal_processor (self ._engine .dialect )(value .strftime ("%Y-%m-%d" )))
528
+ sqlalchemy .types .String ().literal_processor (self ._engine .dialect )(
529
+ value .strftime ("%Y-%m-%d" )
530
+ ),
531
+ )
466
532
467
533
# datetime.time
468
534
elif isinstance (value , datetime .time ):
469
535
return sqlparse .sql .Token (
470
536
sqlparse .tokens .String ,
471
- sqlalchemy .types .String ().literal_processor (self ._engine .dialect )(value .strftime ("%H:%M:%S" )))
537
+ sqlalchemy .types .String ().literal_processor (self ._engine .dialect )(
538
+ value .strftime ("%H:%M:%S" )
539
+ ),
540
+ )
472
541
473
542
# float
474
543
elif isinstance (value , float ):
475
544
return sqlparse .sql .Token (
476
545
sqlparse .tokens .Number ,
477
- sqlalchemy .types .Float ().literal_processor (self ._engine .dialect )(value ))
546
+ sqlalchemy .types .Float ().literal_processor (self ._engine .dialect )(
547
+ value
548
+ ),
549
+ )
478
550
479
551
# int
480
552
elif isinstance (value , int ):
481
553
return sqlparse .sql .Token (
482
554
sqlparse .tokens .Number ,
483
- sqlalchemy .types .Integer ().literal_processor (self ._engine .dialect )(value ))
555
+ sqlalchemy .types .Integer ().literal_processor (self ._engine .dialect )(
556
+ value
557
+ ),
558
+ )
484
559
485
560
# str
486
561
elif isinstance (value , str ):
487
562
return sqlparse .sql .Token (
488
563
sqlparse .tokens .String ,
489
- sqlalchemy .types .String ().literal_processor (self ._engine .dialect )(value ))
564
+ sqlalchemy .types .String ().literal_processor (self ._engine .dialect )(
565
+ value
566
+ ),
567
+ )
490
568
491
569
# None
492
570
elif value is None :
493
- return sqlparse .sql .Token (
494
- sqlparse .tokens .Keyword ,
495
- sqlalchemy .null ())
571
+ return sqlparse .sql .Token (sqlparse .tokens .Keyword , sqlalchemy .null ())
496
572
497
573
# Unsupported value
498
574
else :
499
575
raise RuntimeError ("unsupported value: {}" .format (value ))
500
576
501
577
# Escape value(s), separating with commas as needed
502
578
if isinstance (value , (list , tuple )):
503
- return sqlparse .sql .TokenList (sqlparse .parse (", " .join ([str (__escape (v )) for v in value ])))
579
+ return sqlparse .sql .TokenList (
580
+ sqlparse .parse (", " .join ([str (__escape (v )) for v in value ]))
581
+ )
504
582
else :
505
583
return __escape (value )
506
584
@@ -512,7 +590,9 @@ def _parse_exception(e):
512
590
import re
513
591
514
592
# MySQL
515
- matches = re .search (r"^\(_mysql_exceptions\.OperationalError\) \(\d+, \"(.+)\"\)$" , str (e ))
593
+ matches = re .search (
594
+ r"^\(_mysql_exceptions\.OperationalError\) \(\d+, \"(.+)\"\)$" , str (e )
595
+ )
516
596
if matches :
517
597
return matches .group (1 )
518
598
@@ -538,7 +618,10 @@ def _parse_placeholder(token):
538
618
import sqlparse
539
619
540
620
# Validate token
541
- if not isinstance (token , sqlparse .sql .Token ) or token .ttype != sqlparse .tokens .Name .Placeholder :
621
+ if (
622
+ not isinstance (token , sqlparse .sql .Token )
623
+ or token .ttype != sqlparse .tokens .Name .Placeholder
624
+ ):
542
625
raise TypeError ()
543
626
544
627
# qmark
0 commit comments