Skip to content

Commit c4bb169

Browse files
authoredDec 14, 2020
Merge pull request #142 from cs50/scoped_session
fixing support for multithreading
2 parents 9612a28 + cb03ec1 commit c4bb169

File tree

4 files changed

+31
-29
lines changed

4 files changed

+31
-29
lines changed
 

‎.gitignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
.*
22
!.gitignore
33
!.travis.yml
4-
dist/
54
*.db
65
*.egg-info/
76
*.pyc
7+
dist/
8+
test.db

‎README.md

-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ s = cs50.get_string();
4242
```
4343
1. Run `service postgresql start`.
4444
1. Run `psql -c 'create database test;' -U postgres`.
45-
1. Run `touch test.db`.
4645

4746
### Sample Tests
4847

‎setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,5 @@
1616
package_dir={"": "src"},
1717
packages=["cs50"],
1818
url="https://github.com/cs50/python-cs50",
19-
version="6.0.1"
19+
version="6.0.2"
2020
)

‎src/cs50/sql.py

+28-26
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def __init__(self, url, **kwargs):
4343
import os
4444
import re
4545
import sqlalchemy
46+
import sqlalchemy.orm
4647
import sqlite3
4748

4849
# Require that file already exist for SQLite
@@ -72,12 +73,12 @@ def connect(dbapi_connection, connection_record):
7273
cursor.execute("PRAGMA foreign_keys=ON")
7374
cursor.close()
7475

75-
# Autocommit by default
76-
self._autocommit = True
77-
7876
# Register listener
7977
sqlalchemy.event.listen(self._engine, "connect", connect)
8078

79+
# Autocommit by default
80+
self._autocommit = True
81+
8182
# Test database
8283
disabled = self._logger.disabled
8384
self._logger.disabled = True
@@ -96,9 +97,9 @@ def __del__(self):
9697

9798
def _disconnect(self):
9899
"""Close database connection."""
99-
if hasattr(self, "_connection"):
100-
self._connection.close()
101-
delattr(self, "_connection")
100+
if hasattr(self, "_session"):
101+
self._session.remove()
102+
delattr(self, "_session")
102103

103104
@_enable_logging
104105
def execute(self, sql, *args, **kwargs):
@@ -275,33 +276,34 @@ def execute(self, sql, *args, **kwargs):
275276
# Infer whether app is defined
276277
assert flask.current_app
277278

278-
# If no connections to any databases yet
279-
if not hasattr(flask.g, "_connections"):
280-
setattr(flask.g, "_connections", {})
281-
connections = getattr(flask.g, "_connections")
279+
# If no sessions for any databases yet
280+
if not hasattr(flask.g, "_sessions"):
281+
setattr(flask.g, "_sessions", {})
282+
sessions = getattr(flask.g, "_sessions")
282283

283-
# If not yet connected to this database
284+
# If no session yet for this database
284285
# https://flask.palletsprojects.com/en/1.1.x/appcontext/#storing-data
285-
if self not in connections:
286+
# https://stackoverflow.com/a/34010159
287+
if self not in sessions:
286288

287289
# Connect to database
288-
connections[self] = self._engine.connect()
290+
sessions[self] = sqlalchemy.orm.scoping.scoped_session(sqlalchemy.orm.sessionmaker(bind=self._engine))
289291

290-
# Disconnect from database later
292+
# Remove session later
291293
if _teardown_appcontext not in flask.current_app.teardown_appcontext_funcs:
292294
flask.current_app.teardown_appcontext(_teardown_appcontext)
293295

294-
# Use this connection
295-
connection = connections[self]
296+
# Use this session
297+
session = sessions[self]
296298

297299
except (ModuleNotFoundError, AssertionError):
298300

299301
# If no connection yet
300-
if not hasattr(self, "_connection"):
301-
self._connection = self._engine.connect()
302+
if not hasattr(self, "_session"):
303+
self._session = sqlalchemy.orm.scoping.scoped_session(sqlalchemy.orm.sessionmaker(bind=self._engine))
302304

303-
# Use this connection
304-
connection = self._connection
305+
# Use this session
306+
session = self._session
305307

306308
# Catch SQLAlchemy warnings
307309
with warnings.catch_warnings():
@@ -321,10 +323,10 @@ def execute(self, sql, *args, **kwargs):
321323

322324
# Execute statement
323325
if self._autocommit:
324-
connection.execute(sqlalchemy.text("BEGIN"))
325-
result = connection.execute(sqlalchemy.text(statement))
326+
session.execute(sqlalchemy.text("BEGIN"))
327+
result = session.execute(sqlalchemy.text(statement))
326328
if self._autocommit:
327-
connection.execute(sqlalchemy.text("COMMIT"))
329+
session.execute(sqlalchemy.text("COMMIT"))
328330

329331
# Check for end of transaction
330332
if command in ["COMMIT", "ROLLBACK"]:
@@ -357,7 +359,7 @@ def execute(self, sql, *args, **kwargs):
357359
elif command == "INSERT":
358360
if self._engine.url.get_backend_name() in ["postgres", "postgresql"]:
359361
try:
360-
result = connection.execute("SELECT LASTVAL()")
362+
result = session.execute("SELECT LASTVAL()")
361363
ret = result.first()[0]
362364
except sqlalchemy.exc.OperationalError: # If lastval is not yet defined in this session
363365
ret = None
@@ -538,5 +540,5 @@ def _parse_placeholder(token):
538540
def _teardown_appcontext(exception=None):
539541
"""Closes context's database connection, if any."""
540542
import flask
541-
for connection in flask.g.pop("_connections", {}).values():
542-
connection.close()
543+
for session in flask.g.pop("_sessions", {}).values():
544+
session.remove()

0 commit comments

Comments
 (0)
Please sign in to comment.