From 6e1798224295a0e0aaa076164694a8e262f6ca03 Mon Sep 17 00:00:00 2001
From: Joshua Archibald <jarchibald121@gmail.com>
Date: Thu, 4 Jun 2020 12:17:45 -0500
Subject: [PATCH 1/6] Use sessions to handle transactions, allowing for both
 auto and manual commit modes. Registers Flask appcontext teardown function
 only once per database instance, and also allows for multiple database
 connections in a single Flask request.

Add unit tests for SQL savepoints, autocommit mode, manual transaction
mode. Add integration tests for Flask.
---
 src/cs50/sql.py            | 66 +++++++++++++++++++++++++-------------
 tests/flask/application.py | 56 ++++++++++++++++++++++++++++++--
 tests/flask/test.py        | 49 ++++++++++++++++++++++++++++
 tests/sql.py               | 29 +++++++++++++++--
 4 files changed, 173 insertions(+), 27 deletions(-)
 create mode 100644 tests/flask/test.py

diff --git a/src/cs50/sql.py b/src/cs50/sql.py
index cd8ae88..55cf058 100644
--- a/src/cs50/sql.py
+++ b/src/cs50/sql.py
@@ -43,6 +43,7 @@ def __init__(self, url, **kwargs):
         import os
         import re
         import sqlalchemy
+        import sqlalchemy.orm as orm
         import sqlite3
 
         # Get logger
@@ -56,9 +57,16 @@ def __init__(self, url, **kwargs):
             if not os.path.isfile(matches.group(1)):
                 raise RuntimeError("not a file: {}".format(matches.group(1)))
 
+        # Record the URL (used in testing)
+        self.url = url
+
         # Create engine, disabling SQLAlchemy's own autocommit mode, raising exception if back end's module not installed
         self._engine = sqlalchemy.create_engine(url, **kwargs).execution_options(autocommit=False)
 
+        # Create a variable to hold the session. If None, autocommit is on.
+        self.Session = orm.sessionmaker(bind=self._engine)
+        self._session = None
+
         # Listener for connections
         def connect(dbapi_connection, connection_record):
 
@@ -90,9 +98,9 @@ def connect(dbapi_connection, connection_record):
             self._logger.disabled = disabled
 
     def __del__(self):
-        """Close database connection."""
-        if hasattr(self, "_connection"):
-            self._connection.close()
+        """Close database session and connection."""
+        if self._session is not None:
+            self._session.close()
 
     @_enable_logging
     def execute(self, sql, *args, **kwargs):
@@ -125,6 +133,12 @@ def execute(self, sql, *args, **kwargs):
             if token.ttype in [sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML]:
                 command = token.value.upper()
                 break
+
+            # Begin a new transaction session, if done manually
+            elif token.value.upper() in ["BEGIN", "START"]:
+                if self._session is not None:
+                    self._session.close()
+                self._session = self.Session()
         else:
             command = None
 
@@ -272,6 +286,11 @@ def execute(self, sql, *args, **kwargs):
         statement = "".join([str(token) for token in tokens])
 
         # Connect to database (for transactions' sake)
+        session = self._session
+        if session is None:
+            session = self.Session()
+
+        # Set up a Flask app teardown function to close session at teardown
         try:
 
             # Infer whether Flask is installed
@@ -280,29 +299,18 @@ def execute(self, sql, *args, **kwargs):
             # Infer whether app is defined
             assert flask.current_app
 
-            # If no connection for app's current request yet
-            if not hasattr(flask.g, "_connection"):
-
-                # Connect now
-                flask.g._connection = self._engine.connect()
+            # Disconnect later - but only once
+            if not hasattr(self, "teardown_appcontext_added"):
+                self.teardown_appcontext_added = True
 
-                # Disconnect later
+                # Register shutdown_session on app context teardown
                 @flask.current_app.teardown_appcontext
                 def shutdown_session(exception=None):
-                    if hasattr(flask.g, "_connection"):
-                        flask.g._connection.close()
-
-            # Use this connection
-            connection = flask.g._connection
+                    if self._session is not None:
+                        self._session.close()
 
         except (ModuleNotFoundError, AssertionError):
-
-            # If no connection yet
-            if not hasattr(self, "_connection"):
-                self._connection = self._engine.connect()
-
-            # Use this connection
-            connection = self._connection
+            pass
 
         # Catch SQLAlchemy warnings
         with warnings.catch_warnings():
@@ -317,7 +325,7 @@ def shutdown_session(exception=None):
                 _statement = "".join([str(bytes) if token.ttype == sqlparse.tokens.Other else str(token) for token in tokens])
 
                 # Execute statement
-                result = connection.execute(sqlalchemy.text(statement))
+                result = session.execute(sqlalchemy.text(statement))
 
                 # Return value
                 ret = True
@@ -346,7 +354,7 @@ def shutdown_session(exception=None):
                 elif command == "INSERT":
                     if self._engine.url.get_backend_name() in ["postgres", "postgresql"]:
                         try:
-                            result = connection.execute("SELECT LASTVAL()")
+                            result = session.execute("SELECT LASTVAL()")
                             ret = result.first()[0]
                         except sqlalchemy.exc.OperationalError:  # If lastval is not yet defined in this session
                             ret = None
@@ -357,6 +365,18 @@ def shutdown_session(exception=None):
                 elif command in ["DELETE", "UPDATE"]:
                     ret = result.rowcount
 
+                # If COMMIT or ROLLBACK, turn on autocommit mode
+                elif command in ["COMMIT", "ROLLBACK"] and "TO" not in statement:
+                    session.close()
+                    self._session = None
+
+
+                # If autocommit is on, commit and close
+                if self._session is None and command not in ["COMMIT", "ROLLBACK"]:
+                    if command not in ["SELECT"]:
+                        session.commit()
+                    session.close()
+
             # If constraint violated, return None
             except sqlalchemy.exc.IntegrityError as e:
                 self._logger.debug(termcolor.colored(statement, "yellow"))
diff --git a/tests/flask/application.py b/tests/flask/application.py
index 939a8f9..e3f0768 100644
--- a/tests/flask/application.py
+++ b/tests/flask/application.py
@@ -1,3 +1,5 @@
+import logging
+import os
 import requests
 import sys
 from flask import Flask, render_template
@@ -9,14 +11,64 @@
 
 app = Flask(__name__)
 
-db = cs50.SQL("sqlite:///../sqlite.db")
+logging.disable(logging.CRITICAL)
+os.environ["WERKZEUG_RUN_MAIN"] = "true"
+
+db = cs50.SQL("sqlite:///../test.db")
 
 @app.route("/")
 def index():
-    db.execute("SELECT 1")
     """
     def f():
         res = requests.get("cs50.harvard.edu")
     f()
     """
     return render_template("index.html")
