Skip to content

Commit 4afc998

Browse files
author
Kareem Zidane
authoredMar 10, 2018
Merge pull request #49 from cs50/develop
2.4.0
2 parents 7f30db9 + f19934d commit 4afc998

File tree

8 files changed

+112
-35
lines changed

8 files changed

+112
-35
lines changed
 

‎.travis.yml

+3-1
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@ services:
1010
install:
1111
- python setup.py install
1212
- pip install mysqlclient
13-
- pip install psycopg2
13+
- pip install psycopg2-binary
1414
before_script:
1515
- mysql -e 'CREATE DATABASE IF NOT EXISTS test;'
1616
- psql -c 'create database test;' -U postgres
17+
- touch test.db
1718
script: python tests/sql.py
19+
after_script: rm -f test.db
1820
jobs:
1921
include:
2022
- stage: deploy

‎setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,5 @@
1616
package_dir={"": "src"},
1717
packages=["cs50"],
1818
url="https://github.com/cs50/python-cs50",
19-
version="2.3.3"
19+
version="2.4.0"
2020
)

‎src/cs50/cs50.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def get_char(prompt=None):
8383
if len(s) == 1:
8484
return s[0]
8585

86-
# temporarily here for backwards compatibility
86+
# Temporarily here for backwards compatibility
8787
if prompt is None:
8888
print("Retry: ", end="")
8989

@@ -104,7 +104,7 @@ def get_float(prompt=None):
104104
except ValueError:
105105
pass
106106

107-
# temporarily here for backwards compatibility
107+
# Temporarily here for backwards compatibility
108108
if prompt is None:
109109
print("Retry: ", end="")
110110

@@ -122,12 +122,12 @@ def get_int(prompt=None):
122122
if re.search(r"^[+-]?\d+$", s):
123123
try:
124124
i = int(s, 10)
125-
if type(i) is int: # could become long in Python 2
125+
if type(i) is int: # Could become long in Python 2
126126
return i
127127
except ValueError:
128128
pass
129129

130-
# temporarily here for backwards compatibility
130+
# Temporarily here for backwards compatibility
131131
if prompt is None:
132132
print("Retry: ", end="")
133133

@@ -149,7 +149,7 @@ def get_long(prompt=None):
149149
except ValueError:
150150
pass
151151

152-
# temporarily here for backwards compatibility
152+
# Temporarily here for backwards compatibility
153153
if prompt is None:
154154
print("Retry: ", end="")
155155

‎src/cs50/sql.py

+84-28
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
import decimal
33
import importlib
44
import logging
5+
import os
56
import re
67
import sqlalchemy
78
import sqlparse
89
import sys
10+
import termcolor
911
import warnings
1012

1113

@@ -22,12 +24,52 @@ def __init__(self, url, **kwargs):
2224
http://docs.sqlalchemy.org/en/latest/dialects/index.html
2325
"""
2426

25-
# log statements to standard error
27+
# Require that file already exist for SQLite
28+
matches = re.search(r"^sqlite:///(.+)$", url)
29+
if matches:
30+
if not os.path.exists(matches.group(1)):
31+
raise RuntimeError("does not exist: {}".format(matches.group(1)))
32+
if not os.path.isfile(matches.group(1)):
33+
raise RuntimeError("not a file: {}".format(matches.group(1)))
34+
35+
# Create engine, raising exception if back end's module not installed
36+
self.engine = sqlalchemy.create_engine(url, **kwargs)
37+
38+
# Log statements to standard error
2639
logging.basicConfig(level=logging.DEBUG)
2740
self.logger = logging.getLogger("cs50")
2841

29-
# create engine, raising exception if back end's module not installed
30-
self.engine = sqlalchemy.create_engine(url, **kwargs)
42+
# Test database
43+
try:
44+
self.logger.disabled = True
45+
self.execute("SELECT 1")
46+
except sqlalchemy.exc.OperationalError as e:
47+
e = RuntimeError(self._parse(e))
48+
e.__cause__ = None
49+
raise e
50+
else:
51+
self.logger.disabled = False
52+
53+
def _parse(self, e):
54+
"""Parses an exception, returns its message."""
55+
56+
# MySQL
57+
matches = re.search(r"^\(_mysql_exceptions\.OperationalError\) \(\d+, \"(.+)\"\)$", str(e))
58+
if matches:
59+
return matches.group(1)
60+
61+
# PostgreSQL
62+
matches = re.search(r"^\(psycopg2\.OperationalError\) (.+)$", str(e))
63+
if matches:
64+
return matches.group(1)
65+
66+
# SQLite
67+
matches = re.search(r"^\(sqlite3\.OperationalError\) (.+)$", str(e))
68+
if matches:
69+
return matches.group(1)
70+
71+
# Default
72+
return str(e)
3173

3274
def execute(self, text, **params):
3375
"""
@@ -81,77 +123,91 @@ def process(value):
81123
elif isinstance(value, sqlalchemy.sql.elements.Null):
82124
return sqlalchemy.types.NullType().literal_processor(dialect)(value)
83125

84-
# unsupported value
126+
# Unsupported value
85127
raise RuntimeError("unsupported value")
86128

