diff --git a/src/cs50/sql.py b/src/cs50/sql.py
index cd8ae88..f47e2b6 100644
--- a/src/cs50/sql.py
+++ b/src/cs50/sql.py
@@ -43,6 +43,7 @@ def __init__(self, url, **kwargs):
         import os
         import re
         import sqlalchemy
+        import sqlalchemy.orm
         import sqlite3
 
         # Get logger
@@ -59,6 +60,11 @@ def __init__(self, url, **kwargs):
         # Create engine, disabling SQLAlchemy's own autocommit mode, raising exception if back end's module not installed
         self._engine = sqlalchemy.create_engine(url, **kwargs).execution_options(autocommit=False)
 
+        # Create a variable to hold the session. If None, autocommit is on.
+        self._Session = sqlalchemy.orm.session.sessionmaker(bind=self._engine)
+        self._session = None
+        self._in_transaction = False
+
         # Listener for connections
         def connect(dbapi_connection, connection_record):
 
@@ -90,9 +96,8 @@ def connect(dbapi_connection, connection_record):
             self._logger.disabled = disabled
 
     def __del__(self):
-        """Close database connection."""
-        if hasattr(self, "_connection"):
-            self._connection.close()
+        """Close database session and connection."""
+        self._close_session()
 
     @_enable_logging
     def execute(self, sql, *args, **kwargs):
@@ -125,6 +130,13 @@ def execute(self, sql, *args, **kwargs):
             if token.ttype in [sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML]:
                 command = token.value.upper()
                 break
+
+            # Begin a new session, if transaction started by caller (not using autocommit)
+            elif token.value.upper() in ["BEGIN", "START"]:
+                if self._in_transaction:
+                    raise RuntimeError("transaction already open")
+
+                self._in_transaction = True
         else:
             command = None
 
@@ -272,6 +284,10 @@ def execute(self, sql, *args, **kwargs):
         statement = "".join([str(token) for token in tokens])
 
         # Connect to database (for transactions' sake)
+        if self._session is None:
+            self._session = self._Session()
+
+        # Set up a Flask app teardown function to close session at teardown
         try:
 
             # Infer whether Flask is installed
@@ -280,29 +296,17 @@ def execute(self, sql, *args, **kwargs):
             # Infer whether app is defined
             assert flask.current_app
 
-            # If no connection for app's current request yet
-            if not hasattr(flask.g, "_connection"):
+            # Disconnect later - but only once
+            if not hasattr(self, "_teardown_appcontext_added"):
+                self._teardown_appcontext_added = True
 
-                # Connect now
-                flask.g._connection = self._engine.connect()
-
-                # Disconnect later
                 @flask.current_app.teardown_appcontext
                 def shutdown_session(exception=None):
-                    if hasattr(flask.g, "_connection"):
-                        flask.g._connection.close()
-
-            # Use this connection
-            connection = flask.g._connection
+                    """Close any existing session on app context teardown."""
+                    self._close_session()
 
         except (ModuleNotFoundError, AssertionError):
-
-            # If no connection yet
-            if not hasattr(self, "_connection"):
-                self._connection = self._engine.connect()
-
-            # Use this connection
-            connection = self._connection
+            pass
 
         # Catch SQLAlchemy warnings
         with warnings.catch_warnings():
@@ -316,8 +320,15 @@ def shutdown_session(exception=None):
                 # Join tokens into statement, abbreviating binary data as <class 'bytes'>
                 _statement = "".join([str(bytes) if token.ttype == sqlparse.tokens.Other else str(token) for token in tokens])
 
+                # If COMMIT or ROLLBACK, turn on autocommit mode
+                if command in ["COMMIT", "ROLLBACK"] and "TO" not in (token.value for token in tokens):
+                    if not self._in_transaction:
+                        raise RuntimeError("transactions must be initiated with BEGIN or START TRANSACTION")
+
+                    self._in_transaction = False
+
                 # Execute statement
-                result = connection.execute(sqlalchemy.text(statement))
+                result = self._session.execute(sqlalchemy.text(statement))
 
                 # Return value
                 ret = True
@@ -346,7 +357,7 @@ def shutdown_session(exception=None):
                 elif command == "INSERT":
                     if self._engine.url.get_backend_name() in ["postgres", "postgresql"]:
                         try:
-                            result = connection.execute("SELECT LASTVAL()")
+                            result = self._session.execute("SELECT LASTVAL()")
                             ret = result.first()[0]
                         except sqlalchemy.exc.OperationalError:  # If lastval is not yet defined in this session
                             ret = None
@@ -357,6 +368,10 @@ def shutdown_session(exception=None):
                 elif command in ["DELETE", "UPDATE"]:
                     ret = result.rowcount
 
