]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
connection: Move lock acquisition out of _exec_commit_rollback()
authorDaniel Fortunov <github@danielfortunov.com>
Sat, 25 Jul 2020 11:36:47 +0000 (12:36 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 26 Oct 2020 16:46:29 +0000 (17:46 +0100)
psycopg3/psycopg3/connection.py

index 71d564ef6aa2812b0322361f52363859f56fdabb..67f9df4c1325bd41844a3b66240629e266f40ed0 100644 (file)
@@ -265,24 +265,26 @@ class Connection(BaseConnection):
             )
 
     def commit(self) -> None:
-        self._exec_commit_rollback(b"commit")
+        with self.lock:
+            self._exec_commit_rollback(b"commit")
 
     def rollback(self) -> None:
-        self._exec_commit_rollback(b"rollback")
+        with self.lock:
+            self._exec_commit_rollback(b"rollback")
 
     def _exec_commit_rollback(self, command: bytes) -> None:
-        with self.lock:
-            status = self.pgconn.transaction_status
-            if status == TransactionStatus.IDLE:
-                return
-
-            self.pgconn.send_query(command)
-            (pgres,) = self.wait(execute(self.pgconn))
-            if pgres.status != ExecStatus.COMMAND_OK:
-                raise e.OperationalError(
-                    f"error on {command.decode('utf8')}:"
-                    f" {pq.error_message(pgres, encoding=self.codec.name)}"
-                )
+        # Caller must hold self.lock
+        status = self.pgconn.transaction_status
+        if status == TransactionStatus.IDLE:
+            return
+
+        self.pgconn.send_query(command)
+        (pgres,) = self.wait(execute(self.pgconn))
+        if pgres.status != ExecStatus.COMMAND_OK:
+            raise e.OperationalError(
+                f"error on {command.decode('utf8')}:"
+                f" {pq.error_message(pgres, encoding=self.codec.name)}"
+            )
 
     @classmethod
     def wait(
@@ -377,24 +379,26 @@ class AsyncConnection(BaseConnection):
             )
 
     async def commit(self) -> None:
-        await self._exec_commit_rollback(b"commit")
+        async with self.lock:
+            await self._exec_commit_rollback(b"commit")
 
     async def rollback(self) -> None:
-        await self._exec_commit_rollback(b"rollback")
+        async with self.lock:
+            await self._exec_commit_rollback(b"rollback")
 
     async def _exec_commit_rollback(self, command: bytes) -> None:
-        async with self.lock:
-            status = self.pgconn.transaction_status
-            if status == TransactionStatus.IDLE:
-                return
-
-            self.pgconn.send_query(command)
-            (pgres,) = await self.wait(execute(self.pgconn))
-            if pgres.status != ExecStatus.COMMAND_OK:
-                raise e.OperationalError(
-                    f"error on {command.decode('utf8')}:"
-                    f" {pq.error_message(pgres, encoding=self.codec.name)}"
-                )
+        # Caller must hold self.lock
+        status = self.pgconn.transaction_status
+        if status == TransactionStatus.IDLE:
+            return
+
+        self.pgconn.send_query(command)
+        (pgres,) = await self.wait(execute(self.pgconn))
+        if pgres.status != ExecStatus.COMMAND_OK:
+            raise e.OperationalError(
+                f"error on {command.decode('utf8')}:"
+                f" {pq.error_message(pgres, encoding=self.codec.name)}"
+            )
 
     @classmethod
     async def wait(cls, gen: proto.PQGen[proto.RV]) -> proto.RV: