Skip to content

Commit 89f7279

Browse files
authoredSep 5, 2019
Merge pull request #88 from cs50/warnings
used catch_warnings
2 parents a1a901e + 85adc53 commit 89f7279

File tree

2 files changed

+58
-55
lines changed

2 files changed

+58
-55
lines changed
 

‎setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,5 @@
1616
package_dir={"": "src"},
1717
packages=["cs50"],
1818
url="https://github.com/cs50/python-cs50",
19-
version="4.0.2"
19+
version="4.0.3"
2020
)

‎src/cs50/sql.py

+57-54
Original file line numberDiff line numberDiff line change
@@ -215,62 +215,65 @@ def execute(self, sql, *args, **kwargs):
215215
# Join tokens into statement
216216
statement = "".join([str(token) for token in tokens])
217217

218-
# Raise exceptions for warnings
219-
warnings.filterwarnings("error")
220-
221-
# Prepare, execute statement
222-
try:
223-
224-
# Execute statement
225-
result = self.engine.execute(sqlalchemy.text(statement))
218+
# Catch SQLAlchemy warnings
219+
with warnings.catch_warnings():
220+
221+
# Raise exceptions for warnings
222+
warnings.simplefilter("error")
223+
224+
# Prepare, execute statement
225+
try:
226+
227+
# Execute statement
228+
result = self.engine.execute(sqlalchemy.text(statement))
229+
230+
# Return value
231+
ret = True
232+
if tokens[0].ttype == sqlparse.tokens.Keyword.DML:
233+
234+
# Uppercase token's value
235+
value = tokens[0].value.upper()
236+
237+
# If SELECT, return result set as list of dict objects
238+
if value == "SELECT":
239+
240+
# Coerce any decimal.Decimal objects to float objects
241+
# https://groups.google.com/d/msg/sqlalchemy/0qXMYJvq8SA/oqtvMD9Uw-kJ
242+
rows = [dict(row) for row in result.fetchall()]
243+
for row in rows:
244+
for column in row:
245+
if type(row[column]) is decimal.Decimal:
246+
row[column] = float(row[column])
247+
ret = rows
248+
249+
# If INSERT, return primary key value for a newly inserted row
250+
elif value == "INSERT":
251+
if self.engine.url.get_backend_name() in ["postgres", "postgresql"]:
252+
result = self.engine.execute("SELECT LASTVAL()")
253+
ret = result.first()[0]
254+
else:
255+
ret = result.lastrowid
256+
257+
# If DELETE or UPDATE, return number of rows matched
258+
elif value in ["DELETE", "UPDATE"]:
259+
ret = result.rowcount
260+
261+
# If constraint violated, return None
262+
except sqlalchemy.exc.IntegrityError:
263+
self._logger.debug(termcolor.colored(statement, "yellow"))
264+
return None
265+
266+
# If user errror
267+
except sqlalchemy.exc.OperationalError as e:
268+
self._logger.debug(termcolor.colored(statement, "red"))
269+
e = RuntimeError(_parse_exception(e))
270+
e.__cause__ = None
271+
raise e
226272

227273
# Return value
228-
ret = True
229-
if tokens[0].ttype == sqlparse.tokens.Keyword.DML:
230-
231-
# Uppercase token's value
232-
value = tokens[0].value.upper()
233-
234-
# If SELECT, return result set as list of dict objects
235-
if value == "SELECT":
236-
237-
# Coerce any decimal.Decimal objects to float objects
238-
# https://groups.google.com/d/msg/sqlalchemy/0qXMYJvq8SA/oqtvMD9Uw-kJ
239-
rows = [dict(row) for row in result.fetchall()]
240-
for row in rows:
241-
for column in row:
242-
if type(row[column]) is decimal.Decimal:
243-
row[column] = float(row[column])
244-
ret = rows
245-
246-
# If INSERT, return primary key value for a newly inserted row
247-
elif value == "INSERT":
248-
if self.engine.url.get_backend_name() in ["postgres", "postgresql"]:
249-
result = self.engine.execute("SELECT LASTVAL()")
250-
ret = result.first()[0]
251-
else:
252-
ret = result.lastrowid
253-
254-
# If DELETE or UPDATE, return number of rows matched
255-
elif value in ["DELETE", "UPDATE"]:
256-
ret = result.rowcount
257-
258-
# If constraint violated, return None
259-
except sqlalchemy.exc.IntegrityError:
260-
self._logger.debug(termcolor.colored(statement, "yellow"))
261-
return None
262-
263-
# If user errror
264-
except sqlalchemy.exc.OperationalError as e:
265-
self._logger.debug(termcolor.colored(statement, "red"))
266-
e = RuntimeError(_parse_exception(e))
267-
e.__cause__ = None
268-
raise e
269-
270-
# Return value
271-
else:
272-
self._logger.debug(termcolor.colored(statement, "green"))
273-
return ret
274+
else:
275+
self._logger.debug(termcolor.colored(statement, "green"))
276+
return ret
274277

275278
def _escape(self, value):
276279
"""

0 commit comments

Comments
 (0)
Please sign in to comment.