Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixing support for multithreading #142

Merged
merged 2 commits into from
Dec 14, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
.*
!.gitignore
!.travis.yml
dist/
*.db
*.egg-info/
*.pyc
dist/
test.db
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -42,7 +42,6 @@ s = cs50.get_string();
```
1. Run `service postgresql start`.
1. Run `psql -c 'create database test;' -U postgres`.
1. Run `touch test.db`.

### Sample Tests

2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -16,5 +16,5 @@
package_dir={"": "src"},
packages=["cs50"],
url="https://github.com/cs50/python-cs50",
version="6.0.1"
version="6.0.2"
)
54 changes: 28 additions & 26 deletions src/cs50/sql.py
Original file line number Diff line number Diff line change
@@ -43,6 +43,7 @@ def __init__(self, url, **kwargs):
import os
import re
import sqlalchemy
import sqlalchemy.orm
import sqlite3

# Require that file already exist for SQLite
@@ -72,12 +73,12 @@ def connect(dbapi_connection, connection_record):
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()

# Autocommit by default
self._autocommit = True

# Register listener
sqlalchemy.event.listen(self._engine, "connect", connect)

# Autocommit by default
self._autocommit = True

# Test database
disabled = self._logger.disabled
self._logger.disabled = True
@@ -96,9 +97,9 @@ def __del__(self):

def _disconnect(self):
"""Close database connection."""
if hasattr(self, "_connection"):
self._connection.close()
delattr(self, "_connection")
if hasattr(self, "_session"):
self._session.remove()
delattr(self, "_session")

@_enable_logging
def execute(self, sql, *args, **kwargs):
@@ -275,33 +276,34 @@ def execute(self, sql, *args, **kwargs):
# Infer whether app is defined
assert flask.current_app

# If no connections to any databases yet
if not hasattr(flask.g, "_connections"):
setattr(flask.g, "_connections", {})
connections = getattr(flask.g, "_connections")
# If no sessions for any databases yet
if not hasattr(flask.g, "_sessions"):
setattr(flask.g, "_sessions", {})
sessions = getattr(flask.g, "_sessions")

# If not yet connected to this database
# If no session yet for this database
# https://flask.palletsprojects.com/en/1.1.x/appcontext/#storing-data
if self not in connections:
# https://stackoverflow.com/a/34010159
if self not in sessions:

# Connect to database
connections[self] = self._engine.connect()
sessions[self] = sqlalchemy.orm.scoping.scoped_session(sqlalchemy.orm.sessionmaker(bind=self._engine))

# Disconnect from database later
# Remove session later
if _teardown_appcontext not in flask.current_app.teardown_appcontext_funcs:
flask.current_app.teardown_appcontext(_teardown_appcontext)

# Use this connection
connection = connections[self]
# Use this session
session = sessions[self]

except (ModuleNotFoundError, AssertionError):

# If no connection yet
if not hasattr(self, "_connection"):
self._connection = self._engine.connect()
if not hasattr(self, "_session"):
self._session = sqlalchemy.orm.scoping.scoped_session(sqlalchemy.orm.sessionmaker(bind=self._engine))

# Use this connection
connection = self._connection
# Use this session
session = self._session

# Catch SQLAlchemy warnings
with warnings.catch_warnings():
@@ -321,10 +323,10 @@ def execute(self, sql, *args, **kwargs):

# Execute statement
if self._autocommit:
connection.execute(sqlalchemy.text("BEGIN"))
result = connection.execute(sqlalchemy.text(statement))
session.execute(sqlalchemy.text("BEGIN"))
result = session.execute(sqlalchemy.text(statement))
if self._autocommit:
connection.execute(sqlalchemy.text("COMMIT"))
session.execute(sqlalchemy.text("COMMIT"))

# Check for end of transaction
if command in ["COMMIT", "ROLLBACK"]:
@@ -357,7 +359,7 @@ def execute(self, sql, *args, **kwargs):
elif command == "INSERT":
if self._engine.url.get_backend_name() in ["postgres", "postgresql"]:
try:
result = connection.execute("SELECT LASTVAL()")
result = session.execute("SELECT LASTVAL()")
ret = result.first()[0]
except sqlalchemy.exc.OperationalError: # If lastval is not yet defined in this session
ret = None
@@ -538,5 +540,5 @@ def _parse_placeholder(token):
def _teardown_appcontext(exception=None):
"""Closes context's database connection, if any."""
import flask
for connection in flask.g.pop("_connections", {}).values():
connection.close()
for session in flask.g.pop("_sessions", {}).values():
session.remove()