diff --git a/.gitignore b/.gitignore index 5a13495..65f1e1f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,8 @@ .* !.gitignore !.travis.yml -dist/ *.db *.egg-info/ *.pyc +dist/ +test.db diff --git a/README.md b/README.md index 0fb6d64..3d6eed8 100644 --- a/README.md +++ b/README.md @@ -42,7 +42,6 @@ s = cs50.get_string(); ``` 1. Run `service postgresql start`. 1. Run `psql -c 'create database test;' -U postgres`. -1. Run `touch test.db`. ### Sample Tests diff --git a/setup.py b/setup.py index c8a5f5b..d7cd3f2 100644 --- a/setup.py +++ b/setup.py @@ -16,5 +16,5 @@ package_dir={"": "src"}, packages=["cs50"], url="https://github.com/cs50/python-cs50", - version="6.0.1" + version="6.0.2" ) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index e405dc9..1ced4b3 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -43,6 +43,7 @@ def __init__(self, url, **kwargs): import os import re import sqlalchemy + import sqlalchemy.orm import sqlite3 # Require that file already exist for SQLite @@ -72,12 +73,12 @@ def connect(dbapi_connection, connection_record): cursor.execute("PRAGMA foreign_keys=ON") cursor.close() - # Autocommit by default - self._autocommit = True - # Register listener sqlalchemy.event.listen(self._engine, "connect", connect) + # Autocommit by default + self._autocommit = True + # Test database disabled = self._logger.disabled self._logger.disabled = True @@ -96,9 +97,9 @@ def __del__(self): def _disconnect(self): """Close database connection.""" - if hasattr(self, "_connection"): - self._connection.close() - delattr(self, "_connection") + if hasattr(self, "_session"): + self._session.remove() + delattr(self, "_session") @_enable_logging def execute(self, sql, *args, **kwargs): @@ -275,33 +276,34 @@ def execute(self, sql, *args, **kwargs): # Infer whether app is defined assert flask.current_app - # If no connections to any databases yet - if not hasattr(flask.g, "_connections"): - setattr(flask.g, "_connections", {}) - connections = getattr(flask.g, "_connections") + # If no sessions for any databases yet + if not hasattr(flask.g, "_sessions"): + setattr(flask.g, "_sessions", {}) + sessions = getattr(flask.g, "_sessions") - # If not yet connected to this database + # If no session yet for this database # https://flask.palletsprojects.com/en/1.1.x/appcontext/#storing-data - if self not in connections: + # https://stackoverflow.com/a/34010159 + if self not in sessions: # Connect to database - connections[self] = self._engine.connect() + sessions[self] = sqlalchemy.orm.scoping.scoped_session(sqlalchemy.orm.sessionmaker(bind=self._engine)) - # Disconnect from database later + # Remove session later if _teardown_appcontext not in flask.current_app.teardown_appcontext_funcs: flask.current_app.teardown_appcontext(_teardown_appcontext) - # Use this connection - connection = connections[self] + # Use this session + session = sessions[self] except (ModuleNotFoundError, AssertionError): # If no connection yet - if not hasattr(self, "_connection"): - self._connection = self._engine.connect() + if not hasattr(self, "_session"): + self._session = sqlalchemy.orm.scoping.scoped_session(sqlalchemy.orm.sessionmaker(bind=self._engine)) - # Use this connection - connection = self._connection + # Use this session + session = self._session # Catch SQLAlchemy warnings with warnings.catch_warnings(): @@ -321,10 +323,10 @@ def execute(self, sql, *args, **kwargs): # Execute statement if self._autocommit: - connection.execute(sqlalchemy.text("BEGIN")) - result = connection.execute(sqlalchemy.text(statement)) + session.execute(sqlalchemy.text("BEGIN")) + result = session.execute(sqlalchemy.text(statement)) if self._autocommit: - connection.execute(sqlalchemy.text("COMMIT")) + session.execute(sqlalchemy.text("COMMIT")) # Check for end of transaction if command in ["COMMIT", "ROLLBACK"]: @@ -357,7 +359,7 @@ def execute(self, sql, *args, **kwargs): elif command == "INSERT": if self._engine.url.get_backend_name() in ["postgres", "postgresql"]: try: - result = connection.execute("SELECT LASTVAL()") + result = session.execute("SELECT LASTVAL()") ret = result.first()[0] except sqlalchemy.exc.OperationalError: # If lastval is not yet defined in this session ret = None @@ -538,5 +540,5 @@ def _parse_placeholder(token): def _teardown_appcontext(exception=None): """Closes context's database connection, if any.""" import flask - for connection in flask.g.pop("_connections", {}).values(): - connection.close() + for session in flask.g.pop("_sessions", {}).values(): + session.remove()