2
2
import decimal
3
3
import importlib
4
4
import logging
5
+ import os
5
6
import re
6
7
import sqlalchemy
7
8
import sqlparse
8
9
import sys
10
+ import termcolor
9
11
import warnings
10
12
11
13
@@ -22,12 +24,52 @@ def __init__(self, url, **kwargs):
22
24
http://docs.sqlalchemy.org/en/latest/dialects/index.html
23
25
"""
24
26
25
- # log statements to standard error
27
+ # Require that file already exist for SQLite
28
+ matches = re .search (r"^sqlite:///(.+)$" , url )
29
+ if matches :
30
+ if not os .path .exists (matches .group (1 )):
31
+ raise RuntimeError ("does not exist: {}" .format (matches .group (1 )))
32
+ if not os .path .isfile (matches .group (1 )):
33
+ raise RuntimeError ("not a file: {}" .format (matches .group (1 )))
34
+
35
+ # Create engine, raising exception if back end's module not installed
36
+ self .engine = sqlalchemy .create_engine (url , ** kwargs )
37
+
38
+ # Log statements to standard error
26
39
logging .basicConfig (level = logging .DEBUG )
27
40
self .logger = logging .getLogger ("cs50" )
28
41
29
- # create engine, raising exception if back end's module not installed
30
- self .engine = sqlalchemy .create_engine (url , ** kwargs )
42
+ # Test database
43
+ try :
44
+ self .logger .disabled = True
45
+ self .execute ("SELECT 1" )
46
+ except sqlalchemy .exc .OperationalError as e :
47
+ e = RuntimeError (self ._parse (e ))
48
+ e .__cause__ = None
49
+ raise e
50
+ else :
51
+ self .logger .disabled = False
52
+
53
+ def _parse (self , e ):
54
+ """Parses an exception, returns its message."""
55
+
56
+ # MySQL
57
+ matches = re .search (r"^\(_mysql_exceptions\.OperationalError\) \(\d+, \"(.+)\"\)$" , str (e ))
58
+ if matches :
59
+ return matches .group (1 )
60
+
61
+ # PostgreSQL
62
+ matches = re .search (r"^\(psycopg2\.OperationalError\) (.+)$" , str (e ))
63
+ if matches :
64
+ return matches .group (1 )
65
+
66
+ # SQLite
67
+ matches = re .search (r"^\(sqlite3\.OperationalError\) (.+)$" , str (e ))
68
+ if matches :
69
+ return matches .group (1 )
70
+
71
+ # Default
72
+ return str (e )
31
73
32
74
def execute (self , text , ** params ):
33
75
"""
@@ -81,77 +123,91 @@ def process(value):
81
123
elif isinstance (value , sqlalchemy .sql .elements .Null ):
82
124
return sqlalchemy .types .NullType ().literal_processor (dialect )(value )
83
125
84
- # unsupported value
126
+ # Unsupported value
85
127
raise RuntimeError ("unsupported value" )
86
128
87
- # process value(s), separating with commas as needed
129
+ # Process value(s), separating with commas as needed
88
130
if type (value ) is list :
89
131
return ", " .join ([process (v ) for v in value ])
90
132
else :
91
133
return process (value )
92
134
93
- # allow only one statement at a time
135
+ # Allow only one statement at a time
94
136
if len (sqlparse .split (text )) > 1 :
95
137
raise RuntimeError ("too many statements at once" )
96
138
97
- # raise exceptions for warnings
139
+ # Raise exceptions for warnings
98
140
warnings .filterwarnings ("error" )
99
141
100
- # prepare , execute statement
142
+ # Prepare , execute statement
101
143
try :
102
144
103
- # construct a new TextClause clause
145
+ # Construct a new TextClause clause
104
146
statement = sqlalchemy .text (text )
105
147
106
- # iterate over parameters
148
+ # Iterate over parameters
107
149
for key , value in params .items ():
108
150
109
- # translate None to NULL
151
+ # Translate None to NULL
110
152
if value is None :
111
153
value = sqlalchemy .sql .null ()
112
154
113
- # bind parameters before statement reaches database, so that bound parameters appear in exceptions
155
+ # Bind parameters before statement reaches database, so that bound parameters appear in exceptions
114
156
# http://docs.sqlalchemy.org/en/latest/core/sqlelement.html#sqlalchemy.sql.expression.text
115
157
statement = statement .bindparams (sqlalchemy .bindparam (
116
158
key , value = value , type_ = UserDefinedType ()))
117
159
118
- # stringify bound parameters
160
+ # Stringify bound parameters
119
161
# http://docs.sqlalchemy.org/en/latest/faq/sqlexpressions.html#how-do-i-render-sql-expressions-as-strings-possibly-with-bound-parameters-inlined
120
162
statement = str (statement .compile (compile_kwargs = {"literal_binds" : True }))
121
163
122
- # execute statement
123
- result = self . engine . execute (statement )
164
+ # Statement for logging
165
+ log = re . sub ( r"\n\s*" , " " , sqlparse . format (statement , reindent = True ) )
124
166
125
- # log statement
126
- self .logger . debug ( re . sub ( r"\n\s*" , " " , sqlparse . format ( statement , reindent = True )) )
167
+ # Execute statement
168
+ result = self .engine . execute ( statement )
127
169
128
- # if SELECT (or INSERT with RETURNING), return result set as list of dict objects
170
+ # If SELECT (or INSERT with RETURNING), return result set as list of dict objects
129
171
if re .search (r"^\s*SELECT" , statement , re .I ):
130
172
131
- # coerce any decimal.Decimal objects to float objects
173
+ # Coerce any decimal.Decimal objects to float objects
132
174
# https://groups.google.com/d/msg/sqlalchemy/0qXMYJvq8SA/oqtvMD9Uw-kJ
133
175
rows = [dict (row ) for row in result .fetchall ()]
134
176
for row in rows :
135
177
for column in row :
136
178
if isinstance (row [column ], decimal .Decimal ):
137
179
row [column ] = float (row [column ])
138
- return rows
180
+ ret = rows
139
181
140
- # if INSERT, return primary key value for a newly inserted row
182
+ # If INSERT, return primary key value for a newly inserted row
141
183
elif re .search (r"^\s*INSERT" , statement , re .I ):
142
184
if self .engine .url .get_backend_name () in ["postgres" , "postgresql" ]:
143
185
result = self .engine .execute (sqlalchemy .text ("SELECT LASTVAL()" ))
144
- return result .first ()[0 ]
186
+ ret = result .first ()[0 ]
145
187
else :
146
- return result .lastrowid
188
+ ret = result .lastrowid
147
189
148
- # if DELETE or UPDATE, return number of rows matched
190
+ # If DELETE or UPDATE, return number of rows matched
149
191
elif re .search (r"^\s*(?:DELETE|UPDATE)" , statement , re .I ):
150
- return result .rowcount
192
+ ret = result .rowcount
151
193
152
- # if some other statement, return True unless exception
153
- return True
194
+ # If some other statement, return True unless exception
195
+ else :
196
+ ret = True
154
197
155
- # if constraint violated, return None
198
+ # If constraint violated, return None
156
199
except sqlalchemy .exc .IntegrityError :
200
+ self .logger .debug (termcolor .colored (log , "yellow" ))
157
201
return None
202
+
203
+ # If user errror
204
+ except sqlalchemy .exc .OperationalError as e :
205
+ self .logger .debug (termcolor .colored (log , "red" ))
206
+ e = RuntimeError (self ._parse (e ))
207
+ e .__cause__ = None
208
+ raise e
209
+
210
+ # Return value
211
+ else :
212
+ self .logger .debug (termcolor .colored (log , "green" ))
213
+ return ret
0 commit comments