)
def commit(self) -> None:
- self._exec_commit_rollback(b"commit")
+ with self.lock:
+ self._exec_commit_rollback(b"commit")
def rollback(self) -> None:
- self._exec_commit_rollback(b"rollback")
+ with self.lock:
+ self._exec_commit_rollback(b"rollback")
def _exec_commit_rollback(self, command: bytes) -> None:
- with self.lock:
- status = self.pgconn.transaction_status
- if status == TransactionStatus.IDLE:
- return
-
- self.pgconn.send_query(command)
- (pgres,) = self.wait(execute(self.pgconn))
- if pgres.status != ExecStatus.COMMAND_OK:
- raise e.OperationalError(
- f"error on {command.decode('utf8')}:"
- f" {pq.error_message(pgres, encoding=self.codec.name)}"
- )
+ # Caller must hold self.lock
+ status = self.pgconn.transaction_status
+ if status == TransactionStatus.IDLE:
+ return
+
+ self.pgconn.send_query(command)
+ (pgres,) = self.wait(execute(self.pgconn))
+ if pgres.status != ExecStatus.COMMAND_OK:
+ raise e.OperationalError(
+ f"error on {command.decode('utf8')}:"
+ f" {pq.error_message(pgres, encoding=self.codec.name)}"
+ )
@classmethod
def wait(
)
async def commit(self) -> None:
- await self._exec_commit_rollback(b"commit")
+ async with self.lock:
+ await self._exec_commit_rollback(b"commit")
async def rollback(self) -> None:
- await self._exec_commit_rollback(b"rollback")
+ async with self.lock:
+ await self._exec_commit_rollback(b"rollback")
async def _exec_commit_rollback(self, command: bytes) -> None:
- async with self.lock:
- status = self.pgconn.transaction_status
- if status == TransactionStatus.IDLE:
- return
-
- self.pgconn.send_query(command)
- (pgres,) = await self.wait(execute(self.pgconn))
- if pgres.status != ExecStatus.COMMAND_OK:
- raise e.OperationalError(
- f"error on {command.decode('utf8')}:"
- f" {pq.error_message(pgres, encoding=self.codec.name)}"
- )
+ # Caller must hold self.lock
+ status = self.pgconn.transaction_status
+ if status == TransactionStatus.IDLE:
+ return
+
+ self.pgconn.send_query(command)
+ (pgres,) = await self.wait(execute(self.pgconn))
+ if pgres.status != ExecStatus.COMMAND_OK:
+ raise e.OperationalError(
+ f"error on {command.decode('utf8')}:"
+ f" {pq.error_message(pgres, encoding=self.codec.name)}"
+ )
@classmethod
async def wait(cls, gen: proto.PQGen[proto.RV]) -> proto.RV: