Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 1252ff7

Browse files
committedJan 7, 2018
improved error messages, added color-coding
1 parent 5ecefbb commit 1252ff7

File tree

1 file changed

+65
-10
lines changed

1 file changed

+65
-10
lines changed
 

‎src/cs50/sql.py

+65-10
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
import decimal
33
import importlib
44
import logging
5+
import os
56
import re
67
import sqlalchemy
78
import sqlparse
89
import sys
10+
import termcolor
911
import warnings
1012

1113

@@ -22,12 +24,52 @@ def __init__(self, url, **kwargs):
2224
http://docs.sqlalchemy.org/en/latest/dialects/index.html
2325
"""
2426

27+
# Require that file already exist for SQLite
28+
matches = re.search(r"^sqlite:///(.+)$", url)
29+
if matches:
30+
if not os.path.exists(matches.group(1)):
31+
raise RuntimeError("does not exist: {}".format(matches.group(1)))
32+
if not os.path.isfile(matches.group(1)):
33+
raise RuntimeError("not a file: {}".format(matches.group(1)))
34+
35+
# Create engine, raising exception if back end's module not installed
36+
self.engine = sqlalchemy.create_engine(url, **kwargs)
37+
2538
# Log statements to standard error
2639
logging.basicConfig(level=logging.DEBUG)
2740
self.logger = logging.getLogger("cs50")
2841

29-
# Create engine, raising exception if back end's module not installed
30-
self.engine = sqlalchemy.create_engine(url, **kwargs)
42+
# Test database
43+
try:
44+
self.logger.disabled = True
45+
self.execute("SELECT 1")
46+
except sqlalchemy.exc.OperationalError as e:
47+
e = RuntimeError(self._parse(e))
48+
e.__cause__ = None
49+
raise e
50+
else:
51+
self.logger.disabled = False
52+
53+
def _parse(self, e):
54+
"""Parses an exception, returns its message."""
55+
56+
# MySQL
57+
matches = re.search(r"^\(_mysql_exceptions\.OperationalError\) \(\d+, \"(.+)\"\)$", str(e))
58+
if matches:
59+
return matches.group(1)
60+
61+
# PostgreSQL
62+
matches = re.search(r"^\((psycopg2\.OperationalError)\) (.+)$", str(e))
63+
if matches:
64+
return matches.group(1)
65+
66+
# SQLite
67+
matches = re.search(r"^\(sqlite3\.OperationalError\) (.+)$", str(e))
68+
if matches:
69+
return matches.group(1)
70+
71+
# Default
72+
return str(e)
3173

3274
def execute(self, text, **params):
3375
"""
@@ -119,12 +161,12 @@ def process(value):
119161
# http://docs.sqlalchemy.org/en/latest/faq/sqlexpressions.html#how-do-i-render-sql-expressions-as-strings-possibly-with-bound-parameters-inlined
120162
statement = str(statement.compile(compile_kwargs={"literal_binds": True}))
121163

164+
# Statement for logging
165+
log = re.sub(r"\n\s*", " ", sqlparse.format(statement, reindent=True))
166+
122167
# Execute statement
123168
result = self.engine.execute(statement)
124169

125-
# Log statement
126-
self.logger.debug(re.sub(r"\n\s*", " ", sqlparse.format(statement, reindent=True)))
127-
128170
# If SELECT (or INSERT with RETURNING), return result set as list of dict objects
129171
if re.search(r"^\s*SELECT", statement, re.I):
130172

@@ -135,23 +177,36 @@ def process(value):
135177
for column in row:
136178
if isinstance(row[column], decimal.Decimal):
137179
row[column] = float(row[column])
138-
return rows
180+
ret = rows
139181

140182
# If INSERT, return primary key value for a newly inserted row
141183
elif re.search(r"^\s*INSERT", statement, re.I):
142184
if self.engine.url.get_backend_name() in ["postgres", "postgresql"]:
143185
result = self.engine.execute(sqlalchemy.text("SELECT LASTVAL()"))
144-
return result.first()[0]
186+
ret = result.first()[0]
145187
else:
146-
return result.lastrowid
188+
ret = result.lastrowid
147189

148190
# If DELETE or UPDATE, return number of rows matched
149191
elif re.search(r"^\s*(?:DELETE|UPDATE)", statement, re.I):
150-
return result.rowcount
192+
ret = result.rowcount
151193

152194
# If some other statement, return True unless exception
153-
return True
195+
ret = True
154196

155197
# If constraint violated, return None
156198
except sqlalchemy.exc.IntegrityError:
199+
self.logger.debug(termcolor.colored(log, "yellow"))
157200
return None
201+
202+
# If user errror
203+
except sqlalchemy.exc.OperationalError as e:
204+
self.logger.debug(termcolor.colored(log, "red"))
205+
e = RuntimeError(self._parse(e))
206+
e.__cause__ = None
207+
raise e
208+
209+
# Return value
210+
else:
211+
self.logger.debug(termcolor.colored(log, "green"))
212+
return ret

0 commit comments

Comments
 (0)
Please sign in to comment.