]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Fixed async connection lock to be awaitable
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 21 Mar 2020 12:46:20 +0000 (01:46 +1300)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 21 Mar 2020 12:46:20 +0000 (01:46 +1300)
psycopg3/connection.py

index 12fbfcbe7b953d0c8394b9594c8ab58801427d06..fd14c30525045106c765ffdbe0ce3f7e213f217b 100644 (file)
@@ -5,7 +5,8 @@ psycopg3 connection objects
 # Copyright (C) 2020 The Psycopg Team
 
 import logging
-from threading import Lock
+import asyncio
+import threading
 
 from . import pq
 from . import exceptions as exc
@@ -25,7 +26,6 @@ class BaseConnection:
 
     def __init__(self, pgconn):
         self.pgconn = pgconn
-        self.lock = Lock()
 
     @classmethod
     def _connect_gen(cls, conninfo):
@@ -111,6 +111,10 @@ class Connection(BaseConnection):
     This class implements a DBAPI-compliant interface.
     """
 
+    def __init__(self, pgconn):
+        super().__init__(pgconn)
+        self.lock = threading.Lock()
+
     @classmethod
     def connect(cls, conninfo, connection_factory=None, **kwargs):
         if connection_factory is not None:
@@ -153,6 +157,10 @@ class AsyncConnection(BaseConnection):
     methods implemented as coroutines.
     """
 
+    def __init__(self, pgconn):
+        super().__init__(pgconn)
+        self.lock = asyncio.Lock()
+
     @classmethod
     async def connect(cls, conninfo, **kwargs):
         conninfo = make_conninfo(conninfo, **kwargs)
@@ -167,7 +175,7 @@ class AsyncConnection(BaseConnection):
         await self._exec_commit_rollback(b"rollback")
 
     async def _exec_commit_rollback(self, command):
-        with self.lock:
+        with await self.lock:
             status = self.pgconn.transaction_status
             if status == pq.TransactionStatus.IDLE:
                 return