]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
libpq protocol generators moved to their own module
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 11 Apr 2020 12:55:10 +0000 (00:55 +1200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 11 Apr 2020 12:57:36 +0000 (00:57 +1200)
psycopg3/connection.py
psycopg3/cursor.py
psycopg3/generators.py [new file with mode: 0644]
tests/pq/test_async.py
tests/pq/test_pgconn.py

index 094f89695e33549aa8f1e49efc3a82831d575919..c3019029977bc81ce439d114543929067d8fb773 100644 (file)
@@ -8,19 +8,18 @@ import codecs
 import logging
 import asyncio
 import threading
-from typing import Any, Generator, List, Optional, Tuple, Type, TypeVar
+from typing import Any, Generator, Optional, Tuple, Type, TypeVar
 from typing import cast, TYPE_CHECKING
 
 from . import pq
 from . import errors as e
 from . import cursor
+from . import generators
 from .conninfo import make_conninfo
 from .waiting import wait, wait_async, Wait, Ready
 
 logger = logging.getLogger(__name__)
 
-ConnectGen = Generator[Tuple[int, Wait], Ready, pq.PGconn]
-QueryGen = Generator[Tuple[int, Wait], Ready, List[pq.PGresult]]
 RV = TypeVar("RV")
 
 if TYPE_CHECKING:
@@ -112,88 +111,6 @@ class BaseConnection:
         else:
             return "UTF8"
 
-    @classmethod
-    def _connect_gen(cls, conninfo: str) -> ConnectGen:
-        """
-        Generator to create a database connection without blocking.
-
-        Yield pairs (fileno, `Wait`) whenever an operation would block. The
-        generator can be restarted sending the appropriate `Ready` state when
-        the file descriptor is ready.
-        """
-        conn = pq.PGconn.connect_start(conninfo.encode("utf8"))
-        logger.debug("connection started, status %s", conn.status.name)
-        while 1:
-            if conn.status == cls.ConnStatus.BAD:
-                raise e.OperationalError(
-                    f"connection is bad: {pq.error_message(conn)}"
-                )
-
-            status = conn.connect_poll()
-            logger.debug("connection polled, status %s", conn.status.name)
-            if status == pq.PollingStatus.OK:
-                break
-            elif status == pq.PollingStatus.READING:
-                yield conn.socket, Wait.R
-            elif status == pq.PollingStatus.WRITING:
-                yield conn.socket, Wait.W
-            elif status == pq.PollingStatus.FAILED:
-                raise e.OperationalError(
-                    f"connection failed: {pq.error_message(conn)}"
-                )
-            else:
-                raise e.InternalError(f"unexpected poll status: {status}")
-
-        conn.nonblocking = 1
-        return conn
-
-    @classmethod
-    def _exec_gen(cls, pgconn: pq.PGconn) -> QueryGen:
-        """
-        Generator returning query results without blocking.
-
-        The query must have already been sent using `pgconn.send_query()` or
-        similar. Flush the query and then return the result using nonblocking
-        functions.
-
-        Yield pairs (fileno, `Wait`) whenever an operation would block. The
-        generator can be restarted sending the appropriate `Ready` state when
-        the file descriptor is ready.
-
-        Return the list of results returned by the database (whether success
-        or error).
-        """
-        results: List[pq.PGresult] = []
-
-        while 1:
-            f = pgconn.flush()
-            if f == 0:
-                break
-
-            ready = yield pgconn.socket, Wait.RW
-            if ready & Ready.R:
-                pgconn.consume_input()
-            continue
-
-        while 1:
-            pgconn.consume_input()
-            if pgconn.is_busy():
-                ready = yield pgconn.socket, Wait.R
-            res = pgconn.get_result()
-            if res is None:
-                break
-            results.append(res)
-            if res.status in (
-                pq.ExecStatus.COPY_IN,
-                pq.ExecStatus.COPY_OUT,
-                pq.ExecStatus.COPY_BOTH,
-            ):
-                # After entering copy mode the libpq will create a phony result
-                # for every request so let's break the endless loop.
-                break
-
-        return results
-
 
 class Connection(BaseConnection):
     """
@@ -216,7 +133,7 @@ class Connection(BaseConnection):
         if connection_factory is not None:
             raise NotImplementedError()
         conninfo = make_conninfo(conninfo, **kwargs)
-        gen = cls._connect_gen(conninfo)
+        gen = generators.connect(conninfo)
         pgconn = cls.wait(gen)
         return cls(pgconn)
 
@@ -239,7 +156,7 @@ class Connection(BaseConnection):
                 return
 
             self.pgconn.send_query(command)
-            (pgres,) = self.wait(self._exec_gen(self.pgconn))
+            (pgres,) = self.wait(generators.execute(self.pgconn))
             if pgres.status != pq.ExecStatus.COMMAND_OK:
                 raise e.OperationalError(
                     f"error on {command.decode('utf8')}:"
@@ -260,7 +177,7 @@ class Connection(BaseConnection):
                 b"select set_config('client_encoding', $1, false)",
                 [value.encode("ascii")],
             )
-            gen = self._exec_gen(self.pgconn)
+            gen = generators.execute(self.pgconn)
             (result,) = self.wait(gen)
             if result.status != pq.ExecStatus.TUPLES_OK:
                 raise e.error_from_result(result)
@@ -284,7 +201,7 @@ class AsyncConnection(BaseConnection):
     @classmethod
     async def connect(cls, conninfo: str, **kwargs: Any) -> "AsyncConnection":
         conninfo = make_conninfo(conninfo, **kwargs)
-        gen = cls._connect_gen(conninfo)
+        gen = generators.connect(conninfo)
         pgconn = await cls.wait(gen)
         return cls(pgconn)
 
@@ -307,7 +224,7 @@ class AsyncConnection(BaseConnection):
                 return
 
             self.pgconn.send_query(command)
-            (pgres,) = await self.wait(self._exec_gen(self.pgconn))
+            (pgres,) = await self.wait(generators.execute(self.pgconn))
             if pgres.status != pq.ExecStatus.COMMAND_OK:
                 raise e.OperationalError(
                     f"error on {command.decode('utf8')}:"
@@ -324,7 +241,7 @@ class AsyncConnection(BaseConnection):
                 b"select set_config('client_encoding', $1, false)",
                 [value.encode("ascii")],
             )
-            gen = self._exec_gen(self.pgconn)
+            gen = generators.execute(self.pgconn)
             (result,) = await self.wait(gen)
             if result.status != pq.ExecStatus.TUPLES_OK:
                 raise e.error_from_result(result)
index c2ac2da8a03643df1d568d88eb4631377c07e74d..71a4cf9e1c5033799157796912234dd8d014d09a 100644 (file)
@@ -10,13 +10,14 @@ from typing import Any, List, Mapping, Optional, Sequence, Tuple, TYPE_CHECKING
 
 from . import errors as e
 from . import pq
+from . import generators
 from .utils.queries import query2pg, reorder_params
 from .utils.typing import Query, Params
 
 if TYPE_CHECKING:
-    from .connection import BaseConnection, Connection, AsyncConnection
-    from .connection import QueryGen
     from .adapt import DumpersMap, LoadersMap, Transformer
+    from .connection import BaseConnection, Connection, AsyncConnection
+    from .generators import QueryGen
 
 
 class Column(Sequence[Any]):
@@ -186,7 +187,7 @@ class BaseCursor:
             else:
                 self.conn.pgconn.send_query(query)
 
-        return self.conn._exec_gen(self.conn.pgconn)
+        return generators.execute(self.conn.pgconn)
 
     def _execute_results(self, results: List[pq.PGresult]) -> None:
         """
diff --git a/psycopg3/generators.py b/psycopg3/generators.py
new file mode 100644 (file)
index 0000000..1899eaa
--- /dev/null
@@ -0,0 +1,105 @@
+"""
+Generators implementing communication protocols with the libpq
+
+Certain operations (connection, querying) are an interleave of libpq calls and
+waiting for the socket to be ready. This module contains the code to execute
+the operations, yielding a polling state whenever there is to wait. The
+functions in the `waiting` module are the ones who wait more or less
+cooperatively for the socket to be ready and make these generators continue.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import logging
+from typing import Generator, List, Tuple
+from .waiting import Wait, Ready
+
+from . import pq
+from . import errors as e
+
+ConnectGen = Generator[Tuple[int, Wait], Ready, pq.PGconn]
+QueryGen = Generator[Tuple[int, Wait], Ready, List[pq.PGresult]]
+
+logger = logging.getLogger(__name__)
+
+
+def connect(conninfo: str) -> ConnectGen:
+    """
+    Generator to create a database connection without blocking.
+
+    Yield pairs (fileno, `Wait`) whenever an operation would block. The
+    generator can be restarted sending the appropriate `Ready` state when
+    the file descriptor is ready.
+    """
+    conn = pq.PGconn.connect_start(conninfo.encode("utf8"))
+    logger.debug("connection started, status %s", conn.status.name)
+    while 1:
+        if conn.status == pq.ConnStatus.BAD:
+            raise e.OperationalError(
+                f"connection is bad: {pq.error_message(conn)}"
+            )
+
+        status = conn.connect_poll()
+        logger.debug("connection polled, status %s", conn.status.name)
+        if status == pq.PollingStatus.OK:
+            break
+        elif status == pq.PollingStatus.READING:
+            yield conn.socket, Wait.R
+        elif status == pq.PollingStatus.WRITING:
+            yield conn.socket, Wait.W
+        elif status == pq.PollingStatus.FAILED:
+            raise e.OperationalError(
+                f"connection failed: {pq.error_message(conn)}"
+            )
+        else:
+            raise e.InternalError(f"unexpected poll status: {status}")
+
+    conn.nonblocking = 1
+    return conn
+
+
+def execute(pgconn: pq.PGconn) -> QueryGen:
+    """
+    Generator returning query results without blocking.
+
+    The query must have already been sent using `pgconn.send_query()` or
+    similar. Flush the query and then return the result using nonblocking
+    functions.
+
+    Yield pairs (fileno, `Wait`) whenever an operation would block. The
+    generator can be restarted sending the appropriate `Ready` state when
+    the file descriptor is ready.
+
+    Return the list of results returned by the database (whether success
+    or error).
+    """
+    results: List[pq.PGresult] = []
+
+    while 1:
+        f = pgconn.flush()
+        if f == 0:
+            break
+
+        ready = yield pgconn.socket, Wait.RW
+        if ready & Ready.R:
+            pgconn.consume_input()
+        continue
+
+    while 1:
+        pgconn.consume_input()
+        if pgconn.is_busy():
+            ready = yield pgconn.socket, Wait.R
+        res = pgconn.get_result()
+        if res is None:
+            break
+        results.append(res)
+        if res.status in (
+            pq.ExecStatus.COPY_IN,
+            pq.ExecStatus.COPY_OUT,
+            pq.ExecStatus.COPY_BOTH,
+        ):
+            # After entering copy mode the libpq will create a phony result
+            # for every request so let's break the endless loop.
+            break
+
+    return results
index d5a0e185d16ec6accc22abb69ec48041c863a431..7cedc9574ec5ccd703c81376792b163160c24034 100644 (file)
@@ -55,13 +55,13 @@ def test_send_query(pq, pgconn):
     assert results[1].get_value(0, 0) == b"1"
 
 
-def test_send_query_compact_test(pq, conn):
+def test_send_query_compact_test(pq, pgconn):
     # Like the above test but use psycopg3 facilities for compactness
-    conn.pgconn.send_query(
+    pgconn.send_query(
         b"/* %s */ select pg_sleep(0.01); select 1 as foo;"
         % (b"x" * 1_000_000)
     )
-    results = conn.wait(conn._exec_gen(conn.pgconn))
+    results = psycopg3.waiting.wait(psycopg3.generators.execute(pgconn))
 
     assert len(results) == 2
     assert results[0].nfields == 1
@@ -71,17 +71,17 @@ def test_send_query_compact_test(pq, conn):
     assert results[1].fname(0) == b"foo"
     assert results[1].get_value(0, 0) == b"1"
 
-    conn.pgconn.finish()
+    pgconn.finish()
     with pytest.raises(psycopg3.OperationalError):
-        conn.pgconn.send_query(b"select 1")
+        pgconn.send_query(b"select 1")
 
 
-def test_send_query_params(pq, conn):
-    res = conn.pgconn.send_query_params(b"select $1::int + $2", [b"5", b"3"])
-    (res,) = conn.wait(conn._exec_gen(conn.pgconn))
+def test_send_query_params(pq, pgconn):
+    res = pgconn.send_query_params(b"select $1::int + $2", [b"5", b"3"])
+    (res,) = psycopg3.waiting.wait(psycopg3.generators.execute(pgconn))
     assert res.status == pq.ExecStatus.TUPLES_OK
     assert res.get_value(0, 0) == b"8"
 
-    conn.pgconn.finish()
+    pgconn.finish()
     with pytest.raises(psycopg3.OperationalError):
-        conn.pgconn.send_query_params(b"select $1", [b"1"])
+        pgconn.send_query_params(b"select $1", [b"1"])
index 65b16b4c43d73b8473a9904d72b61502b996f1ae..9e8ea4fd3c76ae033abb2f38ee51211189103595 100644 (file)
@@ -202,7 +202,7 @@ def test_transaction_status(pq, pgconn):
     assert pgconn.transaction_status == pq.TransactionStatus.INTRANS
     pgconn.send_query(b"select 1")
     assert pgconn.transaction_status == pq.TransactionStatus.ACTIVE
-    psycopg3.waiting.wait(psycopg3.Connection._exec_gen(pgconn))
+    psycopg3.waiting.wait(psycopg3.generators.execute(pgconn))
     assert pgconn.transaction_status == pq.TransactionStatus.INTRANS
     pgconn.finish()
     assert pgconn.transaction_status == pq.TransactionStatus.UNKNOWN