1
- import unittest
2
1
from cs50 .sql import SQL
2
+ import sys
3
+ import unittest
4
+ import warnings
3
5
4
6
class SQLTests (unittest .TestCase ):
7
+ def multi_inserts_enabled (self ):
8
+ return True
9
+
5
10
def test_delete_returns_affected_rows (self ):
6
11
rows = [
7
12
{"id" : 1 , "val" : "foo" },
@@ -22,6 +27,8 @@ def test_delete_returns_affected_rows(self):
22
27
def test_insert_returns_last_row_id (self ):
23
28
self .assertEqual (self .db .execute ("INSERT INTO cs50(val) VALUES('foo')" ), 1 )
24
29
self .assertEqual (self .db .execute ("INSERT INTO cs50(val) VALUES('bar')" ), 2 )
30
+ if self .multi_inserts_enabled ():
31
+ self .assertEqual (self .db .execute ("INSERT INTO cs50(val) VALUES('baz'); INSERT INTO cs50(val) VALUES('qux')" ), 4 )
25
32
26
33
def test_select_all (self ):
27
34
self .assertEqual (self .db .execute ("SELECT * FROM cs50" ), [])
@@ -70,54 +77,44 @@ def test_update_returns_affected_rows(self):
70
77
self .assertEqual (self .db .execute ("UPDATE cs50 SET val = 'foo' WHERE id > 1" ), 2 )
71
78
self .assertEqual (self .db .execute ("UPDATE cs50 SET val = 'foo' WHERE id = -50" ), 0 )
72
79
73
- class MySQLTests (SQLTests ):
74
- @classmethod
75
- def setUpClass (self ):
76
- self .db = SQL ("mysql://root@localhost/cs50_sql_tests" )
77
-
78
- def setUp (self ):
79
- self .db .execute ("CREATE TABLE cs50 (id INTEGER NOT NULL AUTO_INCREMENT, val VARCHAR(16), PRIMARY KEY (id))" )
80
-
81
80
def tearDown (self ):
82
81
self .db .execute ("DROP TABLE cs50" )
83
82
84
83
@classmethod
85
84
def tearDownClass (self ):
86
- self .db .execute ("DROP TABLE IF EXISTS cs50" )
85
+ try :
86
+ self .db .execute ("DROP TABLE IF EXISTS cs50" )
87
+ except Warning as e :
88
+ # suppress "unknown table"
89
+ if not str (e ).startswith ("(1051" ):
90
+ raise e
87
91
88
- class PostgresTests (SQLTests ):
92
+ class MySQLTests (SQLTests ):
89
93
@classmethod
90
94
def setUpClass (self ):
91
- self .db = SQL ("postgresql ://postgres:postgres @localhost/cs50_sql_tests " )
95
+ self .db = SQL ("mysql ://root @localhost/test " )
92
96
93
97
def setUp (self ):
94
- self .db .execute ("CREATE TABLE cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16))" )
95
-
96
- def tearDown (self ):
97
- self .db .execute ("DROP TABLE cs50" )
98
+ self .db .execute ("CREATE TABLE cs50 (id INTEGER NOT NULL AUTO_INCREMENT, val VARCHAR(16), PRIMARY KEY (id))" )
98
99
100
+ class PostgresTests (SQLTests ):
99
101
@classmethod
100
- def tearDownClass (self ):
101
- self .db . execute ( "DROP TABLE IF EXISTS cs50 " )
102
+ def setUpClass (self ):
103
+ self .db = SQL ( "postgresql://postgres@localhost/test " )
102
104
103
- def test_insert_returns_last_row_id (self ):
104
- self .assertEqual (self .db .execute ("INSERT INTO cs50(val) VALUES('foo')" ), 1 )
105
- self .assertEqual (self .db .execute ("INSERT INTO cs50(val) VALUES('bar')" ), 2 )
105
+ def setUp (self ):
106
+ self .db .execute ("CREATE TABLE cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16))" )
106
107
107
108
class SQLiteTests (SQLTests ):
108
109
@classmethod
109
110
def setUpClass (self ):
110
- self .db = SQL ("sqlite:///cs50_sql_tests .db" )
111
+ self .db = SQL ("sqlite:///test .db" )
111
112
112
113
def setUp (self ):
113
114
self .db .execute ("CREATE TABLE cs50(id INTEGER PRIMARY KEY, val TEXT)" )
114
115
115
- def tearDown (self ):
116
- self .db .execute ("DROP TABLE cs50" )
117
-
118
- @classmethod
119
- def tearDownClass (self ):
120
- self .db .execute ("DROP TABLE IF EXISTS cs50" )
116
+ def multi_inserts_enabled (self ):
117
+ return False
121
118
122
119
if __name__ == "__main__" :
123
120
suite = unittest .TestSuite ([
@@ -126,4 +123,4 @@ def tearDownClass(self):
126
123
unittest .TestLoader ().loadTestsFromTestCase (PostgresTests )
127
124
])
128
125
129
- unittest .TextTestRunner (verbosity = 2 ).run (suite )
126
+ sys . exit ( not unittest .TextTestRunner (verbosity = 2 ).run (suite ). wasSuccessful () )
0 commit comments