]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Setting autocommit made thread safe
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 25 Jul 2020 10:52:07 +0000 (11:52 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 26 Oct 2020 16:18:18 +0000 (17:18 +0100)
Using a setter on sync connections, a function on async connections.
Should review why this approach was abandoned for the client encoding,
because it seems working alright.

psycopg3/psycopg3/connection.py
tests/test_connection_async.py

index 8529b44462f83b06bc329410192c2cf5f2a15dc5..bac9a8a0abbb3fef86ced7bed635e8fb02c0547b 100644 (file)
@@ -105,10 +105,15 @@ class BaseConnection:
 
     @autocommit.setter
     def autocommit(self, value: bool) -> None:
+        self._set_autocommit(value)
+
+    def _set_autocommit(self, value: bool) -> None:
+        # Base implementation, not thread safe
+        # subclasses must call it holding a lock
         status = self.pgconn.transaction_status
         if status != TransactionStatus.IDLE:
             raise e.ProgrammingError(
-                "can't change autocommit state: connection in"
+                "couldn't change autocommit state: connection in"
                 f" transaction status {TransactionStatus(status).name}"
             )
         self._autocommit = value
@@ -302,6 +307,10 @@ class Connection(BaseConnection):
                     yield None  # for the send who stopped us
                     return
 
+    def _set_autocommit(self, value: bool) -> None:
+        with self.lock:
+            super()._set_autocommit(value)
+
 
 class AsyncConnection(BaseConnection):
     """
@@ -407,3 +416,15 @@ class AsyncConnection(BaseConnection):
                 if (yield n):
                     yield None
                     return
+
+    def _set_autocommit(self, value: bool) -> None:
+        raise AttributeError(
+            "autocommit is read-only on async connections:"
+            " please use await connection.set_autocommit() instead."
+            " Note that you can pass an 'autocommit' value to 'connect()'"
+            " if it doesn't need to change during the connection's lifetime."
+        )
+
+    async def set_autocommit(self, value: bool) -> None:
+        async with self.lock:
+            super()._set_autocommit(value)
index ca4ae75cfbdf3d7572c9abc915bb667c500f76b4..e62451b946a8e4100747e6c4d4229ce339d2f80c 100644 (file)
@@ -120,7 +120,11 @@ async def test_auto_transaction_fail(aconn):
 
 async def test_autocommit(aconn):
     assert aconn.autocommit is False
-    aconn.autocommit = True
+    with pytest.raises(TypeError):
+        aconn.autocommit = True
+    assert not aconn.autocommit
+
+    await aconn.set_autocommit(True)
     assert aconn.autocommit
     cur = aconn.cursor()
     await cur.execute("select 1")
@@ -139,7 +143,7 @@ async def test_autocommit_intrans(aconn):
     assert await cur.fetchone() == (1,)
     assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS
     with pytest.raises(psycopg3.ProgrammingError):
-        aconn.autocommit = True
+        await aconn.set_autocommit(True)
     assert not aconn.autocommit
 
 
@@ -149,7 +153,7 @@ async def test_autocommit_inerror(aconn):
         await cur.execute("meh")
     assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INERROR
     with pytest.raises(psycopg3.ProgrammingError):
-        aconn.autocommit = True
+        await aconn.set_autocommit(True)
     assert not aconn.autocommit
 
 
@@ -157,7 +161,7 @@ async def test_autocommit_unknown(aconn):
     await aconn.close()
     assert aconn.pgconn.transaction_status == aconn.TransactionStatus.UNKNOWN
     with pytest.raises(psycopg3.ProgrammingError):
-        aconn.autocommit = True
+        await aconn.set_autocommit(True)
     assert not aconn.autocommit
 
 
@@ -326,7 +330,7 @@ async def test_notify_handlers(aconn):
     aconn.add_notify_handler(cb1)
     aconn.add_notify_handler(lambda n: nots2.append(n))
 
-    aconn.autocommit = True
+    await aconn.set_autocommit(True)
     cur = aconn.cursor()
     await cur.execute("listen foo")
     await cur.execute("notify foo, 'n1'")