]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Handle RW-ready in wait(,_conn)_async()
authorDenis Laxalde <denis@laxalde.org>
Thu, 11 Nov 2021 09:59:59 +0000 (10:59 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 16 Nov 2021 10:28:18 +0000 (11:28 +0100)
psycopg/psycopg/waiting.py

index b75dbe5b598de93a14a22c18bc82435d9fae2746..7563d4f444817bba91ba2092e968acfef51c0d65 100644 (file)
@@ -119,29 +119,27 @@ async def wait_async(gen: PQGen[RV], fileno: int) -> RV:
 
     def wakeup(state: Ready) -> None:
         nonlocal ready
-        ready = state
+        ready |= state  # type: ignore[assignment]
         ev.set()
 
     try:
         s = next(gen)
         while 1:
+            reader = s & Wait.R
+            writer = s & Wait.W
+            if not reader and not writer:
+                raise e.InternalError(f"bad poll status: {s}")
             ev.clear()
-            if s == Wait.R:
-                loop.add_reader(fileno, wakeup, Ready.R)
-                await ev.wait()
-                loop.remove_reader(fileno)
-            elif s == Wait.W:
-                loop.add_writer(fileno, wakeup, Ready.W)
-                await ev.wait()
-                loop.remove_writer(fileno)
-            elif s == Wait.RW:
+            ready = 0  # type: ignore[assignment]
+            if reader:
                 loop.add_reader(fileno, wakeup, Ready.R)
+            if writer:
                 loop.add_writer(fileno, wakeup, Ready.W)
-                await ev.wait()
+            await ev.wait()
+            if reader:
                 loop.remove_reader(fileno)
+            if writer:
                 loop.remove_writer(fileno)
-            else:
-                raise e.InternalError("bad poll status: %s")
             s = gen.send(ready)
 
     except StopIteration as ex:
@@ -180,23 +178,21 @@ async def wait_conn_async(
     try:
         fileno, s = next(gen)
         while 1:
+            reader = s & Wait.R
+            writer = s & Wait.W
+            if not reader and not writer:
+                raise e.InternalError(f"bad poll status: {s}")
             ev.clear()
-            if s == Wait.R:
-                loop.add_reader(fileno, wakeup, Ready.R)
-                await wait_for(ev.wait(), timeout)
-                loop.remove_reader(fileno)
-            elif s == Wait.W:
-                loop.add_writer(fileno, wakeup, Ready.W)
-                await wait_for(ev.wait(), timeout)
-                loop.remove_writer(fileno)
-            elif s == Wait.RW:
+            ready = 0  # type: ignore[assignment]
+            if reader:
                 loop.add_reader(fileno, wakeup, Ready.R)
+            if writer:
                 loop.add_writer(fileno, wakeup, Ready.W)
-                await wait_for(ev.wait(), timeout)
+            await wait_for(ev.wait(), timeout)
+            if reader:
                 loop.remove_reader(fileno)
+            if writer:
                 loop.remove_writer(fileno)
-            else:
-                raise e.InternalError("bad poll status: %s")
             fileno, s = gen.send(ready)
 
     except TimeoutError: