Skip to content

used catch_warnings #88

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 5, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -16,5 +16,5 @@
package_dir={"": "src"},
packages=["cs50"],
url="https://github.com/cs50/python-cs50",
version="4.0.2"
version="4.0.3"
)
111 changes: 57 additions & 54 deletions src/cs50/sql.py
Original file line number Diff line number Diff line change
@@ -215,62 +215,65 @@ def execute(self, sql, *args, **kwargs):
# Join tokens into statement
statement = "".join([str(token) for token in tokens])

# Raise exceptions for warnings
warnings.filterwarnings("error")

# Prepare, execute statement
try:

# Execute statement
result = self.engine.execute(sqlalchemy.text(statement))
# Catch SQLAlchemy warnings
with warnings.catch_warnings():

# Raise exceptions for warnings
warnings.simplefilter("error")

# Prepare, execute statement
try:

# Execute statement
result = self.engine.execute(sqlalchemy.text(statement))

# Return value
ret = True
if tokens[0].ttype == sqlparse.tokens.Keyword.DML:

# Uppercase token's value
value = tokens[0].value.upper()

# If SELECT, return result set as list of dict objects
if value == "SELECT":

# Coerce any decimal.Decimal objects to float objects
# https://groups.google.com/d/msg/sqlalchemy/0qXMYJvq8SA/oqtvMD9Uw-kJ
rows = [dict(row) for row in result.fetchall()]
for row in rows:
for column in row:
if type(row[column]) is decimal.Decimal:
row[column] = float(row[column])
ret = rows

# If INSERT, return primary key value for a newly inserted row
elif value == "INSERT":
if self.engine.url.get_backend_name() in ["postgres", "postgresql"]:
result = self.engine.execute("SELECT LASTVAL()")
ret = result.first()[0]
else:
ret = result.lastrowid

# If DELETE or UPDATE, return number of rows matched
elif value in ["DELETE", "UPDATE"]:
ret = result.rowcount

# If constraint violated, return None
except sqlalchemy.exc.IntegrityError:
self._logger.debug(termcolor.colored(statement, "yellow"))
return None

# If user errror
except sqlalchemy.exc.OperationalError as e:
self._logger.debug(termcolor.colored(statement, "red"))
e = RuntimeError(_parse_exception(e))
e.__cause__ = None
raise e

# Return value
ret = True
if tokens[0].ttype == sqlparse.tokens.Keyword.DML:

# Uppercase token's value
value = tokens[0].value.upper()

# If SELECT, return result set as list of dict objects
if value == "SELECT":

# Coerce any decimal.Decimal objects to float objects
# https://groups.google.com/d/msg/sqlalchemy/0qXMYJvq8SA/oqtvMD9Uw-kJ
rows = [dict(row) for row in result.fetchall()]
for row in rows:
for column in row:
if type(row[column]) is decimal.Decimal:
row[column] = float(row[column])
ret = rows

# If INSERT, return primary key value for a newly inserted row
elif value == "INSERT":
if self.engine.url.get_backend_name() in ["postgres", "postgresql"]:
result = self.engine.execute("SELECT LASTVAL()")
ret = result.first()[0]
else:
ret = result.lastrowid

# If DELETE or UPDATE, return number of rows matched
elif value in ["DELETE", "UPDATE"]:
ret = result.rowcount

# If constraint violated, return None
except sqlalchemy.exc.IntegrityError:
self._logger.debug(termcolor.colored(statement, "yellow"))
return None

# If user errror
except sqlalchemy.exc.OperationalError as e:
self._logger.debug(termcolor.colored(statement, "red"))
e = RuntimeError(_parse_exception(e))
e.__cause__ = None
raise e

# Return value
else:
self._logger.debug(termcolor.colored(statement, "green"))
return ret
else:
self._logger.debug(termcolor.colored(statement, "green"))
return ret

def _escape(self, value):
"""