Skip to content

normalizing whitespace in log, removing .logger #42

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -16,5 +16,5 @@
package_dir={"": "src"},
packages=["cs50"],
url="https://github.com/cs50/python-cs50",
version="2.3.1"
version="2.3.2"
)
42 changes: 21 additions & 21 deletions src/cs50/sql.py
Original file line number Diff line number Diff line change
@@ -22,11 +22,11 @@ def __init__(self, url, **kwargs):
http://docs.sqlalchemy.org/en/latest/dialects/index.html
"""

# log statements to standard error
# Log statements to standard error
logging.basicConfig(level=logging.DEBUG)
self.logger = logging.getLogger("cs50")
self.logger = logging.getLogger(__name__)

# create engine, raising exception if back end's module not installed
# Create engine, raising exception if back end's module not installed
self.engine = sqlalchemy.create_engine(url, **kwargs)

def execute(self, text, **params):
@@ -37,11 +37,11 @@ class UserDefinedType(sqlalchemy.TypeDecorator):
"""
Add support for expandable values, a la https://bitbucket.org/zzzeek/sqlalchemy/issues/3953/expanding-parameter.
"""

impl = sqlalchemy.types.UserDefinedType

def process_literal_param(self, value, dialect):
"""Receive a literal parameter value to be rendered inline within a statement."""

def process(value):
"""Render a literal value, escaping as needed."""

@@ -84,48 +84,48 @@ def process(value):
# unsupported value
raise RuntimeError("unsupported value")

# process value(s), separating with commas as needed
# Process value(s), separating with commas as needed
if type(value) is list:
return ", ".join([process(v) for v in value])
else:
return process(value)

# allow only one statement at a time
# Allow only one statement at a time
if len(sqlparse.split(text)) > 1:
raise RuntimeError("too many statements at once")

# raise exceptions for warnings
# Raise exceptions for warnings
warnings.filterwarnings("error")

# prepare, execute statement
# Prepare, execute statement
try:

# construct a new TextClause clause
# Construct a new TextClause clause
statement = sqlalchemy.text(text)

# iterate over parameters
# Iterate over parameters
for key, value in params.items():

# translate None to NULL
# Translate None to NULL
if value is None:
value = sqlalchemy.sql.null()

# bind parameters before statement reaches database, so that bound parameters appear in exceptions
# Bind parameters before statement reaches database, so that bound parameters appear in exceptions
# http://docs.sqlalchemy.org/en/latest/core/sqlelement.html#sqlalchemy.sql.expression.text
statement = statement.bindparams(sqlalchemy.bindparam(
key, value=value, type_=UserDefinedType()))

# stringify bound parameters
# Stringify bound parameters
# http://docs.sqlalchemy.org/en/latest/faq/sqlexpressions.html#how-do-i-render-sql-expressions-as-strings-possibly-with-bound-parameters-inlined
statement = str(statement.compile(compile_kwargs={"literal_binds": True}))

# execute statement
# Execute statement
result = self.engine.execute(statement)

# log statement
self.logger.debug(statement)
# Log statement
self.logger.debug(re.sub(r"\n\s*", " ", sqlparse.format(statement, reindent=True)))

# if SELECT (or INSERT with RETURNING), return result set as list of dict objects
# If SELECT (or INSERT with RETURNING), return result set as list of dict objects
if re.search(r"^\s*SELECT", statement, re.I):

# coerce any decimal.Decimal objects to float objects
@@ -137,21 +137,21 @@ def process(value):
row[column] = float(row[column])
return rows

# if INSERT, return primary key value for a newly inserted row
# If INSERT, return primary key value for a newly inserted row
elif re.search(r"^\s*INSERT", statement, re.I):
if self.engine.url.get_backend_name() in ["postgres", "postgresql"]:
result = self.engine.execute(sqlalchemy.text("SELECT LASTVAL()"))
return result.first()[0]
else:
return result.lastrowid

# if DELETE or UPDATE, return number of rows matched
# If DELETE or UPDATE, return number of rows matched
elif re.search(r"^\s*(?:DELETE|UPDATE)", statement, re.I):
return result.rowcount

# if some other statement, return True unless exception
# If some other statement, return True unless exception
return True

# if constraint violated, return None
# If constraint violated, return None
except sqlalchemy.exc.IntegrityError:
return None