+                # If autocommit is on, commit
+                if not self._in_transaction:
+                    self._session.commit()
+
             # If constraint violated, return None
             except sqlalchemy.exc.IntegrityError as e:
                 self._logger.debug(termcolor.colored(statement, "yellow"))
@@ -376,6 +391,14 @@ def shutdown_session(exception=None):
                 self._logger.debug(termcolor.colored(_statement, "green"))
                 return ret
 
+    def _close_session(self):
+        """Closes any existing session and resets instance variables."""
+        if self._session is not None:
+            self._session.close()
+
+        self._session = None
+        self._in_transaction = False
+
     def _escape(self, value):
         """
         Escapes value using engine's conversion function.
diff --git a/tests/flask/application.py b/tests/flask/application.py
index 939a8f9..404b1d4 100644
--- a/tests/flask/application.py
+++ b/tests/flask/application.py
@@ -1,22 +1,76 @@
+import logging
+import os
 import requests
 import sys
-from flask import Flask, render_template
 
 sys.path.insert(0, "../../src")
 
 import cs50
 import cs50.flask
 
+from flask import Flask, render_template
+
 app = Flask(__name__)
 
-db = cs50.SQL("sqlite:///../sqlite.db")
+logging.disable(logging.CRITICAL)
+os.environ["WERKZEUG_RUN_MAIN"] = "true"
+
+db_url = "sqlite:///../test.db"
+db = cs50.SQL(db_url)
 
 @app.route("/")
 def index():
-    db.execute("SELECT 1")
     """
     def f():
         res = requests.get("cs50.harvard.edu")
     f()
     """
     return render_template("index.html")
+
+@app.route("/autocommit")
+def autocommit():
+    db.execute("INSERT INTO test (val) VALUES (?)", "def")
+    db2 = cs50.SQL(db_url)
+    ret = db2.execute("SELECT val FROM test WHERE val=?", "def")
+    return str(ret == [{"val": "def"}])
+
+@app.route("/create")
+def create():
+    ret = db.execute("CREATE TABLE test (id INTEGER PRIMARY KEY AUTOINCREMENT, val VARCHAR(16))")
+    return str(ret)
+
+@app.route("/delete")
+def delete():
+    ret = db.execute("DELETE FROM test")
+    return str(ret > 0)
+
+@app.route("/drop")
+def drop():
+    ret = db.execute("DROP TABLE test")
+    return str(ret)
+
+@app.route("/insert")
+def insert():
+    ret = db.execute("INSERT INTO test (val) VALUES (?)", "abc")
+    return str(ret > 0)
+
+@app.route("/multiple_connections")
+def multiple_connections():
+    ctx = len(app.teardown_appcontext_funcs)
+    db1 = cs50.SQL(db_url)
+    td1 = (len(app.teardown_appcontext_funcs) == ctx + 1)
+    db2 = cs50.SQL(db_url)
+    td2 = (len(app.teardown_appcontext_funcs) == ctx + 2)
+    return str(td1 and td2)
+
+@app.route("/select")
+def select():
+    ret = db.execute("SELECT val FROM test")
+    return str(ret == [{"val": "abc"}])
+
+@app.route("/single_teardown")
+def single_teardown():
+    db.execute("SELECT * FROM test")
+    ctx = len(app.teardown_appcontext_funcs)
+    db.execute("SELECT COUNT(id) FROM test")
+    return str(ctx == len(app.teardown_appcontext_funcs))
diff --git a/tests/flask/test.py b/tests/flask/test.py
new file mode 100644
index 0000000..0b084d6
--- /dev/null
+++ b/tests/flask/test.py
@@ -0,0 +1,49 @@
+import logging
+import requests
+import sys
+import threading
+import time
+import unittest
+
+from application import app
+
+def request(route):
+    r = requests.get("http://localhost:5000/{}".format(route))
+    return r.text == "True"
+
+class FlaskTests(unittest.TestCase):
+
+    def test__create(self):
+        self.assertTrue(request("create"))
+ 
+    def test_autocommit(self):
+        self.assertTrue(request("autocommit"))
+
+    def test_delete(self):
+        self.assertTrue(request("delete"))
+
+    def test_insert(self):
+        self.assertTrue(request("insert"))
+
+    def test_multiple_connections(self):
+        self.assertTrue(request("multiple_connections"))
+
+    def test_select(self):
+        self.assertTrue(request("select"))
+
+    def test_single_teardown(self):
+        self.assertTrue(request("single_teardown"))
+
+    def test_zdrop(self):
+        self.assertTrue(request("drop"))
+
+
+if __name__ == "__main__":
+    t = threading.Thread(target=app.run, daemon=True)
+    t.start()
+
+    suite = unittest.TestSuite([
+        unittest.TestLoader().loadTestsFromTestCase(FlaskTests)
+    ])
+
+    sys.exit(not unittest.TextTestRunner(verbosity=2).run(suite).wasSuccessful())
diff --git a/tests/sql.py b/tests/sql.py
index 9ad463f..95301eb 100644
--- a/tests/sql.py
+++ b/tests/sql.py
@@ -115,11 +115,34 @@ def test_blob(self):
             self.db.execute("INSERT INTO cs50(bin) VALUES(:bin)", bin=row["bin"])
         self.assertEqual(self.db.execute("SELECT id, bin FROM cs50"), rows)
 
+    def test_autocommit(self):
+        self.assertEqual(self.db.execute("INSERT INTO cs50(val) VALUES('foo')"), 1)
+        self.assertEqual(self.db.execute("INSERT INTO cs50(val) VALUES('bar')"), 2)
+
+        # Load a new database instance to confirm the INSERTs were committed
+        db2 = SQL(self.db_url)
+        self.assertEqual(db2.execute("DELETE FROM cs50 WHERE id < 3"), 2)
+
+    def test_commit_no_transaction(self):
+        with self.assertRaises(RuntimeError):
+            self.db.execute("COMMIT")
+        with self.assertRaises(RuntimeError):
+            self.db.execute("ROLLBACK")
+
     def test_commit(self):
         self.db.execute("BEGIN")
         self.db.execute("INSERT INTO cs50 (val) VALUES('foo')")
         self.db.execute("COMMIT")
-        self.assertEqual(self.db.execute("SELECT val FROM cs50"), [{"val": "foo"}])
+
+        # Load a new database instance to confirm the INSERT was committed
+        db2 = SQL(self.db_url)
+        self.assertEqual(db2.execute("SELECT val FROM cs50"), [{"val": "foo"}])
+
+    def test_double_begin(self):
+        self.db.execute("BEGIN")
+        with self.assertRaises(RuntimeError):
+            self.db.execute("BEGIN")
+        self.db.execute("ROLLBACK")
 
     def test_rollback(self):
         self.db.execute("BEGIN")
@@ -128,6 +151,17 @@ def test_rollback(self):
         self.db.execute("ROLLBACK")
         self.assertEqual(self.db.execute("SELECT val FROM cs50"), [])
 
+    def test_savepoint(self):
+        self.db.execute("BEGIN")
+        self.db.execute("INSERT INTO cs50 (val) VALUES('foo')")
+        self.db.execute("SAVEPOINT sp1")
+        self.db.execute("INSERT INTO cs50 (val) VALUES('bar')")
+        self.assertEqual(self.db.execute("SELECT val FROM cs50"), [{"val": "foo"}, {"val": "bar"}])
+        self.db.execute("ROLLBACK TO sp1")
+        self.assertEqual(self.db.execute("SELECT val FROM cs50"), [{"val": "foo"}])
+        self.db.execute("ROLLBACK")
+        self.assertEqual(self.db.execute("SELECT val FROM cs50"), [])
+
     def tearDown(self):
         self.db.execute("DROP TABLE cs50")
         self.db.execute("DROP TABLE IF EXISTS foo")
@@ -145,7 +179,9 @@ def tearDownClass(self):
 class MySQLTests(SQLTests):
     @classmethod
     def setUpClass(self):
-        self.db = SQL("mysql://root@localhost/test")
+        self.db_url = "mysql://root@localhost/test"
+        self.db = SQL(self.db_url)
+        print("\nMySQL tests")
 
     def setUp(self):
         self.db.execute("CREATE TABLE cs50 (id INTEGER NOT NULL AUTO_INCREMENT, val VARCHAR(16), bin BLOB, PRIMARY KEY (id))")
@@ -153,7 +189,9 @@ def setUp(self):
 class PostgresTests(SQLTests):
     @classmethod
     def setUpClass(self):
-        self.db = SQL("postgresql://postgres@localhost/test")
+        self.db_url = "postgresql://postgres@localhost/test"
+        self.db = SQL(self.db_url)
+        print("\nPOSTGRES tests")
 
     def setUp(self):
         self.db.execute("CREATE TABLE cs50 (id SERIAL PRIMARY KEY, val VARCHAR(16), bin BYTEA)")
@@ -165,7 +203,9 @@ class SQLiteTests(SQLTests):
     @classmethod
     def setUpClass(self):
         open("test.db", "w").close()
-        self.db = SQL("sqlite:///test.db")
+        self.db_url = "sqlite:///test.db"
+        self.db = SQL(self.db_url)
+        print("\nSQLite tests")
 
     def setUp(self):
         self.db.execute("CREATE TABLE cs50(id INTEGER PRIMARY KEY, val TEXT, bin BLOB)")