87-
# process value(s), separating with commas as needed
129+
# Process value(s), separating with commas as needed
88130
if type(value) is list:
89131
return ", ".join([process(v) for v in value])
90132
else:
91133
return process(value)
92134

93-
# allow only one statement at a time
135+
# Allow only one statement at a time
94136
if len(sqlparse.split(text)) > 1:
95137
raise RuntimeError("too many statements at once")
96138

97-
# raise exceptions for warnings
139+
# Raise exceptions for warnings
98140
warnings.filterwarnings("error")
99141

100-
# prepare, execute statement
142+
# Prepare, execute statement
101143
try:
102144

103-
# construct a new TextClause clause
145+
# Construct a new TextClause clause
104146
statement = sqlalchemy.text(text)
105147

106-
# iterate over parameters
148+
# Iterate over parameters
107149
for key, value in params.items():
108150

109-
# translate None to NULL
151+
# Translate None to NULL
110152
if value is None:
111153
value = sqlalchemy.sql.null()
112154

113-
# bind parameters before statement reaches database, so that bound parameters appear in exceptions
155+
# Bind parameters before statement reaches database, so that bound parameters appear in exceptions
114156
# http://docs.sqlalchemy.org/en/latest/core/sqlelement.html#sqlalchemy.sql.expression.text
115157
statement = statement.bindparams(sqlalchemy.bindparam(
116158
key, value=value, type_=UserDefinedType()))
117159

118-
# stringify bound parameters
160+
# Stringify bound parameters
119161
# http://docs.sqlalchemy.org/en/latest/faq/sqlexpressions.html#how-do-i-render-sql-expressions-as-strings-possibly-with-bound-parameters-inlined
120162
statement = str(statement.compile(compile_kwargs={"literal_binds": True}))
121163

122-
# execute statement
123-
result = self.engine.execute(statement)
164+
# Statement for logging
165+
log = re.sub(r"\n\s*", " ", sqlparse.format(statement, reindent=True))
124166

125-
# log statement
126-
self.logger.debug(re.sub(r"\n\s*", " ", sqlparse.format(statement, reindent=True)))
167+
# Execute statement
168+
result = self.engine.execute(statement)
127169

128-
# if SELECT (or INSERT with RETURNING), return result set as list of dict objects
170+
# If SELECT (or INSERT with RETURNING), return result set as list of dict objects
129171
if re.search(r"^\s*SELECT", statement, re.I):
130172

131-
# coerce any decimal.Decimal objects to float objects
173+
# Coerce any decimal.Decimal objects to float objects
132174
# https://groups.google.com/d/msg/sqlalchemy/0qXMYJvq8SA/oqtvMD9Uw-kJ
133175
rows = [dict(row) for row in result.fetchall()]
134176
for row in rows:
135177
for column in row:
136178
if isinstance(row[column], decimal.Decimal):
137179
row[column] = float(row[column])
138-
return rows
180+
ret = rows
139181

140-
# if INSERT, return primary key value for a newly inserted row
182+
# If INSERT, return primary key value for a newly inserted row
141183
elif re.search(r"^\s*INSERT", statement, re.I):
142184
if self.engine.url.get_backend_name() in ["postgres", "postgresql"]:
143185
result = self.engine.execute(sqlalchemy.text("SELECT LASTVAL()"))
144-
return result.first()[0]
186+
ret = result.first()[0]
145187
else:
146-
return result.lastrowid
188+
ret = result.lastrowid
147189

148-
# if DELETE or UPDATE, return number of rows matched
190+
# If DELETE or UPDATE, return number of rows matched
149191
elif re.search(r"^\s*(?:DELETE|UPDATE)", statement, re.I):
150-
return result.rowcount
192+
ret = result.rowcount
151193

152-
# if some other statement, return True unless exception
153-
return True
194+
# If some other statement, return True unless exception
195+
else:
196+
ret = True
154197

155-
# if constraint violated, return None
198+
# If constraint violated, return None
156199
except sqlalchemy.exc.IntegrityError:
200+
self.logger.debug(termcolor.colored(log, "yellow"))
157201
return None
202+
203+
# If user errror
204+
except sqlalchemy.exc.OperationalError as e:
205+
self.logger.debug(termcolor.colored(log, "red"))
206+
e = RuntimeError(self._parse(e))
207+
e.__cause__ = None
208+
raise e
209+
210+
# Return value
211+
else:
212+
self.logger.debug(termcolor.colored(log, "green"))
213+
return ret

‎tests/flask/application.py

+3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import requests
2+
import sys
23
from flask import Flask, render_template
34

5+
sys.path.insert(0, "../../src")
6+
47
import cs50
58

69
app = Flask(__name__)

‎tests/mysql.py

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import sys
2+
3+
sys.path.insert(0, "../src")
4+
5+
from cs50 import SQL
6+
7+
db = SQL("mysql://root@localhost/test")
8+
db.execute("SELECT 1")

‎tests/sqlite.py

+4
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
import sys
2+
3+
sys.path.insert(0, "../src")
4+
15
from cs50 import SQL
26

37
db = SQL("sqlite:///sqlite.db")

‎tests/tb.py

+4
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
import sys
2+
3+
sys.path.insert(0, "../src")
4+
15
import cs50
26
import requests
37

0 commit comments

Comments
 (0)
Please sign in to comment.