]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added generators implementation in cython
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 16 May 2020 13:50:31 +0000 (01:50 +1200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 17 May 2020 09:29:34 +0000 (21:29 +1200)
psycopg3/_psycopg3.pyi
psycopg3/_psycopg3.pyx
psycopg3/connection.py
psycopg3/cursor.py
psycopg3/generators.py
psycopg3/generators.pyx [new file with mode: 0644]
tests/pq/test_async.py
tests/test_async_connection.py
tests/test_connection.py
tests/test_psycopg3_dbapi20.py

index 10eb17cdecb8d55ea9b843750bfa2d20d94ccad8..ca26da420f831f419639e22c428650c65f7a9a76 100644 (file)
@@ -11,7 +11,7 @@ import codecs
 from typing import Any, Iterable, List, Optional, Sequence, Tuple
 
 from .proto import AdaptContext, DumpFunc, DumpersMap, DumperType
-from .proto import LoadFunc, LoadersMap, LoaderType, MaybeOid
+from .proto import LoadFunc, LoadersMap, LoaderType, MaybeOid, PQGen
 from .connection import BaseConnection
 from . import pq
 
@@ -49,5 +49,7 @@ class Transformer:
     def lookup_loader(self, oid: int, format: Format) -> LoaderType: ...
 
 def register_builtin_c_loaders() -> None: ...
+def connect(conninfo: str) -> PQGen[pq.proto.PGconn]: ...
+def execute(pgconn: pq.proto.PGconn) -> PQGen[List[pq.proto.PGresult]]: ...
 
 # vim: set syntax=python:
index c7537e0eb91ed63bd00f5f16e1d75d20fa744a1b..6746187ac380617b21632eaa296120a75cbc9bb3 100644 (file)
@@ -9,5 +9,6 @@ if a compiler is available.
 
 include "types/numeric.pyx"
 include "types/text.pyx"
+include "generators.pyx"
 include "adapt.pyx"
 include "transform.pyx"
index b4b14c093546ef11a06d3b70abdec99a173f3e01..bcf45dd153734ab8eca99665754809a787bc3d6e 100644 (file)
@@ -8,21 +8,32 @@ import codecs
 import logging
 import asyncio
 import threading
-from typing import Any, Optional, Type
-from typing import cast, TYPE_CHECKING
+from typing import Any, Callable, List, Optional, Type, cast
 
 from . import pq
 from . import errors as e
 from . import cursor
-from . import generators
 from . import proto
 from .conninfo import make_conninfo
 from .waiting import wait, wait_async
 
 logger = logging.getLogger(__name__)
 
-if TYPE_CHECKING:
-    from .proto import PQGen, RV
+
+connect: Callable[[str], proto.PQGen[pq.proto.PGconn]]
+execute: Callable[[pq.proto.PGconn], proto.PQGen[List[pq.proto.PGresult]]]
+
+if pq.__impl__ == "c":
+    from . import _psycopg3
+
+    connect = _psycopg3.connect
+    execute = _psycopg3.execute
+
+else:
+    from . import generators
+
+    connect = generators.connect
+    execute = generators.execute
 
 
 class BaseConnection:
@@ -129,7 +140,7 @@ class Connection(BaseConnection):
         if conninfo is None and not kwargs:
             raise TypeError("missing conninfo and not parameters specified")
         conninfo = make_conninfo(conninfo or "", **kwargs)
-        gen = generators.connect(conninfo)
+        gen = connect(conninfo)
         pgconn = cls.wait(gen)
         return cls(pgconn)
 
@@ -155,7 +166,7 @@ class Connection(BaseConnection):
                 return
 
             self.pgconn.send_query(command)
-            (pgres,) = self.wait(generators.execute(self.pgconn))
+            (pgres,) = self.wait(execute(self.pgconn))
             if pgres.status != pq.ExecStatus.COMMAND_OK:
                 raise e.OperationalError(
                     f"error on {command.decode('utf8')}:"
@@ -163,7 +174,9 @@ class Connection(BaseConnection):
                 )
 
     @classmethod
-    def wait(cls, gen: "PQGen[RV]", timeout: Optional[float] = 0.1) -> "RV":
+    def wait(
+        cls, gen: proto.PQGen[proto.RV], timeout: Optional[float] = 0.1
+    ) -> proto.RV:
         return wait(gen, timeout=timeout)
 
     def set_client_encoding(self, value: str) -> None:
@@ -172,7 +185,7 @@ class Connection(BaseConnection):
                 b"select set_config('client_encoding', $1, false)",
                 [value.encode("ascii")],
             )
