@@ -14,7 +14,6 @@ def _enable_logging(f):
14
14
15
15
@functools .wraps (f )
16
16
def decorator (* args , ** kwargs ):
17
-
18
17
# Infer whether Flask is installed
19
18
try :
20
19
import flask
@@ -71,17 +70,20 @@ def __init__(self, url, **kwargs):
71
70
# Create engine, disabling SQLAlchemy's own autocommit mode raising exception if back end's module not installed;
72
71
# without isolation_level, PostgreSQL warns with "there is already a transaction in progress" for our own BEGIN and
73
72
# "there is no transaction in progress" for our own COMMIT
74
- self ._engine = sqlalchemy .create_engine (url , ** kwargs ).execution_options (autocommit = False , isolation_level = "AUTOCOMMIT" )
73
+ self ._engine = sqlalchemy .create_engine (url , ** kwargs ).execution_options (
74
+ autocommit = False , isolation_level = "AUTOCOMMIT"
75
+ )
75
76
76
77
# Get logger
77
78
self ._logger = logging .getLogger ("cs50" )
78
79
79
80
# Listener for connections
80
81
def connect (dbapi_connection , connection_record ):
81
-
82
82
# Enable foreign key constraints
83
83
try :
84
- if isinstance (dbapi_connection , sqlite3 .Connection ): # If back end is sqlite
84
+ if isinstance (
85
+ dbapi_connection , sqlite3 .Connection
86
+ ): # If back end is sqlite
85
87
cursor = dbapi_connection .cursor ()
86
88
cursor .execute ("PRAGMA foreign_keys=ON" )
87
89
cursor .close ()
@@ -150,14 +152,33 @@ def execute(self, sql, *args, **kwargs):
150
152
raise RuntimeError ("cannot pass both positional and named parameters" )
151
153
152
154
# Infer command from flattened statement to a single string separated by spaces
153
- full_statement = ' ' .join (str (token ) for token in statements [0 ].tokens if token .ttype in [sqlparse .tokens .Keyword , sqlparse .tokens .Keyword .DDL , sqlparse .tokens .Keyword .DML ])
155
+ full_statement = " " .join (
156
+ str (token )
157
+ for token in statements [0 ].tokens
158
+ if token .ttype
159
+ in [
160
+ sqlparse .tokens .Keyword ,
161
+ sqlparse .tokens .Keyword .DDL ,
162
+ sqlparse .tokens .Keyword .DML ,
163
+ ]
164
+ )
154
165
full_statement = full_statement .upper ()
155
166
156
167
# Set of possible commands
157
- commands = {"BEGIN" , "CREATE VIEW" , "DELETE" , "INSERT" , "SELECT" , "START" , "UPDATE" }
168
+ commands = {
169
+ "BEGIN" ,
170
+ "CREATE VIEW" ,
171
+ "DELETE" ,
172
+ "INSERT" ,
173
+ "SELECT" ,
174
+ "START" ,
175
+ "UPDATE" ,
176
+ }
158
177
159
178
# Check if the full_statement starts with any command
160
- command = next ((cmd for cmd in commands if full_statement .startswith (cmd )), None )
179
+ command = next (
180
+ (cmd for cmd in commands if full_statement .startswith (cmd )), None
181
+ )
161
182
162
183
# Flatten statement
163
184
tokens = list (statements [0 ].flatten ())
@@ -166,10 +187,8 @@ def execute(self, sql, *args, **kwargs):
166
187
placeholders = {}
167
188
paramstyle = None
168
189
for index , token in enumerate (tokens ):
169
-
170
190
# If token is a placeholder
171
191
if token .ttype == sqlparse .tokens .Name .Placeholder :
172
-
173
192
# Determine paramstyle, name
174
193
_paramstyle , name = _parse_placeholder (token )
175
194
@@ -186,7 +205,6 @@ def execute(self, sql, *args, **kwargs):
186
205
187
206
# If no placeholders
188
207
if not paramstyle :
189
-
190
208
# Error-check like qmark if args
191
209
if args :
192
210
paramstyle = "qmark"
@@ -201,41 +219,55 @@ def execute(self, sql, *args, **kwargs):
201
219
202
220
# qmark
203
221
if paramstyle == "qmark" :
204
-
205
222
# Validate number of placeholders
206
223
if len (placeholders ) != len (args ):
207
224
if len (placeholders ) < len (args ):
208
- raise RuntimeError ("fewer placeholders ({}) than values ({})" .format (_placeholders , _args ))
225
+ raise RuntimeError (
226
+ "fewer placeholders ({}) than values ({})" .format (
227
+ _placeholders , _args
228
+ )
229
+ )
209
230
else :
210
- raise RuntimeError ("more placeholders ({}) than values ({})" .format (_placeholders , _args ))
231
+ raise RuntimeError (
232
+ "more placeholders ({}) than values ({})" .format (
233
+ _placeholders , _args
234
+ )
235
+ )
211
236
212
237
# Escape values
213
238
for i , index in enumerate (placeholders .keys ()):
214
239
tokens [index ] = self ._escape (args [i ])
215
240
216
241
# numeric
217
242
elif paramstyle == "numeric" :
218
-
219
243
# Escape values
220
244
for index , i in placeholders .items ():
221
245
if i >= len (args ):
222
- raise RuntimeError ("missing value for placeholder (:{})" .format (i + 1 , len (args )))
246
+ raise RuntimeError (
247
+ "missing value for placeholder (:{})" .format (i + 1 , len (args ))
248
+ )
223
249
tokens [index ] = self ._escape (args [i ])
224
250
225
251
# Check if any values unused
226
252
indices = set (range (len (args ))) - set (placeholders .values ())
227
253
if indices :
228
- raise RuntimeError ("unused {} ({})" .format (
229
- "value" if len (indices ) == 1 else "values" ,
230
- ", " .join ([str (self ._escape (args [index ])) for index in indices ])))
254
+ raise RuntimeError (
255
+ "unused {} ({})" .format (
256
+ "value" if len (indices ) == 1 else "values" ,
257
+ ", " .join (
258
+ [str (self ._escape (args [index ])) for index in indices ]
259
+ ),
260
+ )
261
+ )
231
262
232
263
# named
233
264
elif paramstyle == "named" :
234
-
235
265
# Escape values
236
266
for index , name in placeholders .items ():
237
267
if name not in kwargs :
238
- raise RuntimeError ("missing value for placeholder (:{})" .format (name ))
268
+ raise RuntimeError (
269
+ "missing value for placeholder (:{})" .format (name )
270
+ )
239
271
tokens [index ] = self ._escape (kwargs [name ])
240
272
241
273
# Check if any keys unused
@@ -245,54 +277,65 @@ def execute(self, sql, *args, **kwargs):
245
277
246
278
# format
247
279
elif paramstyle == "format" :
248
-
249
280
# Validate number of placeholders
250
281
if len (placeholders ) != len (args ):
251
282
if len (placeholders ) < len (args ):
252
- raise RuntimeError ("fewer placeholders ({}) than values ({})" .format (_placeholders , _args ))
283
+ raise RuntimeError (
284
+ "fewer placeholders ({}) than values ({})" .format (
285
+ _placeholders , _args
286
+ )
287
+ )
253
288
else :
254
- raise RuntimeError ("more placeholders ({}) than values ({})" .format (_placeholders , _args ))
289
+ raise RuntimeError (
290
+ "more placeholders ({}) than values ({})" .format (
291
+ _placeholders , _args
292
+ )
293
+ )
255
294
256
295
# Escape values
257
296
for i , index in enumerate (placeholders .keys ()):
258
297
tokens [index ] = self ._escape (args [i ])
259
298
260
299
# pyformat
261
300
elif paramstyle == "pyformat" :
262
-
263
301
# Escape values
264
302
for index , name in placeholders .items ():
265
303
if name not in kwargs :
266
- raise RuntimeError ("missing value for placeholder (%{}s)" .format (name ))
304
+ raise RuntimeError (
305
+ "missing value for placeholder (%{}s)" .format (name )
306
+ )
267
307
tokens [index ] = self ._escape (kwargs [name ])
268
308
269
309
# Check if any keys unused
270
310
keys = kwargs .keys () - placeholders .values ()
271
311
if keys :
272
- raise RuntimeError ("unused {} ({})" .format (
273
- "value" if len (keys ) == 1 else "values" ,
274
- ", " .join (keys )))
312
+ raise RuntimeError (
313
+ "unused {} ({})" .format (
314
+ "value" if len (keys ) == 1 else "values" , ", " .join (keys )
315
+ )
316
+ )
275
317
276
318
# For SQL statements where a colon is required verbatim, as within an inline string, use a backslash to escape
277
319
# https://docs.sqlalchemy.org/en/13/core/sqlelement.html?highlight=text#sqlalchemy.sql.expression.text
278
320
for index , token in enumerate (tokens ):
279
-
280
321
# In string literal
281
322
# https://www.sqlite.org/lang_keywords.html
282
- if token .ttype in [sqlparse .tokens .Literal .String , sqlparse .tokens .Literal .String .Single ]:
323
+ if token .ttype in [
324
+ sqlparse .tokens .Literal .String ,
325
+ sqlparse .tokens .Literal .String .Single ,
326
+ ]:
283
327
token .value = re .sub ("(^'|\s+):" , r"\1\:" , token .value )
284
328
285
329
# In identifier
286
330
# https://www.sqlite.org/lang_keywords.html
287
331
elif token .ttype == sqlparse .tokens .Literal .String .Symbol :
288
- token .value = re .sub ("(^ \ " |\s+):" , r"\1\:" , token .value )
332
+ token .value = re .sub ('(^ "|\s+):' , r"\1\:" , token .value )
289
333
290
334
# Join tokens into statement
291
335
statement = "" .join ([str (token ) for token in tokens ])
292
336
293
337
# If no connection yet
294
338
if not hasattr (_data , self ._name ()):
295
-
296
339
# Connect to database
297
340
setattr (_data , self ._name (), self ._engine .connect ())
298
341
@@ -302,25 +345,33 @@ def execute(self, sql, *args, **kwargs):
302
345
# Disconnect if/when a Flask app is torn down
303
346
try :
304
347
import flask
348
+
305
349
assert flask .current_app
350
+
306
351
def teardown_appcontext (exception ):
307
352
self ._disconnect ()
353
+
308
354
if teardown_appcontext not in flask .current_app .teardown_appcontext_funcs :
309
355
flask .current_app .teardown_appcontext (teardown_appcontext )
310
356
except (ModuleNotFoundError , AssertionError ):
311
357
pass
312
358
313
359
# Catch SQLAlchemy warnings
314
360
with warnings .catch_warnings ():
315
-
316
361
# Raise exceptions for warnings
317
362
warnings .simplefilter ("error" )
318
363
319
364
# Prepare, execute statement
320
365
try :
321
-
322
366
# Join tokens into statement, abbreviating binary data as <class 'bytes'>
323
- _statement = "" .join ([str (bytes ) if token .ttype == sqlparse .tokens .Other else str (token ) for token in tokens ])
367
+ _statement = "" .join (
368
+ [
369
+ str (bytes )
370
+ if token .ttype == sqlparse .tokens .Other
371
+ else str (token )
372
+ for token in tokens
373
+ ]
374
+ )
324
375
325
376
# Check for start of transaction
326
377
if command in ["BEGIN" , "START" ]:
@@ -342,12 +393,10 @@ def teardown_appcontext(exception):
342
393
343
394
# If SELECT, return result set as list of dict objects
344
395
if command == "SELECT" :
345
-
346
396
# Coerce types
347
397
rows = [dict (row ) for row in result .mappings ().all ()]
348
398
for row in rows :
349
399
for column in row :
350
-
351
400
# Coerce decimal.Decimal objects to float objects
352
401
# https://groups.google.com/d/msg/sqlalchemy/0qXMYJvq8SA/oqtvMD9Uw-kJ
353
402
if isinstance (row [column ], decimal .Decimal ):
@@ -362,15 +411,15 @@ def teardown_appcontext(exception):
362
411
363
412
# If INSERT, return primary key value for a newly inserted row (or None if none)
364
413
elif command == "INSERT" :
365
-
366
414
# If PostgreSQL
367
415
if self ._engine .url .get_backend_name () == "postgresql" :
368
-
369
416
# Return LASTVAL() or NULL, avoiding
370
417
# "(psycopg2.errors.ObjectNotInPrerequisiteState) lastval is not yet defined in this session",
371
418
# a la https://stackoverflow.com/a/24186770/5156190;
372
419
# cf. https://www.psycopg.org/docs/errors.html re 55000
373
- result = connection .execute (sqlalchemy .text ("""
420
+ result = connection .execute (
421
+ sqlalchemy .text (
422
+ """
374
423
CREATE OR REPLACE FUNCTION _LASTVAL()
375
424
RETURNS integer LANGUAGE plpgsql
376
425
AS $$
@@ -382,7 +431,9 @@ def teardown_appcontext(exception):
382
431
END;
383
432
END $$;
384
433
SELECT _LASTVAL();
385
- """ ))
434
+ """
435
+ )
436
+ )
386
437
ret = result .first ()[0 ]
387
438
388
439
# If not PostgreSQL
@@ -405,7 +456,10 @@ def teardown_appcontext(exception):
405
456
raise e
406
457
407
458
# If user error
408
- except (sqlalchemy .exc .OperationalError , sqlalchemy .exc .ProgrammingError ) as e :
459
+ except (
460
+ sqlalchemy .exc .OperationalError ,
461
+ sqlalchemy .exc .ProgrammingError ,
462
+ ) as e :
409
463
self ._disconnect ()
410
464
self ._logger .error (termcolor .colored (_statement , "red" ))
411
465
e = RuntimeError (e .orig )
@@ -430,7 +484,6 @@ def _escape(self, value):
430
484
import sqlparse
431
485
432
486
def __escape (value ):
433
-
434
487
# Lazily import
435
488
import datetime
436
489
import sqlalchemy
@@ -439,66 +492,91 @@ def __escape(value):
439
492
if isinstance (value , bool ):
440
493
return sqlparse .sql .Token (
441
494
sqlparse .tokens .Number ,
442
- sqlalchemy .types .Boolean ().literal_processor (self ._engine .dialect )(value ))
495
+ sqlalchemy .types .Boolean ().literal_processor (self ._engine .dialect )(
496
+ value
497
+ ),
498
+ )
443
499
444
500
# bytes
445
501
elif isinstance (value , bytes ):
446
502
if self ._engine .url .get_backend_name () in ["mysql" , "sqlite" ]:
447
- return sqlparse .sql .Token (sqlparse .tokens .Other , f"x'{ value .hex ()} '" ) # https://dev.mysql.com/doc/refman/8.0/en/hexadecimal-literals.html
503
+ return sqlparse .sql .Token (
504
+ sqlparse .tokens .Other , f"x'{ value .hex ()} '"
505
+ ) # https://dev.mysql.com/doc/refman/8.0/en/hexadecimal-literals.html
448
506
elif self ._engine .url .get_backend_name () == "postgresql" :
449
- return sqlparse .sql .Token (sqlparse .tokens .Other , f"'\\ x{ value .hex ()} '" ) # https://dba.stackexchange.com/a/203359
507
+ return sqlparse .sql .Token (
508
+ sqlparse .tokens .Other , f"'\\ x{ value .hex ()} '"
509
+ ) # https://dba.stackexchange.com/a/203359
450
510
else :
451
511
raise RuntimeError ("unsupported value: {}" .format (value ))
452
512
453
513
# datetime.datetime
454
514
elif isinstance (value , datetime .datetime ):
455
515
return sqlparse .sql .Token (
456
516
sqlparse .tokens .String ,
457
- sqlalchemy .types .String ().literal_processor (self ._engine .dialect )(value .strftime ("%Y-%m-%d %H:%M:%S" )))
517
+ sqlalchemy .types .String ().literal_processor (self ._engine .dialect )(
518
+ value .strftime ("%Y-%m-%d %H:%M:%S" )
519
+ ),
520
+ )
458
521
459
522
# datetime.date
460
523
elif isinstance (value , datetime .date ):
461
524
return sqlparse .sql .Token (
462
525
sqlparse .tokens .String ,
463
- sqlalchemy .types .String ().literal_processor (self ._engine .dialect )(value .strftime ("%Y-%m-%d" )))
526
+ sqlalchemy .types .String ().literal_processor (self ._engine .dialect )(
527
+ value .strftime ("%Y-%m-%d" )
528
+ ),
529
+ )
464
530
465
531
# datetime.time
466
532
elif isinstance (value , datetime .time ):
467
533
return sqlparse .sql .Token (
468
534
sqlparse .tokens .String ,
469
- sqlalchemy .types .String ().literal_processor (self ._engine .dialect )(value .strftime ("%H:%M:%S" )))
535
+ sqlalchemy .types .String ().literal_processor (self ._engine .dialect )(
536
+ value .strftime ("%H:%M:%S" )
537
+ ),
538
+ )
470
539
471
540
# float
472
541
elif isinstance (value , float ):
473
542
return sqlparse .sql .Token (
474
543
sqlparse .tokens .Number ,
475
- sqlalchemy .types .Float ().literal_processor (self ._engine .dialect )(value ))
544
+ sqlalchemy .types .Float ().literal_processor (self ._engine .dialect )(
545
+ value
546
+ ),
547
+ )
476
548
477
549
# int
478
550
elif isinstance (value , int ):
479
551
return sqlparse .sql .Token (
480
552
sqlparse .tokens .Number ,
481
- sqlalchemy .types .Integer ().literal_processor (self ._engine .dialect )(value ))
553
+ sqlalchemy .types .Integer ().literal_processor (self ._engine .dialect )(
554
+ value
555
+ ),
556
+ )
482
557
483
558
# str
484
559
elif isinstance (value , str ):
485
560
return sqlparse .sql .Token (
486
561
sqlparse .tokens .String ,
487
- sqlalchemy .types .String ().literal_processor (self ._engine .dialect )(value ))
562
+ sqlalchemy .types .String ().literal_processor (self ._engine .dialect )(
563
+ value
564
+ ),
565
+ )
488
566
489
567
# None
490
568
elif value is None :
491
- return sqlparse .sql .Token (
492
- sqlparse .tokens .Keyword ,
493
- sqlalchemy .null ())
569
+ return sqlparse .sql .Token (sqlparse .tokens .Keyword , sqlalchemy .null ())
494
570
495
571
# Unsupported value
496
572
else :
497
573
raise RuntimeError ("unsupported value: {}" .format (value ))
498
574
499
575
# Escape value(s), separating with commas as needed
500
576
if isinstance (value , (list , tuple )):
501
- return sqlparse .sql .TokenList (sqlparse .parse (", " .join ([str (__escape (v )) for v in value ])))
577
+ return sqlparse .sql .TokenList (
578
+ sqlparse .parse (", " .join ([str (__escape (v )) for v in value ]))
579
+ )
502
580
else :
503
581
return __escape (value )
504
582
@@ -510,7 +588,9 @@ def _parse_exception(e):
510
588
import re
511
589
512
590
# MySQL
513
- matches = re .search (r"^\(_mysql_exceptions\.OperationalError\) \(\d+, \"(.+)\"\)$" , str (e ))
591
+ matches = re .search (
592
+ r"^\(_mysql_exceptions\.OperationalError\) \(\d+, \"(.+)\"\)$" , str (e )
593
+ )
514
594
if matches :
515
595
return matches .group (1 )
516
596
@@ -536,7 +616,10 @@ def _parse_placeholder(token):
536
616
import sqlparse
537
617
538
618
# Validate token
539
- if not isinstance (token , sqlparse .sql .Token ) or token .ttype != sqlparse .tokens .Name .Placeholder :
619
+ if (
620
+ not isinstance (token , sqlparse .sql .Token )
621
+ or token .ttype != sqlparse .tokens .Name .Placeholder
622
+ ):
540
623
raise TypeError ()
541
624
542
625
# qmark
0 commit comments