+
+@app.route("/autocommit")
+def autocommit():
+    db.execute("INSERT INTO test (val) VALUES (?)", "def")
+    db2 = cs50.SQL(db.url)
+    ret = db2.execute("SELECT val FROM test WHERE val=?", "def")
+    return str(ret == [{"val": "def"}])
+
+@app.route("/create")
+def create():
+    ret = db.execute("CREATE TABLE test (id INTEGER PRIMARY KEY AUTOINCREMENT, val VARCHAR(16))")
+    return str(ret)
+
+@app.route("/delete")
+def delete():
+    ret = db.execute("DELETE FROM test")
+    return str(ret > 0)
+
+@app.route("/drop")
+def drop():
+    ret = db.execute("DROP TABLE test")
+    return str(ret)
+
+@app.route("/insert")
+def insert():
+    ret = db.execute("INSERT INTO test (val) VALUES (?)", "abc")
+    return str(ret > 0)
+
+@app.route("/multiple_connections")
+def multiple_connections():
+    ctx = len(app.teardown_appcontext_funcs)
+    db1 = cs50.SQL(db.url)
+    td1 = (len(app.teardown_appcontext_funcs) == ctx + 1)
+    db2 = cs50.SQL(db.url)
+    td2 = (len(app.teardown_appcontext_funcs) == ctx + 2)
+    return str(td1 and td2)
+
+@app.route("/select")
+def select():
+    ret = db.execute("SELECT val FROM test")
+    return str(ret == [{"val": "abc"}])
+
+@app.route("/single_teardown")
+def single_teardown():
+    db.execute("SELECT * FROM test")
+    ctx = len(app.teardown_appcontext_funcs)
+    db.execute("SELECT COUNT(id) FROM test")
+    return str(ctx == len(app.teardown_appcontext_funcs))
diff --git a/tests/flask/test.py b/tests/flask/test.py
new file mode 100644
index 0000000..9a4134b
--- /dev/null
+++ b/tests/flask/test.py
@@ -0,0 +1,49 @@
+from application import app
+import logging
+import requests
+import sys
+import threading
+import time
+import unittest
+
+
+def request(route):
+    r = requests.get("http://localhost:5000/{}".format(route))
+    return r.text == "True"
+
+class FlaskTests(unittest.TestCase):
+
+    def test__create(self):
+        self.assertTrue(request("create"))
+ 
+    def test_autocommit(self):
+        self.assertTrue(request("autocommit"))
+
+    def test_delete(self):
+        self.assertTrue(request("delete"))
+
+    def test_insert(self):
+        self.assertTrue(request("insert"))
+
+    def test_multiple_connections(self):
+        self.assertTrue(request("multiple_connections"))
+
+    def test_select(self):
+        self.assertTrue(request("select"))
+
+    def test_single_teardown(self):
+        self.assertTrue(request("single_teardown"))
+
+    def test_zdrop(self):
+        self.assertTrue(request("drop"))
+
+
+if __name__ == "__main__":
+    t = threading.Thread(target=app.run, daemon=True)
+    t.start()
+
+    suite = unittest.TestSuite([
+        unittest.TestLoader().loadTestsFromTestCase(FlaskTests)
+    ])
+
+    sys.exit(not unittest.TextTestRunner(verbosity=2).run(suite).wasSuccessful())
diff --git a/tests/sql.py b/tests/sql.py
index 9ad463f..57974a6 100644
--- a/tests/sql.py
+++ b/tests/sql.py
@@ -115,11 +115,22 @@ def test_blob(self):
             self.db.execute("INSERT INTO cs50(bin) VALUES(:bin)", bin=row["bin"])
         self.assertEqual(self.db.execute("SELECT id, bin FROM cs50"), rows)
 
+    def test_autocommit(self):
+        self.assertEqual(self.db.execute("INSERT INTO cs50(val) VALUES('foo')"), 1)
+        self.assertEqual(self.db.execute("INSERT INTO cs50(val) VALUES('bar')"), 2)
+
+        # Load a new database instance to confirm the INSERTs were committed
+        db2 = SQL(self.db.url)
+        self.assertEqual(db2.execute("DELETE FROM cs50 WHERE id < 3"), 2)
+
     def test_commit(self):
         self.db.execute("BEGIN")
         self.db.execute("INSERT INTO cs50 (val) VALUES('foo')")
         self.db.execute("COMMIT")
-        self.assertEqual(self.db.execute("SELECT val FROM cs50"), [{"val": "foo"}])
+
+        # Load a new database instance to confirm the INSERT was committed
+        db2 = SQL(self.db.url)
+        self.assertEqual(db2.execute("SELECT val FROM cs50"), [{"val": "foo"}])
 
     def test_rollback(self):
         self.db.execute("BEGIN")
@@ -128,6 +139,17 @@ def test_rollback(self):
         self.db.execute("ROLLBACK")
         self.assertEqual(self.db.execute("SELECT val FROM cs50"), [])
 
+    def test_savepoint(self):
+        self.db.execute("BEGIN")
+        self.db.execute("INSERT INTO cs50 (val) VALUES('foo')")
+        self.db.execute("SAVEPOINT sp1")
+        self.db.execute("INSERT INTO cs50 (val) VALUES('bar')")
+        self.assertEqual(self.db.execute("SELECT val FROM cs50"), [{"val": "foo"}, {"val": "bar"}])
+        self.db.execute("ROLLBACK TO sp1")
+        self.assertEqual(self.db.execute("SELECT val FROM cs50"), [{"val": "foo"}])
+        self.db.execute("ROLLBACK")
+        self.assertEqual(self.db.execute("SELECT val FROM cs50"), [])
+
     def tearDown(self):
         self.db.execute("DROP TABLE cs50")
         self.db.execute("DROP TABLE IF EXISTS foo")