-            gen = generators.execute(self.pgconn)
+            gen = execute(self.pgconn)
             (result,) = self.wait(gen)
             if result.status != pq.ExecStatus.TUPLES_OK:
                 raise e.error_from_result(result)
@@ -200,7 +213,7 @@ class AsyncConnection(BaseConnection):
         if conninfo is None and not kwargs:
             raise TypeError("missing conninfo and not parameters specified")
         conninfo = make_conninfo(conninfo or "", **kwargs)
-        gen = generators.connect(conninfo)
+        gen = connect(conninfo)
         pgconn = await cls.wait(gen)
         return cls(pgconn)
 
@@ -226,7 +239,7 @@ class AsyncConnection(BaseConnection):
                 return
 
             self.pgconn.send_query(command)
-            (pgres,) = await self.wait(generators.execute(self.pgconn))
+            (pgres,) = await self.wait(execute(self.pgconn))
             if pgres.status != pq.ExecStatus.COMMAND_OK:
                 raise e.OperationalError(
                     f"error on {command.decode('utf8')}:"
@@ -234,7 +247,7 @@ class AsyncConnection(BaseConnection):
                 )
 
     @classmethod
-    async def wait(cls, gen: "PQGen[RV]") -> "RV":
+    async def wait(cls, gen: proto.PQGen[proto.RV]) -> proto.RV:
         return await wait_async(gen)
 
     async def set_client_encoding(self, value: str) -> None:
@@ -243,7 +256,7 @@ class AsyncConnection(BaseConnection):
                 b"select set_config('client_encoding', $1, false)",
                 [value.encode("ascii")],
             )
-            gen = generators.execute(self.pgconn)
+            gen = execute(self.pgconn)
             (result,) = await self.wait(gen)
             if result.status != pq.ExecStatus.TUPLES_OK:
                 raise e.error_from_result(result)
index 9f9efd605816b155c95582d94e38aadf3ff0daff..a41cd58501298a0461fe395c78d779a7fb52f448 100644 (file)
@@ -6,18 +6,29 @@ psycopg3 cursor objects
 
 import codecs
 from operator import attrgetter
-from typing import Any, List, Optional, Sequence, TYPE_CHECKING
+from typing import Any, Callable, List, Optional, Sequence, TYPE_CHECKING
 
 from . import errors as e
 from . import pq
-from . import generators
 from . import proto
-from .proto import Query, Params, DumpersMap, LoadersMap
+from .proto import Query, Params, DumpersMap, LoadersMap, PQGen
 from .utils.queries import PostgresQuery
 
 if TYPE_CHECKING:
     from .connection import BaseConnection, Connection, AsyncConnection
 
