2
2
import decimal
3
3
import importlib
4
4
import logging
5
+ import os
5
6
import re
6
7
import sqlalchemy
7
8
import sqlparse
8
9
import sys
10
+ import termcolor
9
11
import warnings
10
12
11
13
@@ -22,12 +24,52 @@ def __init__(self, url, **kwargs):
22
24
http://docs.sqlalchemy.org/en/latest/dialects/index.html
23
25
"""
24
26
27
+ # Require that file already exist for SQLite
28
+ matches = re .search (r"^sqlite:///(.+)$" , url )
29
+ if matches :
30
+ if not os .path .exists (matches .group (1 )):
31
+ raise RuntimeError ("does not exist: {}" .format (matches .group (1 )))
32
+ if not os .path .isfile (matches .group (1 )):
33
+ raise RuntimeError ("not a file: {}" .format (matches .group (1 )))
34
+
35
+ # Create engine, raising exception if back end's module not installed
36
+ self .engine = sqlalchemy .create_engine (url , ** kwargs )
37
+
25
38
# Log statements to standard error
26
39
logging .basicConfig (level = logging .DEBUG )
27
40
self .logger = logging .getLogger ("cs50" )
28
41
29
- # Create engine, raising exception if back end's module not installed
30
- self .engine = sqlalchemy .create_engine (url , ** kwargs )
42
+ # Test database
43
+ try :
44
+ self .logger .disabled = True
45
+ self .execute ("SELECT 1" )
46
+ except sqlalchemy .exc .OperationalError as e :
47
+ e = RuntimeError (self ._parse (e ))
48
+ e .__cause__ = None
49
+ raise e
50
+ else :
51
+ self .logger .disabled = False
52
+
53
+ def _parse (self , e ):
54
+ """Parses an exception, returns its message."""
55
+
56
+ # MySQL
57
+ matches = re .search (r"^\(_mysql_exceptions\.OperationalError\) \(\d+, \"(.+)\"\)$" , str (e ))
58
+ if matches :
59
+ return matches .group (1 )
60
+
61
+ # PostgreSQL
62
+ matches = re .search (r"^\((psycopg2\.OperationalError)\) (.+)$" , str (e ))
63
+ if matches :
64
+ return matches .group (1 )
65
+
66
+ # SQLite
67
+ matches = re .search (r"^\(sqlite3\.OperationalError\) (.+)$" , str (e ))
68
+ if matches :
69
+ return matches .group (1 )
70
+
71
+ # Default
72
+ return str (e )
31
73
32
74
def execute (self , text , ** params ):
33
75
"""
@@ -119,12 +161,12 @@ def process(value):
119
161
# http://docs.sqlalchemy.org/en/latest/faq/sqlexpressions.html#how-do-i-render-sql-expressions-as-strings-possibly-with-bound-parameters-inlined
120
162
statement = str (statement .compile (compile_kwargs = {"literal_binds" : True }))
121
163
164
+ # Statement for logging
165
+ log = re .sub (r"\n\s*" , " " , sqlparse .format (statement , reindent = True ))
166
+
122
167
# Execute statement
123
168
result = self .engine .execute (statement )
124
169
125
- # Log statement
126
- self .logger .debug (re .sub (r"\n\s*" , " " , sqlparse .format (statement , reindent = True )))
127
-
128
170
# If SELECT (or INSERT with RETURNING), return result set as list of dict objects
129
171
if re .search (r"^\s*SELECT" , statement , re .I ):
130
172
@@ -135,23 +177,36 @@ def process(value):
135
177
for column in row :
136
178
if isinstance (row [column ], decimal .Decimal ):
137
179
row [column ] = float (row [column ])
138
- return rows
180
+ ret = rows
139
181
140
182
# If INSERT, return primary key value for a newly inserted row
141
183
elif re .search (r"^\s*INSERT" , statement , re .I ):
142
184
if self .engine .url .get_backend_name () in ["postgres" , "postgresql" ]:
143
185
result = self .engine .execute (sqlalchemy .text ("SELECT LASTVAL()" ))
144
- return result .first ()[0 ]
186
+ ret = result .first ()[0 ]
145
187
else :
146
- return result .lastrowid
188
+ ret = result .lastrowid
147
189
148
190
# If DELETE or UPDATE, return number of rows matched
149
191
elif re .search (r"^\s*(?:DELETE|UPDATE)" , statement , re .I ):
150
- return result .rowcount
192
+ ret = result .rowcount
151
193
152
194
# If some other statement, return True unless exception
153
- return True
195
+ ret = True
154
196
155
197
# If constraint violated, return None
156
198
except sqlalchemy .exc .IntegrityError :
199
+ self .logger .debug (termcolor .colored (log , "yellow" ))
157
200
return None
201
+
202
+ # If user errror
203
+ except sqlalchemy .exc .OperationalError as e :
204
+ self .logger .debug (termcolor .colored (log , "red" ))
205
+ e = RuntimeError (self ._parse (e ))
206
+ e .__cause__ = None
207
+ raise e
208
+
209
+ # Return value
210
+ else :
211
+ self .logger .debug (termcolor .colored (log , "green" ))
212
+ return ret
0 commit comments