diff --git a/src/cs50/sql.py b/src/cs50/sql.py index cd8ae88..bce5402 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -89,6 +89,9 @@ def connect(dbapi_connection, connection_record): finally: self._logger.disabled = disabled + # Whether we've registered a Flask teardown function for this connection + self.teardown_appcontext_added = False + def __del__(self): """Close database connection.""" if hasattr(self, "_connection"): @@ -286,11 +289,15 @@ def execute(self, sql, *args, **kwargs): # Connect now flask.g._connection = self._engine.connect() - # Disconnect later - @flask.current_app.teardown_appcontext - def shutdown_session(exception=None): - if hasattr(flask.g, "_connection"): - flask.g._connection.close() + # Disconnect later - but only once + if not self.teardown_appcontext_added: + self.teardown_appcontext_added = True + + # Register shutdown_session on app context teardown + @flask.current_app.teardown_appcontext + def shutdown_session(exception=None): + if hasattr(flask.g, "_connection"): + flask.g._connection.close() # Use this connection connection = flask.g._connection diff --git a/tests/flask/application.py b/tests/flask/application.py index 939a8f9..710a29b 100644 --- a/tests/flask/application.py +++ b/tests/flask/application.py @@ -9,7 +9,7 @@ app = Flask(__name__) -db = cs50.SQL("sqlite:///../sqlite.db") +db = cs50.SQL("sqlite:///../test.db") @app.route("/") def index():