diff --git a/src/cs50/sql.py b/src/cs50/sql.py index 4cfbd78..4b54d97 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -1,4 +1,5 @@ import datetime +import decimal import importlib import logging import re @@ -31,7 +32,6 @@ def execute(self, text, **params): """ Execute a SQL statement. """ - class UserDefinedType(sqlalchemy.TypeDecorator): """ Add support for expandable values, a la https://bitbucket.org/zzzeek/sqlalchemy/issues/3953/expanding-parameter. @@ -122,12 +122,19 @@ def process(value): self.logger.debug(statement) # if SELECT (or INSERT with RETURNING), return result set as list of dict objects - if re.search(r"^\s*SELECT\s+", statement, re.I): - rows = result.fetchall() - return [dict(row) for row in rows] + if re.search(r"^\s*SELECT", statement, re.I): + + # coerce any decimal.Decimal objects to float objects + # https://groups.google.com/d/msg/sqlalchemy/0qXMYJvq8SA/oqtvMD9Uw-kJ + rows = [dict(row) for row in result.fetchall()] + for row in rows: + for column in row: + if isinstance(row[column], decimal.Decimal): + row[column] = float(row[column]) + return rows # if INSERT, return primary key value for a newly inserted row - elif re.search(r"^\s*INSERT\s+", statement, re.I): + elif re.search(r"^\s*INSERT", statement, re.I): if self.engine.url.get_backend_name() in ["postgres", "postgresql"]: result = self.engine.execute(sqlalchemy.text("SELECT LASTVAL()")) return result.first()[0] @@ -135,7 +142,7 @@ def process(value): return result.lastrowid # if DELETE or UPDATE, return number of rows matched - elif re.search(r"^\s*(?:DELETE|UPDATE)\s+", statement, re.I): + elif re.search(r"^\s*(?:DELETE|UPDATE)", statement, re.I): return result.rowcount # if some other statement, return True unless exception