@@ -146,6 +168,7 @@ class MySQLTests(SQLTests):
     @classmethod
     def setUpClass(self):
         self.db = SQL("mysql://root@localhost/test")
+        print("\nMySQL tests")
 
     def setUp(self):
         self.db.execute("CREATE TABLE cs50 (id INTEGER NOT NULL AUTO_INCREMENT, val VARCHAR(16), bin BLOB, PRIMARY KEY (id))")
@@ -153,7 +176,8 @@ def setUp(self):
 class PostgresTests(SQLTests):
     @classmethod
     def setUpClass(self):
-        self.db = SQL("postgresql://postgres@localhost/test")
+        self.db = SQL("postgresql://root:test@localhost/test")
+        print("\nPOSTGRES tests")
 
     def setUp(self):
         self.db.execute("CREATE TABLE cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16), bin BYTEA)")
@@ -166,6 +190,7 @@ class SQLiteTests(SQLTests):
     def setUpClass(self):
         open("test.db", "w").close()
         self.db = SQL("sqlite:///test.db")
+        print("\nSQLite tests")
 
     def setUp(self):
         self.db.execute("CREATE TABLE cs50(id INTEGER PRIMARY KEY, val TEXT, bin BLOB)")

From 0dff7160cc684ffb03c3baeb092e803ddff2abe0 Mon Sep 17 00:00:00 2001
From: Joshua Archibald <jarchibald121@gmail.com>
Date: Thu, 4 Jun 2020 14:29:38 -0500
Subject: [PATCH 2/6] Style fixes. Minor design improvements, including
 removing SQL class URL variable, and always committing session so as to
 release locks.

---
 src/cs50/sql.py            | 21 +++++++++------------
 tests/flask/application.py | 12 +++++++-----
 tests/flask/test.py        |  2 +-
 tests/sql.py               | 13 ++++++++-----
 4 files changed, 25 insertions(+), 23 deletions(-)

diff --git a/src/cs50/sql.py b/src/cs50/sql.py
index 55cf058..8acf194 100644
--- a/src/cs50/sql.py
+++ b/src/cs50/sql.py
@@ -43,7 +43,7 @@ def __init__(self, url, **kwargs):
         import os
         import re
         import sqlalchemy
-        import sqlalchemy.orm as orm
+        import sqlalchemy.orm
         import sqlite3
 
         # Get logger
@@ -57,14 +57,11 @@ def __init__(self, url, **kwargs):
             if not os.path.isfile(matches.group(1)):
                 raise RuntimeError("not a file: {}".format(matches.group(1)))
 
-        # Record the URL (used in testing)
-        self.url = url
-
         # Create engine, disabling SQLAlchemy's own autocommit mode, raising exception if back end's module not installed
         self._engine = sqlalchemy.create_engine(url, **kwargs).execution_options(autocommit=False)
 
         # Create a variable to hold the session. If None, autocommit is on.
-        self.Session = orm.sessionmaker(bind=self._engine)
+        self._Session = sqlalchemy.orm.session.sessionmaker(bind=self._engine)
         self._session = None
 
         # Listener for connections
@@ -101,6 +98,7 @@ def __del__(self):
         """Close database session and connection."""
         if self._session is not None:
             self._session.close()
+            self._session = None
 
     @_enable_logging
     def execute(self, sql, *args, **kwargs):
@@ -134,11 +132,11 @@ def execute(self, sql, *args, **kwargs):
                 command = token.value.upper()
                 break
 
-            # Begin a new transaction session, if done manually
+            # Begin a new session, if transaction started by caller (not using autocommit)
             elif token.value.upper() in ["BEGIN", "START"]:
                 if self._session is not None:
                     self._session.close()