+execute: Callable[[pq.proto.PGconn], PQGen[List[pq.proto.PGresult]]]
+
+if pq.__impl__ == "c":
+    from . import _psycopg3
+
+    execute = _psycopg3.execute
+
+else:
+    from . import generators
+
+    execute = generators.execute
+
 
 class Column(Sequence[Any]):
     def __init__(
@@ -253,7 +264,7 @@ class Cursor(BaseCursor):
         with self.connection.lock:
             self._start_query()
             self._execute_send(query, vars)
-            gen = generators.execute(self.connection.pgconn)
+            gen = execute(self.connection.pgconn)
             results = self.connection.wait(gen)
             self._execute_results(results)
         return self
@@ -266,7 +277,7 @@ class Cursor(BaseCursor):
             for i, vars in enumerate(vars_seq):
                 if i == 0:
                     pgq = self._send_prepare(b"", query, vars)
-                    gen = generators.execute(self.connection.pgconn)
+                    gen = execute(self.connection.pgconn)
                     (result,) = self.connection.wait(gen)
                     if result.status == self.ExecStatus.FATAL_ERROR:
                         raise e.error_from_result(result)
@@ -274,7 +285,7 @@ class Cursor(BaseCursor):
                     pgq.dump(vars)
 
                 self._send_query_prepared(b"", pgq)
-                gen = generators.execute(self.connection.pgconn)
+                gen = execute(self.connection.pgconn)
                 (result,) = self.connection.wait(gen)
                 self._execute_results((result,))
 
@@ -331,7 +342,7 @@ class AsyncCursor(BaseCursor):
         async with self.connection.lock:
             self._start_query()
             self._execute_send(query, vars)
-            gen = generators.execute(self.connection.pgconn)
+            gen = execute(self.connection.pgconn)
             results = await self.connection.wait(gen)
             self._execute_results(results)
         return self
@@ -344,7 +355,7 @@ class AsyncCursor(BaseCursor):
             for i, vars in enumerate(vars_seq):
                 if i == 0:
                     pgq = self._send_prepare(b"", query, vars)
-                    gen = generators.execute(self.connection.pgconn)
+                    gen = execute(self.connection.pgconn)
                     (result,) = await self.connection.wait(gen)
                     if result.status == self.ExecStatus.FATAL_ERROR:
                         raise e.error_from_result(result)
@@ -352,7 +363,7 @@ class AsyncCursor(BaseCursor):
                     pgq.dump(vars)
 
                 self._send_query_prepared(b"", pgq)
-                gen = generators.execute(self.connection.pgconn)
+                gen = execute(self.connection.pgconn)
                 (result,) = await self.connection.wait(gen)
                 self._execute_results((result,))
 
index 6b89bfc332730501a9d98cb0eaa1c140568dbc1d..3deed3fbf8c159ff7197aca34d1e7639776f8c4c 100644 (file)
@@ -17,11 +17,11 @@ when the file descriptor is ready.
 
 import logging
 from typing import List
-from .waiting import Wait, Ready
 
 from . import pq
 from . import errors as e
 from .proto import PQGen
+from .waiting import Wait, Ready
 
 logger = logging.getLogger(__name__)
 
diff --git a/psycopg3/generators.pyx b/psycopg3/generators.pyx
new file mode 100644 (file)
index 0000000..4eb9a56
--- /dev/null
@@ -0,0 +1,107 @@
+"""
+C implementation of generators for the communication protocols with the libpq
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import logging
+from typing import List
+
+from . import errors as e
+from .proto import PQGen
+from .waiting import Wait, Ready
+from psycopg3 import pq
+from psycopg3.pq cimport libpq
+from psycopg3.pq.pq_cython cimport PGconn, PGresult
+
+cdef object WAIT_W = Wait.W
+cdef object WAIT_R = Wait.R
+cdef object WAIT_RW = Wait.RW
+cdef int READY_R = Ready.R
+
+def connect(conninfo: str) -> PQGen[pq.proto.PGconn]:
+    """
+    Generator to create a database connection without blocking.
+
+    """
+    cdef PGconn conn = PGconn.connect_start(conninfo.encode("utf8"))
+    logger.debug("connection started, status %s", conn.status.name)
+    cdef libpq.PGconn *pgconn_ptr = conn.pgconn_ptr
+    cdef int conn_status = libpq.PQstatus(pgconn_ptr)
+    cdef int poll_status
+
+    while 1:
+        if conn_status == libpq.CONNECTION_BAD:
+            raise e.OperationalError(
+                f"connection is bad: {pq.error_message(conn)}"
+            )
+
+        poll_status = libpq.PQconnectPoll(pgconn_ptr)
+        logger.debug("connection polled, status %s", conn.status.name)
+        if poll_status == libpq.PGRES_POLLING_OK:
+            break
+        elif poll_status == libpq.PGRES_POLLING_READING:
+            yield (libpq.PQsocket(pgconn_ptr), WAIT_R)
+        elif poll_status == libpq.PGRES_POLLING_WRITING:
+            yield (libpq.PQsocket(pgconn_ptr), WAIT_W)
+        elif poll_status == libpq.PGRES_POLLING_FAILED:
+            raise e.OperationalError(
+                f"connection failed: {pq.error_message(conn)}"
+            )
+        else:
+            raise e.InternalError(f"unexpected poll status: {poll_status}")
+
+    conn.nonblocking = 1
+    return conn
+
+
+def execute(PGconn pgconn) -> PQGen[List[pq.proto.PGresult]]:
+    """
+    Generator sending a query and returning results without blocking.
+
+    The query must have already been sent using `pgconn.send_query()` or
+    similar. Flush the query and then return the result using nonblocking
+    functions.
+
+    Return the list of results returned by the database (whether success
+    or error).
+    """
+    results: List[pq.proto.PGresult] = []
+    cdef libpq.PGconn *pgconn_ptr = pgconn.pgconn_ptr
+    cdef int status
+
+    # Sending the query
+    while 1:
+        if libpq.PQflush(pgconn_ptr) == 0:
+            break
+
+        status = yield libpq.PQsocket(pgconn_ptr), WAIT_RW
+        if status & READY_R:
+            if 1 != libpq.PQconsumeInput(pgconn_ptr):
+                raise pq.PQerror(
+                    f"consuming input failed: {pq.error_message(pgconn)}")
+        continue
+
+    wr = (libpq.PQsocket(pgconn_ptr), WAIT_R)
+
+    # Fetching the result
+    while 1:
+        if 1 != libpq.PQconsumeInput(pgconn_ptr):
+            raise pq.PQerror(
+                f"consuming input failed: {pq.error_message(pgconn)}")
+        if libpq.PQisBusy(pgconn_ptr):
+            yield wr
+            continue
+
+        res = libpq.PQgetResult(pgconn_ptr)
+        if res is NULL:
+            break
+        results.append(PGresult._from_ptr(res))
+
+        status = libpq.PQresultStatus(res)
+        if status in (libpq.PGRES_COPY_IN, libpq.PGRES_COPY_OUT, libpq.PGRES_COPY_BOTH):
+            # After entering copy mode the libpq will create a phony result
+            # for every request so let's break the endless loop.
+            break
+
+    return results
index 969e406c725f0c05ae35cef9981b30f3dcfeaee5..47378e464498d5492394906cc6fdef5bd174ae5f 100644 (file)
@@ -1,6 +1,7 @@
 import pytest
 from select import select
 import psycopg3
+from psycopg3.generators import execute
 
 
 def test_send_query(pq, pgconn):
@@ -61,7 +62,7 @@ def test_send_query_compact_test(pq, pgconn):
         b"/* %s */ select pg_sleep(0.01); select 1 as foo;"
         % (b"x" * 1_000_000)
     )
-    results = psycopg3.waiting.wait(psycopg3.generators.execute(pgconn))
+    results = psycopg3.waiting.wait(execute(pgconn))
 
     assert len(results) == 2
     assert results[0].nfields == 1
@@ -78,7 +79,7 @@ def test_send_query_compact_test(pq, pgconn):
 
 def test_send_query_params(pq, pgconn):
     pgconn.send_query_params(b"select $1::int + $2", [b"5", b"3"])
-    (res,) = psycopg3.waiting.wait(psycopg3.generators.execute(pgconn))
+    (res,) = psycopg3.waiting.wait(execute(pgconn))
     assert res.status == pq.ExecStatus.TUPLES_OK
     assert res.get_value(0, 0) == b"8"
 
@@ -89,11 +90,11 @@ def test_send_query_params(pq, pgconn):
 
 def test_send_prepare(pq, pgconn):
     pgconn.send_prepare(b"prep", b"select $1::int + $2::int")
-    (res,) = psycopg3.waiting.wait(psycopg3.generators.execute(pgconn))
+    (res,) = psycopg3.waiting.wait(execute(pgconn))
     assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
 
     pgconn.send_query_prepared(b"prep", [b"3", b"5"])
-    (res,) = psycopg3.waiting.wait(psycopg3.generators.execute(pgconn))
+    (res,) = psycopg3.waiting.wait(execute(pgconn))
     assert res.get_value(0, 0) == b"8"
 
     pgconn.finish()
@@ -105,22 +106,22 @@ def test_send_prepare(pq, pgconn):
 
 def test_send_prepare_types(pq, pgconn):
     pgconn.send_prepare(b"prep", b"select $1 + $2", [23, 23])
-    (res,) = psycopg3.waiting.wait(psycopg3.generators.execute(pgconn))
+    (res,) = psycopg3.waiting.wait(execute(pgconn))
     assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
 
     pgconn.send_query_prepared(b"prep", [b"3", b"5"])
-    (res,) = psycopg3.waiting.wait(psycopg3.generators.execute(pgconn))
+    (res,) = psycopg3.waiting.wait(execute(pgconn))
     assert res.get_value(0, 0) == b"8"
 
 
 def test_send_prepared_binary_in(pq, pgconn):
     val = b"foo\00bar"
     pgconn.send_prepare(b"", b"select length($1::bytea), length($2::bytea)")
-    (res,) = psycopg3.waiting.wait(psycopg3.generators.execute(pgconn))
+    (res,) = psycopg3.waiting.wait(execute(pgconn))
     assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
 
     pgconn.send_query_prepared(b"", [val, val], param_formats=[0, 1])
-    (res,) = psycopg3.waiting.wait(psycopg3.generators.execute(pgconn))
+    (res,) = psycopg3.waiting.wait(execute(pgconn))
     assert res.status == pq.ExecStatus.TUPLES_OK
     assert res.get_value(0, 0) == b"3"
     assert res.get_value(0, 1) == b"7"
@@ -135,12 +136,12 @@ def test_send_prepared_binary_in(pq, pgconn):
 def test_send_prepared_binary_out(pq, pgconn, fmt, out):
     val = b"foo\00bar"
     pgconn.send_prepare(b"", b"select $1::bytea")
-    (res,) = psycopg3.waiting.wait(psycopg3.generators.execute(pgconn))
+    (res,) = psycopg3.waiting.wait(execute(pgconn))
     assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
 
     pgconn.send_query_prepared(
         b"", [val], param_formats=[1], result_format=fmt
     )
-    (res,) = psycopg3.waiting.wait(psycopg3.generators.execute(pgconn))
+    (res,) = psycopg3.waiting.wait(execute(pgconn))
     assert res.status == pq.ExecStatus.TUPLES_OK
     assert res.get_value(0, 0) == out
index 81fc057150a43d6f1280428878a9843a79b52dfe..f5a8875e90bdc4857611bc9e34e45ae046de244f 100644 (file)
@@ -103,7 +103,7 @@ def test_connect_args(monkeypatch, pgconn, loop, testdsn, kwargs, want):
         return pgconn
         yield
 
-    monkeypatch.setattr(psycopg3.generators, "connect", fake_connect)
+    monkeypatch.setattr(psycopg3.connection, "connect", fake_connect)
     loop.run_until_complete(
         psycopg3.AsyncConnection.connect(testdsn, **kwargs)
     )
@@ -118,7 +118,7 @@ def test_connect_badargs(monkeypatch, pgconn, loop, args, kwargs):
         return pgconn
         yield
 
-    monkeypatch.setattr(psycopg3.generators, "connect", fake_connect)
+    monkeypatch.setattr(psycopg3.connection, "connect", fake_connect)
     with pytest.raises((TypeError, psycopg3.ProgrammingError)):
         loop.run_until_complete(
             psycopg3.AsyncConnection.connect(*args, **kwargs)
index 774a979bd243cc9b8c718571cb94632acf57c645..b5643ce705bf7076c1035e0c0cd586e8ce0d39da 100644 (file)
@@ -105,7 +105,7 @@ def test_connect_args(monkeypatch, pgconn, testdsn, kwargs, want):
         return pgconn
         yield
 
-    monkeypatch.setattr(psycopg3.generators, "connect", fake_connect)
+    monkeypatch.setattr(psycopg3.connection, "connect", fake_connect)
     psycopg3.Connection.connect(testdsn, **kwargs)
     assert conninfo_to_dict(the_conninfo) == conninfo_to_dict(want)
 
@@ -118,6 +118,6 @@ def test_connect_badargs(monkeypatch, pgconn, args, kwargs):
         return pgconn
         yield
 
-    monkeypatch.setattr(psycopg3.generators, "connect", fake_connect)
+    monkeypatch.setattr(psycopg3.connection, "connect", fake_connect)
     with pytest.raises((TypeError, psycopg3.ProgrammingError)):
         psycopg3.Connection.connect(*args, **kwargs)
index c2e6727c8b8d1cd6e03fa7fa0703be987fb19dcb..2e91c2224442fe56d0e6f7cbaed65bec0f418b1e 100644 (file)
@@ -127,7 +127,7 @@ def test_connect_args(monkeypatch, pgconn, testdsn, kwargs, want):
         return pgconn
         yield
 
-    monkeypatch.setattr(psycopg3.generators, "connect", fake_connect)
+    monkeypatch.setattr(psycopg3.connection, "connect", fake_connect)
     psycopg3.connect(testdsn, **kwargs)
     assert conninfo_to_dict(the_conninfo) == conninfo_to_dict(want)
 
@@ -140,6 +140,6 @@ def test_connect_badargs(monkeypatch, pgconn, args, kwargs):
         return pgconn
         yield
 
-    monkeypatch.setattr(psycopg3.generators, "connect", fake_connect)
+    monkeypatch.setattr(psycopg3.connection, "connect", fake_connect)
     with pytest.raises((TypeError, psycopg3.ProgrammingError)):
         psycopg3.connect(*args, **kwargs)