]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
refactor(pool): give sync attribute the same life cycle of the async one
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 5 Oct 2023 00:43:50 +0000 (02:43 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 11 Oct 2023 21:45:39 +0000 (23:45 +0200)
Create possibly async objects only after we are sure that a loop is
running.

psycopg_pool/psycopg_pool/pool.py
psycopg_pool/psycopg_pool/pool_async.py

index b50e9b44e31fb3196e202552bed548d06e357df3..9159b85ca9871bf2af1e5c2e738c29b3535954b7 100644 (file)
@@ -107,15 +107,16 @@ class ConnectionPool(Generic[CT], BasePool):
 
         self._reconnect_failed = reconnect_failed
 
-        self._lock = Lock()
+        self._lock: Lock
+        self._sched: Scheduler
+        self._tasks: Queue[MaintenanceTask]
+
         self._waiting = Deque[WaitingClient[CT]]()
 
         # to notify that the pool is full
         self._pool_full_event: Optional[Event] = None
 
-        self._sched = Scheduler()
         self._sched_runner: Optional[threading.Thread] = None
-        self._tasks: Queue[MaintenanceTask] = Queue()
         self._workers: List[threading.Thread] = []
 
         super().__init__(
@@ -323,6 +324,8 @@ class ConnectionPool(Generic[CT], BasePool):
         because the pool was initialized with *open* = `!True`) but you cannot
         currently re-open a closed pool.
         """
+        self._ensure_lock()
+
         with self._lock:
             self._open()
 
@@ -335,12 +338,24 @@ class ConnectionPool(Generic[CT], BasePool):
 
         self._check_open()
 
+        # A lock has been most likely, but not necessarily, created in `open()`.
+        self._ensure_lock()
+
+        self._tasks = Queue()
+        self._sched = Scheduler()
+
         self._closed = False
         self._opened = True
 
         self._start_workers()
         self._start_initial_tasks()
 
+    def _ensure_lock(self) -> None:
+        try:
+            self._lock
+        except AttributeError:
+            self._lock = Lock()
+
     def _start_workers(self) -> None:
         self._sched_runner = spawn(self._sched.run, name=f"{self.name}-scheduler")
         assert not self._workers
index 1f8dc9bcc4808b7f69b2c446ead4eded5d5e188c..87b0a3d50bd8b8078084d052477b7486240f9c52 100644 (file)
@@ -317,10 +317,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
         currently re-open a closed pool.
         """
         # Make sure the lock is created after there is an event loop
-        try:
-            self._lock
-        except AttributeError:
-            self._lock = ALock()
+        self._ensure_lock()
 
         async with self._lock:
             self._open()
@@ -332,20 +329,15 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
         if not self._closed:
             return
 
-        # Throw a RuntimeError if the pool is open outside a running loop.
-        asyncio.get_running_loop()
-
         self._check_open()
 
+        # A lock has been most likely, but not necessarily, created in `open()`.
+        self._ensure_lock()
+
         # Create these objects now to attach them to the right loop.
         # See #219
         self._tasks = AQueue()
         self._sched = AsyncScheduler()
-        # This has been most likely, but not necessarily, created in `open()`.
-        try:
-            self._lock
-        except AttributeError:
-            self._lock = ALock()
 
         self._closed = False
         self._opened = True
@@ -353,6 +345,20 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
         self._start_workers()
         self._start_initial_tasks()
 
+    def _ensure_lock(self) -> None:
+        """Make sure the pool lock is created.
+
+        In async code, also make sure that the loop is running.
+        """
+
+        # Throw a RuntimeError if the pool is open outside a running loop.
+        asyncio.get_running_loop()
+
+        try:
+            self._lock
+        except AttributeError:
+            self._lock = ALock()
+
     def _start_workers(self) -> None:
         self._sched_runner = aspawn(self._sched.run, name=f"{self.name}-scheduler")
         for i in range(self.num_workers):