diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 0eb0e2c..30d894b 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -26,18 +26,14 @@ jobs: python-version: '3.6' - name: Setup databases run: | - python setup.py install - pip install mysqlclient - pip install psycopg2-binary - touch test.db test1.db + pip install . + pip install mysqlclient psycopg2-binary - name: Run tests run: python tests/sql.py - name: Install pypa/build - run: | - python -m pip install build --user + run: python -m pip install build --user - name: Build a binary wheel and a source tarball - run: | - python -m build --sdist --wheel --outdir dist/ . + run: python -m build --sdist --wheel --outdir dist/ . - name: Deploy to PyPI if: ${{ github.ref == 'refs/heads/main' }} uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/.gitignore b/.gitignore index 0a2a684..0ce3062 100644 --- a/.gitignore +++ b/.gitignore @@ -5,5 +5,6 @@ *.db *.egg-info/ *.pyc +build/ dist/ test.db diff --git a/setup.py b/setup.py index de271f8..e5f01ce 100644 --- a/setup.py +++ b/setup.py @@ -16,5 +16,5 @@ package_dir={"": "src"}, packages=["cs50"], url="https://github.com/cs50/python-cs50", - version="7.0.0" + version="7.0.1" ) diff --git a/src/cs50/_engine.py b/src/cs50/_engine.py index d74992c..55489d1 100644 --- a/src/cs50/_engine.py +++ b/src/cs50/_engine.py @@ -1,4 +1,5 @@ import threading +import warnings from ._engine_util import create_engine @@ -11,6 +12,7 @@ class Engine: """ def __init__(self, url): + url = _replace_scheme_if_postgres(url) self._engine = create_engine(url) def get_transaction_connection(self): @@ -64,3 +66,23 @@ def _thread_local_connections(): connections = thread_local_data.connections = {} return connections + +def _replace_scheme_if_postgres(url): + """ + Replaces the postgres scheme with the postgresql scheme if possible since the postgres scheme + is deprecated. + + :returns: url with postgresql scheme if the scheme was postgres; otherwise returns url as is + """ + + if url.startswith("postgres://"): + with warnings.catch_warnings(): + warnings.simplefilter("always") + warnings.warn( + "The postgres:// scheme is deprecated and will not be supported in the next major" + + " release of the library. Please use the postgresql:// scheme instead.", + DeprecationWarning + ) + url = f"postgresql{url[len('postgres'):]}" + + return url diff --git a/src/cs50/_sql_sanitizer.py b/src/cs50/_sql_sanitizer.py index 17fc5fa..3803bb8 100644 --- a/src/cs50/_sql_sanitizer.py +++ b/src/cs50/_sql_sanitizer.py @@ -66,9 +66,7 @@ def escape(self, value): return sqlparse.sql.Token(sqlparse.tokens.String, string_processor(value)) if value is None: - return sqlparse.sql.Token( - sqlparse.tokens.Keyword, - sqlalchemy.types.NullType().literal_processor(self._dialect)(value)) + return sqlparse.sql.Token(sqlparse.tokens.Keyword, "NULL") raise RuntimeError(f"unsupported value: {value}") diff --git a/tests/sql.py b/tests/sql.py index 89853a7..cf8c5ae 100644 --- a/tests/sql.py +++ b/tests/sql.py @@ -9,7 +9,6 @@ class SQLTests(unittest.TestCase): - def test_multiple_statements(self): self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO cs50(val) VALUES('baz'); INSERT INTO cs50(val) VALUES('qux')") @@ -133,20 +132,10 @@ def test_identifier_case(self): self.assertIn("count", self.db.execute("SELECT 1 AS count")[0]) def tearDown(self): - self.db.execute("DROP TABLE cs50") + self.db.execute("DROP TABLE IF EXISTS cs50") self.db.execute("DROP TABLE IF EXISTS foo") self.db.execute("DROP TABLE IF EXISTS bar") - @classmethod - def tearDownClass(self): - try: - self.db.execute("DROP TABLE IF EXISTS cs50") - except Warning as e: - # suppress "unknown table" - if not str(e).startswith("(1051"): - raise e - - class MySQLTests(SQLTests): @classmethod def setUpClass(self): @@ -156,7 +145,6 @@ def setUp(self): self.db.execute("CREATE TABLE IF NOT EXISTS cs50 (id INTEGER NOT NULL AUTO_INCREMENT, val VARCHAR(16), bin BLOB, PRIMARY KEY (id))") self.db.execute("DELETE FROM cs50") - class PostgresTests(SQLTests): @classmethod def setUpClass(self): @@ -169,9 +157,11 @@ def setUp(self): def test_cte(self): self.assertEqual(self.db.execute("WITH foo AS ( SELECT 1 AS bar ) SELECT bar FROM foo"), [{"bar": 1}]) + def test_postgres_scheme(self): + db = SQL("postgres://postgres:postgres@127.0.0.1/test") + db.execute("SELECT 1") class SQLiteTests(SQLTests): - @classmethod def setUpClass(self): open("test.db", "w").close() @@ -283,7 +273,6 @@ def test_named(self): self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo VALUES (:bar, :baz)", bar='bar', baz='baz', qux='qux') self.assertRaises(RuntimeError, self.db.execute, "INSERT INTO foo VALUES (:bar, :baz)", 'baz', bar='bar') - def test_numeric(self): self.db.execute("CREATE TABLE foo (firstname STRING, lastname STRING)") @@ -319,6 +308,9 @@ def test_numeric(self): def test_cte(self): self.assertEqual(self.db.execute("WITH foo AS ( SELECT 1 AS bar ) SELECT bar FROM foo"), [{"bar": 1}]) + def test_none(self): + self.db.execute("CREATE TABLE foo (val INTEGER)") + self.db.execute("SELECT * FROM foo WHERE val = ?", None) if __name__ == "__main__": suite = unittest.TestSuite([