Skip to content

Commit 8708774

Browse files
author
Kareem Zidane
committedApr 26, 2018
handling connect on engine level
1 parent 89657ea commit 8708774

File tree

3 files changed

+34
-19
lines changed

3 files changed

+34
-19
lines changed
 

‎.travis.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ install:
1414
before_script:
1515
- mysql -e 'CREATE DATABASE IF NOT EXISTS test;'
1616
- psql -c 'create database test;' -U postgres
17-
- touch test.db
17+
- touch test.db test1.db
1818
script: python tests/sql.py
1919
after_script: rm -f test.db
2020
jobs:

‎src/cs50/sql.py

+22-15
Original file line numberDiff line numberDiff line change
@@ -33,24 +33,18 @@ def __init__(self, url, **kwargs):
3333
if not os.path.isfile(matches.group(1)):
3434
raise RuntimeError("not a file: {}".format(matches.group(1)))
3535

36-
# Optionally enable foreign key constraints
37-
# http://docs.sqlalchemy.org/en/latest/dialects/sqlite.html#foreign-key-support
38-
if kwargs.pop("pragma_foreign_keys", False):
39-
@sqlalchemy.event.listens_for(sqlalchemy.engine.Engine, "connect")
40-
def _set_sqlite_pragma(dbapi_connection, connection_record):
41-
"""Enables foreign key support."""
36+
pragma_foreign_keys = kwargs.pop("pragma_foreign_keys", False)
4237

43-
# Ensure backend is sqlite
44-
if type(dbapi_connection) is sqlite3.Connection:
45-
cursor = dbapi_connection.cursor()
46-
47-
# Respect foreign key constraints by default
48-
cursor.execute("PRAGMA foreign_keys=ON")
49-
cursor.close()
38+
# Create engine, raising exception if back end's module not installed
39+
self.engine = sqlalchemy.create_engine(url, **kwargs)
5040

41+
# Whether to enable foreign key constraints
42+
if pragma_foreign_keys:
43+
sqlalchemy.event.listen(self.engine, "connect", _on_connect)
44+
else:
45+
# Create engine, raising exception if back end's module not installed
46+
self.engine = sqlalchemy.create_engine(url, **kwargs)
5147

52-
# Create engine, raising exception if back end's module not installed
53-
self.engine = sqlalchemy.create_engine(url, **kwargs)
5448

5549
# Log statements to standard error
5650
logging.basicConfig(level=logging.DEBUG)
@@ -229,3 +223,16 @@ def process(value):
229223
else:
230224
self.logger.debug(termcolor.colored(log, "green"))
231225
return ret
226+
227+
228+
# http://docs.sqlalchemy.org/en/latest/dialects/sqlite.html#foreign-key-support
229+
def _on_connect(dbapi_connection, connection_record):
230+
"""Enables foreign key support."""
231+
232+
# Ensure backend is sqlite
233+
if type(dbapi_connection) is sqlite3.Connection:
234+
cursor = dbapi_connection.cursor()
235+
236+
# Respect foreign key constraints by default
237+
cursor.execute("PRAGMA foreign_keys=ON")
238+
cursor.close()

‎tests/sql.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -107,18 +107,26 @@ class SQLiteTests(SQLTests):
107107
@classmethod
108108
def setUpClass(self):
109109
self.db = SQL("sqlite:///test.db")
110+
self.db1 = SQL("sqlite:///test1.db", pragma_foreign_keys=True)
110111

111112
def setUp(self):
112113
self.db.execute("CREATE TABLE cs50(id INTEGER PRIMARY KEY, val TEXT)")
113114

114-
def multi_inserts_enabled(self):
115-
return False
115+
def test_foreign_key_support(self):
116+
self.db.execute("CREATE TABLE foo(id INTEGER PRIMARY KEY)")
117+
self.db.execute("CREATE TABLE bar(foo_id INTEGER, FOREIGN KEY (foo_id) REFERENCES foo(id))")
118+
self.assertEqual(self.db.execute("INSERT INTO bar VALUES(50)"), 1)
119+
120+
self.db1.execute("CREATE TABLE foo(id INTEGER PRIMARY KEY)")
121+
self.db1.execute("CREATE TABLE bar(foo_id INTEGER, FOREIGN KEY (foo_id) REFERENCES foo(id))")
122+
self.assertEqual(self.db1.execute("INSERT INTO bar VALUES(50)"), None)
116123

117124
if __name__ == "__main__":
118125
suite = unittest.TestSuite([
119126
unittest.TestLoader().loadTestsFromTestCase(SQLiteTests),
120127
unittest.TestLoader().loadTestsFromTestCase(MySQLTests),
121128
unittest.TestLoader().loadTestsFromTestCase(PostgresTests)
122129
])
123-
logging.getLogger("cs50.sql").disabled = True
130+
131+
logging.getLogger("cs50").disabled = True
124132
sys.exit(not unittest.TextTestRunner(verbosity=2).run(suite).wasSuccessful())

0 commit comments

Comments
 (0)