]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Separate generators where the fd can change from the ones where not
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 2 Dec 2020 17:23:14 +0000 (17:23 +0000)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 2 Dec 2020 17:23:14 +0000 (17:23 +0000)
This should make the query operation marginally faster.

psycopg3/psycopg3/connection.py
psycopg3/psycopg3/generators.py
psycopg3/psycopg3/proto.py
psycopg3/psycopg3/waiting.py
psycopg3_c/psycopg3_c/_psycopg3.pyi
psycopg3_c/psycopg3_c/generators.pyx

index a11419fe28c8c37eaad2480da3aac0f7377d0d61..7d8adde0c48d5cfe8debf331b84f1cf04f9cf772 100644 (file)
@@ -10,7 +10,7 @@ import logging
 import threading
 from types import TracebackType
 from typing import Any, AsyncIterator, Callable, Iterator, List, NamedTuple
-from typing import Optional, Type, TYPE_CHECKING
+from typing import Optional, Type, TYPE_CHECKING, Union
 from weakref import ref, ReferenceType
 from functools import partial
 from contextlib import contextmanager
@@ -26,7 +26,7 @@ from . import errors as e
 from . import encodings
 from .pq import TransactionStatus, ExecStatus, Format
 from .sql import Composable
-from .proto import DumpersMap, LoadersMap, PQGen, RV, Query
+from .proto import DumpersMap, LoadersMap, PQGen, PQGenConn, RV, Query
 from .waiting import wait, wait_async
 from .conninfo import make_conninfo
 from .generators import notifies
@@ -35,7 +35,7 @@ from .transaction import Transaction, AsyncTransaction
 logger = logging.getLogger(__name__)
 package_logger = logging.getLogger("psycopg3")
 
-connect: Callable[[str], PQGen["PGconn"]]
+connect: Callable[[str], PQGenConn["PGconn"]]
 execute: Callable[["PGconn"], PQGen[List["PGresult"]]]
 
 if TYPE_CHECKING:
@@ -359,7 +359,11 @@ class Connection(BaseConnection):
             yield tx
 
     @classmethod
-    def wait(cls, gen: PQGen[RV], timeout: Optional[float] = 0.1) -> RV:
+    def wait(
+        cls,
+        gen: Union[PQGen[RV], PQGenConn[RV]],
+        timeout: Optional[float] = 0.1,
+    ) -> RV:
         return wait(gen, timeout=timeout)
 
     def _set_client_encoding(self, name: str) -> None:
@@ -518,7 +522,7 @@ class AsyncConnection(BaseConnection):
             yield tx
 
     @classmethod
-    async def wait(cls, gen: PQGen[RV]) -> RV:
+    async def wait(cls, gen: Union[PQGen[RV], PQGenConn[RV]]) -> RV:
         return await wait_async(gen)
 
     def _set_client_encoding(self, name: str) -> None:
index 933a9e21e8a8bdcdfdf3e3ac925eafb36db477bc..50a6c4e5bdc253158c66e96b6d39c2c4cef68794 100644 (file)
@@ -21,7 +21,7 @@ from typing import List, Optional, Union
 from . import pq
 from . import errors as e
 from .pq import ConnStatus, PollingStatus, ExecStatus
-from .proto import PQGen
+from .proto import PQGen, PQGenConn
 from .waiting import Wait, Ready
 from .encodings import py_codecs
 from .pq.proto import PGconn, PGresult
@@ -29,7 +29,7 @@ from .pq.proto import PGconn, PGresult
 logger = logging.getLogger(__name__)
 
 
-def connect(conninfo: str) -> PQGen[PGconn]:
+def connect(conninfo: str) -> PQGenConn[PGconn]:
     """
     Generator to create a database connection without blocking.
 
@@ -73,7 +73,7 @@ def execute(pgconn: PGconn) -> PQGen[List[PGresult]]:
     or error).
     """
     yield from send(pgconn)
-    rv = yield from fetch(pgconn)
+    rv = yield from _fetch(pgconn)
     return rv
 
 
@@ -85,23 +85,24 @@ def send(pgconn: PGconn) -> PQGen[None]:
     similar. Flush the query and then return the result using nonblocking
     functions.
 