-                self._session = self.Session()
+                self._session = self._Session()
         else:
             command = None
 
@@ -288,7 +286,7 @@ def execute(self, sql, *args, **kwargs):
         # Connect to database (for transactions' sake)
         session = self._session
         if session is None:
-            session = self.Session()
+            session = self._Session()
 
         # Set up a Flask app teardown function to close session at teardown
         try:
@@ -303,11 +301,12 @@ def execute(self, sql, *args, **kwargs):
             if not hasattr(self, "teardown_appcontext_added"):
                 self.teardown_appcontext_added = True
 
-                # Register shutdown_session on app context teardown
                 @flask.current_app.teardown_appcontext
                 def shutdown_session(exception=None):
+                    """Close any existing session on app context teardown."""
                     if self._session is not None:
                         self._session.close()
+                        self._session = None
 
         except (ModuleNotFoundError, AssertionError):
             pass
@@ -370,11 +369,9 @@ def shutdown_session(exception=None):
                     session.close()
                     self._session = None
 
-
                 # If autocommit is on, commit and close
                 if self._session is None and command not in ["COMMIT", "ROLLBACK"]:
-                    if command not in ["SELECT"]:
-                        session.commit()
+                    session.commit()
                     session.close()
 
             # If constraint violated, return None
diff --git a/tests/flask/application.py b/tests/flask/application.py
index e3f0768..404b1d4 100644
--- a/tests/flask/application.py
+++ b/tests/flask/application.py
@@ -2,19 +2,21 @@
 import os
 import requests
 import sys
-from flask import Flask, render_template
 
 sys.path.insert(0, "../../src")
 
 import cs50
 import cs50.flask
 
+from flask import Flask, render_template
+
 app = Flask(__name__)
 
 logging.disable(logging.CRITICAL)
 os.environ["WERKZEUG_RUN_MAIN"] = "true"
 
-db = cs50.SQL("sqlite:///../test.db")
+db_url = "sqlite:///../test.db"
+db = cs50.SQL(db_url)
 
 @app.route("/")
 def index():
@@ -28,7 +30,7 @@ def f():
 @app.route("/autocommit")
 def autocommit():
     db.execute("INSERT INTO test (val) VALUES (?)", "def")
-    db2 = cs50.SQL(db.url)
+    db2 = cs50.SQL(db_url)
     ret = db2.execute("SELECT val FROM test WHERE val=?", "def")
     return str(ret == [{"val": "def"}])
 
@@ -55,9 +57,9 @@ def insert():
 @app.route("/multiple_connections")
 def multiple_connections():
     ctx = len(app.teardown_appcontext_funcs)
-    db1 = cs50.SQL(db.url)
+    db1 = cs50.SQL(db_url)
     td1 = (len(app.teardown_appcontext_funcs) == ctx + 1)
-    db2 = cs50.SQL(db.url)
+    db2 = cs50.SQL(db_url)
     td2 = (len(app.teardown_appcontext_funcs) == ctx + 2)
     return str(td1 and td2)
 
diff --git a/tests/flask/test.py b/tests/flask/test.py
index 9a4134b..0b084d6 100644
--- a/tests/flask/test.py
+++ b/tests/flask/test.py
@@ -1,4 +1,3 @@
-from application import app
 import logging
 import requests
 import sys
@@ -6,6 +5,7 @@
 import time
 import unittest
 
+from application import app
 
 def request(route):
     r = requests.get("http://localhost:5000/{}".format(route))
diff --git a/tests/sql.py b/tests/sql.py
index 57974a6..7694ad9 100644
--- a/tests/sql.py
+++ b/tests/sql.py
@@ -120,7 +120,7 @@ def test_autocommit(self):
         self.assertEqual(self.db.execute("INSERT INTO cs50(val) VALUES('bar')"), 2)
 
         # Load a new database instance to confirm the INSERTs were committed
-        db2 = SQL(self.db.url)
+        db2 = SQL(self.db_url)
         self.assertEqual(db2.execute("DELETE FROM cs50 WHERE id < 3"), 2)
 
     def test_commit(self):
@@ -129,7 +129,7 @@ def test_commit(self):
         self.db.execute("COMMIT")
 
         # Load a new database instance to confirm the INSERT was committed
-        db2 = SQL(self.db.url)
+        db2 = SQL(self.db_url)
         self.assertEqual(db2.execute("SELECT val FROM cs50"), [{"val": "foo"}])
 
     def test_rollback(self):
@@ -167,7 +167,8 @@ def tearDownClass(self):
 class MySQLTests(SQLTests):
     @classmethod
     def setUpClass(self):
-        self.db = SQL("mysql://root@localhost/test")
+        self.db_url = "mysql://root@localhost/test"
+        self.db = SQL(self.db_url)
         print("\nMySQL tests")
 
     def setUp(self):
@@ -176,7 +177,8 @@ def setUp(self):
 class PostgresTests(SQLTests):
     @classmethod
     def setUpClass(self):
-        self.db = SQL("postgresql://root:test@localhost/test")
+        self.db_url = "postgresql://root:test@localhost/test"
+        self.db = SQL(self.db_url)
         print("\nPOSTGRES tests")
 
     def setUp(self):
@@ -189,7 +191,8 @@ class SQLiteTests(SQLTests):
     @classmethod
     def setUpClass(self):
         open("test.db", "w").close()
-        self.db = SQL("sqlite:///test.db")
+        self.db_url = "sqlite:///test.db"
+        self.db = SQL(self.db_url)
         print("\nSQLite tests")
 
     def setUp(self):

From c227d18b5200b4474eedc545b76f0143c1ca1e51 Mon Sep 17 00:00:00 2001
From: Joshua Archibald <jarchibald121@gmail.com>
Date: Thu, 4 Jun 2020 14:35:17 -0500
Subject: [PATCH 3/6] Fix sql.py in tests for Travis CI.

---
 tests/sql.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/tests/sql.py b/tests/sql.py
index 7694ad9..8742702 100644
--- a/tests/sql.py
+++ b/tests/sql.py
@@ -177,7 +177,7 @@ def setUp(self):
 class PostgresTests(SQLTests):
     @classmethod
     def setUpClass(self):
-        self.db_url = "postgresql://root:test@localhost/test"
+        self.db_url = "postgresql://postgres@localhost/test"
         self.db = SQL(self.db_url)
         print("\nPOSTGRES tests")
 

From ee4128311e8c3c6962d76f0b4c718b5eaecc5530 Mon Sep 17 00:00:00 2001
From: Joshua Archibald <jarchibald121@gmail.com>
Date: Fri, 5 Jun 2020 17:23:25 -0500
Subject: [PATCH 4/6] Requested changes to code design, including some
 renaming, a new instance variable to track transaction status, and retaining
 session between calls to execute, among other things.

---
 src/cs50/sql.py | 49 ++++++++++++++++++++++++++-----------------------
 tests/sql.py    | 14 +++++++++++++-
 2 files changed, 39 insertions(+), 24 deletions(-)

diff --git a/src/cs50/sql.py b/src/cs50/sql.py
index 8acf194..d8af011 100644
--- a/src/cs50/sql.py
+++ b/src/cs50/sql.py
@@ -63,6 +63,7 @@ def __init__(self, url, **kwargs):
         # Create a variable to hold the session. If None, autocommit is on.
         self._Session = sqlalchemy.orm.session.sessionmaker(bind=self._engine)
         self._session = None
+        self._in_transaction = False
 
         # Listener for connections
         def connect(dbapi_connection, connection_record):
@@ -96,9 +97,7 @@ def connect(dbapi_connection, connection_record):
 
     def __del__(self):
         """Close database session and connection."""
-        if self._session is not None:
-            self._session.close()
-            self._session = None
+        self._close_session()
 
     @_enable_logging
     def execute(self, sql, *args, **kwargs):
@@ -134,9 +133,9 @@ def execute(self, sql, *args, **kwargs):
 
             # Begin a new session, if transaction started by caller (not using autocommit)
             elif token.value.upper() in ["BEGIN", "START"]:
-                if self._session is not None:
-                    self._session.close()
-                self._session = self._Session()
+                if self._in_transaction:
+                    raise RuntimeError("transaction already open")
+                self._in_transaction = True
         else:
             command = None
 
@@ -284,9 +283,8 @@ def execute(self, sql, *args, **kwargs):
         statement = "".join([str(token) for token in tokens])
 
         # Connect to database (for transactions' sake)
-        session = self._session
-        if session is None:
-            session = self._Session()
+        if self._session is None:
+            self._session = self._Session()
 
         # Set up a Flask app teardown function to close session at teardown
         try:
@@ -304,9 +302,7 @@ def execute(self, sql, *args, **kwargs):
                 @flask.current_app.teardown_appcontext
                 def shutdown_session(exception=None):
                     """Close any existing session on app context teardown."""
-                    if self._session is not None:
-                        self._session.close()
-                        self._session = None
+                    self._close_session()
 
         except (ModuleNotFoundError, AssertionError):
             pass
@@ -323,8 +319,14 @@ def shutdown_session(exception=None):
                 # Join tokens into statement, abbreviating binary data as <class 'bytes'>
                 _statement = "".join([str(bytes) if token.ttype == sqlparse.tokens.Other else str(token) for token in tokens])
 
+                # If COMMIT or ROLLBACK, turn on autocommit mode
+                if command in ["COMMIT", "ROLLBACK"] and "TO" not in statement:
+                    if not self._in_transaction:
+                        raise RuntimeError("transactions must be initiated with BEGIN or START TRANSACTION")
+                    self._in_transaction = False
+
                 # Execute statement
-                result = session.execute(sqlalchemy.text(statement))
+                result = self._session.execute(sqlalchemy.text(statement))
 
                 # Return value
                 ret = True
@@ -353,7 +355,7 @@ def shutdown_session(exception=None):
                 elif command == "INSERT":
                     if self._engine.url.get_backend_name() in ["postgres", "postgresql"]:
                         try:
-                            result = session.execute("SELECT LASTVAL()")
+                            result = self._session.execute("SELECT LASTVAL()")
                             ret = result.first()[0]
                         except sqlalchemy.exc.OperationalError:  # If lastval is not yet defined in this session
                             ret = None
@@ -364,15 +366,9 @@ def shutdown_session(exception=None):
                 elif command in ["DELETE", "UPDATE"]:
                     ret = result.rowcount
 
-                # If COMMIT or ROLLBACK, turn on autocommit mode
-                elif command in ["COMMIT", "ROLLBACK"] and "TO" not in statement:
-                    session.close()
-                    self._session = None
-
-                # If autocommit is on, commit and close
-                if self._session is None and command not in ["COMMIT", "ROLLBACK"]:
-                    session.commit()
-                    session.close()
+                # If autocommit is on, commit
+                if not self._in_transaction:
+                    self._session.commit()
 
             # If constraint violated, return None
             except sqlalchemy.exc.IntegrityError as e:
@@ -393,6 +389,13 @@ def shutdown_session(exception=None):
                 self._logger.debug(termcolor.colored(_statement, "green"))
                 return ret
 
+    def _close_session(self):
+        """Closes any existing session and resets instance variables."""
+        if self._session is not None:
+            self._session.close()
+        self._session = None
+        self._in_transaction = False
+
     def _escape(self, value):
         """
         Escapes value using engine's conversion function.
diff --git a/tests/sql.py b/tests/sql.py
index 8742702..661920e 100644
--- a/tests/sql.py
+++ b/tests/sql.py
@@ -123,6 +123,12 @@ def test_autocommit(self):
         db2 = SQL(self.db_url)
         self.assertEqual(db2.execute("DELETE FROM cs50 WHERE id < 3"), 2)
 
+    def test_commit_no_transaction(self):
+        with self.assertRaises(RuntimeError):
+            self.db.execute("COMMIT")
+        with self.assertRaises(RuntimeError):
+            self.db.execute("ROLLBACK")
+
     def test_commit(self):
         self.db.execute("BEGIN")
         self.db.execute("INSERT INTO cs50 (val) VALUES('foo')")
@@ -132,6 +138,12 @@ def test_commit(self):
         db2 = SQL(self.db_url)
         self.assertEqual(db2.execute("SELECT val FROM cs50"), [{"val": "foo"}])
 
+    def test_double_begin(self):
+        self.db.execute("BEGIN")
+        with self.assertRaises(RuntimeError):
+            self.db.execute("BEGIN")
+        self.db.execute("ROLLBACK")
+
     def test_rollback(self):
         self.db.execute("BEGIN")
         self.db.execute("INSERT INTO cs50 (val) VALUES('foo')")
@@ -177,7 +189,7 @@ def setUp(self):
 class PostgresTests(SQLTests):
     @classmethod
     def setUpClass(self):
-        self.db_url = "postgresql://postgres@localhost/test"
+        self.db_url = "postgresql://root:test@localhost/test"
         self.db = SQL(self.db_url)
         print("\nPOSTGRES tests")
 

From c60d67d908e507c2fedfb3fb44829e07571ca7a6 Mon Sep 17 00:00:00 2001
From: Joshua Archibald <jarchibald121@gmail.com>
Date: Fri, 5 Jun 2020 17:24:25 -0500
Subject: [PATCH 5/6] Messed up the tests for Travis CI again. Fixed.

---
 tests/sql.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/tests/sql.py b/tests/sql.py
index 661920e..95301eb 100644
--- a/tests/sql.py
+++ b/tests/sql.py
@@ -189,7 +189,7 @@ def setUp(self):
 class PostgresTests(SQLTests):
     @classmethod
     def setUpClass(self):
-        self.db_url = "postgresql://root:test@localhost/test"
+        self.db_url = "postgresql://postgres@localhost/test"
         self.db = SQL(self.db_url)
         print("\nPOSTGRES tests")
 

From dac2ae88c7533e5aad4e9671395c591692694bed Mon Sep 17 00:00:00 2001
From: Joshua Archibald <jarchibald121@gmail.com>
Date: Wed, 10 Jun 2020 23:32:18 -0500
Subject: [PATCH 6/6] Stylistic changes.

---
 src/cs50/sql.py | 9 ++++++---
 1 file changed, 6 insertions(+), 3 deletions(-)

diff --git a/src/cs50/sql.py b/src/cs50/sql.py
index d8af011..f47e2b6 100644
--- a/src/cs50/sql.py
+++ b/src/cs50/sql.py
@@ -135,6 +135,7 @@ def execute(self, sql, *args, **kwargs):
             elif token.value.upper() in ["BEGIN", "START"]:
                 if self._in_transaction:
                     raise RuntimeError("transaction already open")
+
                 self._in_transaction = True
         else:
             command = None
@@ -296,8 +297,8 @@ def execute(self, sql, *args, **kwargs):
             assert flask.current_app
 
             # Disconnect later - but only once
-            if not hasattr(self, "teardown_appcontext_added"):
-                self.teardown_appcontext_added = True
+            if not hasattr(self, "_teardown_appcontext_added"):
+                self._teardown_appcontext_added = True
 
                 @flask.current_app.teardown_appcontext
                 def shutdown_session(exception=None):
@@ -320,9 +321,10 @@ def shutdown_session(exception=None):
                 _statement = "".join([str(bytes) if token.ttype == sqlparse.tokens.Other else str(token) for token in tokens])
 
                 # If COMMIT or ROLLBACK, turn on autocommit mode
-                if command in ["COMMIT", "ROLLBACK"] and "TO" not in statement:
+                if command in ["COMMIT", "ROLLBACK"] and "TO" not in (token.value for token in tokens):
                     if not self._in_transaction:
                         raise RuntimeError("transactions must be initiated with BEGIN or START TRANSACTION")
+
                     self._in_transaction = False
 
                 # Execute statement
@@ -393,6 +395,7 @@ def _close_session(self):
         """Closes any existing session and resets instance variables."""
         if self._session is not None:
             self._session.close()
+
         self._session = None
         self._in_transaction = False