-    After this generator has finished you may want to cycle using `fetch()`
+    After this generator has finished you may want to cycle using `_fetch()`
     to retrieve the results available.
     """
+    yield pgconn.socket
     while 1:
         f = pgconn.flush()
         if f == 0:
             break
 
-        ready = yield pgconn.socket, Wait.RW
+        ready = yield Wait.RW
         if ready & Ready.R:
             # This call may read notifies: they will be saved in the
-            # PGconn buffer and passed to Python later, in `fetch()`.
+            # PGconn buffer and passed to Python later, in `_fetch()`.
             pgconn.consume_input()
         continue
 
 
-def fetch(pgconn: PGconn) -> PQGen[List[PGresult]]:
+def _fetch(pgconn: PGconn) -> PQGen[List[PGresult]]:
     """
     Generator retrieving results from the database without blocking.
 
@@ -110,12 +111,15 @@ def fetch(pgconn: PGconn) -> PQGen[List[PGresult]]:
 
     Return the list of results returned by the database (whether success
     or error).
+
+    Note that this generator doesn't yield the socket number, which must have
+    been already sent in the sending part of the cycle.
     """
     results: List[PGresult] = []
     while 1:
         pgconn.consume_input()
         if pgconn.is_busy():
-            yield pgconn.socket, Wait.R
+            yield Wait.R
             continue
 
         # Consume notifies
@@ -146,7 +150,8 @@ _copy_statuses = (
 
 
 def notifies(pgconn: PGconn) -> PQGen[List[pq.PGnotify]]:
-    yield pgconn.socket, Wait.R
+    yield pgconn.socket
+    yield Wait.R
     pgconn.consume_input()
 
     ns = []
@@ -161,13 +166,14 @@ def notifies(pgconn: PGconn) -> PQGen[List[pq.PGnotify]]:
 
 
 def copy_from(pgconn: PGconn) -> PQGen[Union[bytes, PGresult]]:
+    yield pgconn.socket
     while 1:
         nbytes, data = pgconn.get_copy_data(1)
         if nbytes != 0:
             break
 
         # would block
-        yield pgconn.socket, Wait.R
+        yield Wait.R
         pgconn.consume_input()
 
     if nbytes > 0:
@@ -175,7 +181,7 @@ def copy_from(pgconn: PGconn) -> PQGen[Union[bytes, PGresult]]:
         return data
 
     # Retrieve the final result of copy
-    (result,) = yield from fetch(pgconn)
+    (result,) = yield from _fetch(pgconn)
     if result.status != ExecStatus.COMMAND_OK:
         encoding = py_codecs.get(
             pgconn.parameter_status(b"client_encoding") or "", "utf-8"
@@ -186,25 +192,27 @@ def copy_from(pgconn: PGconn) -> PQGen[Union[bytes, PGresult]]:
 
 
 def copy_to(pgconn: PGconn, buffer: bytes) -> PQGen[None]:
+    yield pgconn.socket
     # Retry enqueuing data until successful
     while pgconn.put_copy_data(buffer) == 0:
-        yield pgconn.socket, Wait.W
+        yield Wait.W
 
 
 def copy_end(pgconn: PGconn, error: Optional[bytes]) -> PQGen[PGresult]:
+    yield pgconn.socket
     # Retry enqueuing end copy message until successful
     while pgconn.put_copy_end(error) == 0:
-        yield pgconn.socket, Wait.W
+        yield Wait.W
 
     # Repeat until it the message is flushed to the server
     while 1:
-        yield pgconn.socket, Wait.W
+        yield Wait.W
         f = pgconn.flush()
         if f == 0:
             break
 
     # Retrieve the final result of copy
-    (result,) = yield from fetch(pgconn)
+    (result,) = yield from _fetch(pgconn)
     if result.status != ExecStatus.COMMAND_OK:
         encoding = py_codecs.get(
             pgconn.parameter_status(b"client_encoding") or "", "utf-8"
index 1e26e528af49afe207a6353a96f847cf1e5c0d64..328f716bce11a17bdaeb3ab41ccfa2cb798e355d 100644 (file)
@@ -27,7 +27,18 @@ ConnectionType = TypeVar("ConnectionType", bound="BaseConnection")
 # Waiting protocol types
 
 RV = TypeVar("RV")
-PQGen = Generator[Tuple[int, "Wait"], "Ready", RV]
+PQGenConn = Generator[Tuple[int, "Wait"], "Ready", RV]
+"""Generator for processes where the connection file number can change.
+
+This can happen in connection and reset, but not in normal querying.
+"""
+
+PQGen = Generator[Union[int, "Wait"], "Ready", RV]
+"""Generator for processes where the connection file number won't change.
+
+The first item generated is the file descriptor; following items are be the
+Wait states.
+"""
 
 
 # Adaptation types
index 67ac85280779e772afbfe813c7a3e15303180e4f..85e1119662079c06cc37949c99c3dc401c69d8a6 100644 (file)
@@ -10,12 +10,12 @@ These functions are designed to consume the generators returned by the
 
 
 from enum import IntEnum
-from typing import Optional
+from typing import Optional, Union
 from asyncio import get_event_loop, Event
 from selectors import DefaultSelector, EVENT_READ, EVENT_WRITE
 
 from . import errors as e
-from .proto import PQGen, RV
+from .proto import PQGen, PQGenConn, RV
 
 
 class Wait(IntEnum):
@@ -29,7 +29,9 @@ class Ready(IntEnum):
     W = EVENT_WRITE
 
 
-def wait(gen: PQGen[RV], timeout: Optional[float] = None) -> RV:
+def wait(
+    gen: Union[PQGen[RV], PQGenConn[RV]], timeout: Optional[float] = None
+) -> RV:
     """
     Wait for a generator using the best option available on the platform.
 
@@ -40,25 +42,43 @@ def wait(gen: PQGen[RV], timeout: Optional[float] = None) -> RV:
     :type timeout: float
     :return: whatever *gen* returns on completion.
     """
+    fd: int
+    s: Wait
     sel = DefaultSelector()
     try:
-        fd, s = next(gen)
-        while 1:
-            sel.register(fd, s)
-            ready = None
-            while not ready:
-                ready = sel.select(timeout=timeout)
-            sel.unregister(fd)
-
-            assert len(ready) == 1
-            fd, s = gen.send(ready[0][1])
+        # Use the first generated item to tell if it's a PQgen or PQgenConn.
+        # Note: mypy gets confused by the behaviour of this generator.
+        item = next(gen)
+        if isinstance(item, tuple):
+            fd, s = item
+            while 1:
+                sel.register(fd, s)
+                ready = None
+                while not ready:
+                    ready = sel.select(timeout=timeout)
+                sel.unregister(fd)
+
+                assert len(ready) == 1
+                fd, s = gen.send(ready[0][1])
+        else:
+            fd = item  # type: ignore[assignment]
+            s = next(gen)  # type: ignore[assignment]
+            while 1:
+                sel.register(fd, s)
+                ready = None
+                while not ready:
+                    ready = sel.select(timeout=timeout)
+                sel.unregister(fd)
+
+                assert len(ready) == 1
+                s = gen.send(ready[0][1])  # type: ignore[arg-type,assignment]
 
     except StopIteration as ex:
         rv: RV = ex.args[0] if ex.args else None
         return rv
 
 
-async def wait_async(gen: PQGen[RV]) -> RV:
+async def wait_async(gen: Union[PQGen[RV], PQGenConn[RV]]) -> RV:
     """
     Coroutine waiting for a generator to complete.
 
@@ -73,6 +93,8 @@ async def wait_async(gen: PQGen[RV]) -> RV:
     ev = Event()
     loop = get_event_loop()
     ready: Ready
+    fd: int
+    s: Wait
 
     def wakeup(state: Ready) -> None:
         nonlocal ready
@@ -80,26 +102,52 @@ async def wait_async(gen: PQGen[RV]) -> RV:
         ev.set()
 
     try:
-        fd, s = next(gen)
-        while 1:
-            ev.clear()
-            if s == Wait.R:
-                loop.add_reader(fd, wakeup, Ready.R)
-                await ev.wait()
-                loop.remove_reader(fd)
-            elif s == Wait.W:
-                loop.add_writer(fd, wakeup, Ready.W)
-                await ev.wait()
-                loop.remove_writer(fd)
-            elif s == Wait.RW:
-                loop.add_reader(fd, wakeup, Ready.R)
-                loop.add_writer(fd, wakeup, Ready.W)
-                await ev.wait()
-                loop.remove_reader(fd)
-                loop.remove_writer(fd)
-            else:
-                raise e.InternalError("bad poll status: %s")
-            fd, s = gen.send(ready)
+        # Use the first generated item to tell if it's a PQgen or PQgenConn.
+        # Note: mypy gets confused by the behaviour of this generator.
+        item = next(gen)
+        if isinstance(item, tuple):
+            fd, s = item
+            while 1:
+                ev.clear()
+                if s == Wait.R:
+                    loop.add_reader(fd, wakeup, Ready.R)
+                    await ev.wait()
+                    loop.remove_reader(fd)
+                elif s == Wait.W:
+                    loop.add_writer(fd, wakeup, Ready.W)
+                    await ev.wait()
+                    loop.remove_writer(fd)
+                elif s == Wait.RW:
+                    loop.add_reader(fd, wakeup, Ready.R)
+                    loop.add_writer(fd, wakeup, Ready.W)
+                    await ev.wait()
+                    loop.remove_reader(fd)
+                    loop.remove_writer(fd)
+                else:
+                    raise e.InternalError("bad poll status: %s")
+                fd, s = gen.send(ready)  # type: ignore[misc]
+        else:
+            fd = item  # type: ignore[assignment]
+            s = next(gen)  # type: ignore[assignment]
+            while 1:
+                ev.clear()
+                if s == Wait.R:
+                    loop.add_reader(fd, wakeup, Ready.R)
+                    await ev.wait()
+                    loop.remove_reader(fd)
+                elif s == Wait.W:
+                    loop.add_writer(fd, wakeup, Ready.W)
+                    await ev.wait()
+                    loop.remove_writer(fd)
+                elif s == Wait.RW:
+                    loop.add_reader(fd, wakeup, Ready.R)
+                    loop.add_writer(fd, wakeup, Ready.W)
+                    await ev.wait()
+                    loop.remove_reader(fd)
+                    loop.remove_writer(fd)
+                else:
+                    raise e.InternalError("bad poll status: %s")
+                s = gen.send(ready)  # type: ignore[arg-type,assignment]
 
     except StopIteration as ex:
         rv: RV = ex.args[0] if ex.args else None
index 886e4627f2885bfd83d22d5588dcfb57415b7480..d765fa41cb8e61daf306480b907d18054d02152d 100644 (file)
@@ -11,7 +11,7 @@ from typing import Any, Iterable, List, Optional, Sequence, Tuple
 
 from psycopg3.adapt import Dumper, Loader
 from psycopg3.proto import AdaptContext, DumpFunc, DumpersMap, DumperType
-from psycopg3.proto import LoadFunc, LoadersMap, LoaderType, PQGen
+from psycopg3.proto import LoadFunc, LoadersMap, LoaderType, PQGen, PQGenConn
 from psycopg3.connection import BaseConnection
 from psycopg3 import pq
 
@@ -40,7 +40,7 @@ class Transformer:
     def get_loader(self, oid: int, format: pq.Format) -> Loader: ...
 
 def register_builtin_c_adapters() -> None: ...
-def connect(conninfo: str) -> PQGen[pq.proto.PGconn]: ...
+def connect(conninfo: str) -> PQGenConn[pq.proto.PGconn]: ...
 def execute(pgconn: pq.proto.PGconn) -> PQGen[List[pq.proto.PGresult]]: ...
 
 # vim: set syntax=python:
index 58fffc95e83e8be7720fa05240156df6843b6c8a..ab840c6657a9096f0a4a457615b69795616b3450 100644 (file)
@@ -19,7 +19,7 @@ cdef object WAIT_R = Wait.R
 cdef object WAIT_RW = Wait.RW
 cdef int READY_R = Ready.R
 
-def connect(conninfo: str) -> PQGen[pq.proto.PGconn]:
+def connect(conninfo: str) -> PQGenConn[pq.proto.PGconn]:
     """
     Generator to create a database connection without blocking.
 
@@ -71,12 +71,16 @@ def execute(PGconn pgconn) -> PQGen[List[pq.proto.PGresult]]:
     cdef int status
     cdef libpq.PGnotify *notify
 
+    # Start the generator by sending the connection fd, which won't change
+    # during the query process.
+    yield libpq.PQsocket(pgconn_ptr)
+
     # Sending the query
     while 1:
         if libpq.PQflush(pgconn_ptr) == 0:
             break
 
-        status = yield libpq.PQsocket(pgconn_ptr), WAIT_RW
+        status = yield WAIT_RW
         if status & READY_R:
             # This call may read notifies which will be saved in the
             # PGconn buffer and passed to Python later.
@@ -85,15 +89,13 @@ def execute(PGconn pgconn) -> PQGen[List[pq.proto.PGresult]]:
                     f"consuming input failed: {pq.error_message(pgconn)}")
         continue
 
-    wr = (libpq.PQsocket(pgconn_ptr), WAIT_R)
-
     # Fetching the result
     while 1:
         if 1 != libpq.PQconsumeInput(pgconn_ptr):
             raise pq.PQerror(
                 f"consuming input failed: {pq.error_message(pgconn)}")
         if libpq.PQisBusy(pgconn_ptr):
-            yield wr
+            yield WAIT_R
             continue
 
         # Consume notifies