]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added sketch of high-level connection objects
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 20 Mar 2020 11:30:08 +0000 (00:30 +1300)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 20 Mar 2020 11:35:38 +0000 (00:35 +1300)
Trying to create two similar Connection/AsyncConnection classes, a DBAPI
compliant and an async DBAPI-inspired.

The classes share the same implementation of asynchronous connection,
which is independent from the waiting method, and offer different ways
to wait: one blocking the thread, the other as async coroutine.

psycopg3/__init__.py
psycopg3/connection.py [new file with mode: 0644]
psycopg3/waiting.py [new file with mode: 0644]
tests/conftest.py
tests/fix_async.py [new file with mode: 0644]
tests/pq/test_pgconn.py
tests/test_async_connection.py [new file with mode: 0644]
tests/test_connection.py [new file with mode: 0644]

index 05475ae53596cea873f1b4de17dc1b60e6b1fbca..c147bf1c7c75c90c87367fce929288692bb393f0 100644 (file)
@@ -5,6 +5,7 @@ psycopg3 -- PostgreSQL database adapter for Python
 # Copyright (C) 2020 The Psycopg Team
 
 from .consts import VERSION as __version__  # noqa
+from .connection import AsyncConnection, Connection
 
 from .exceptions import (
     Warning,
@@ -19,6 +20,8 @@ from .exceptions import (
     NotSupportedError,
 )
 
+connect = Connection.connect
+
 __all__ = [
     "Warning",
     "Error",
@@ -30,4 +33,7 @@ __all__ = [
     "InternalError",
     "ProgrammingError",
     "NotSupportedError",
+    "AsyncConnection",
+    "Connection",
+    "connect",
 ]
diff --git a/psycopg3/connection.py b/psycopg3/connection.py
new file mode 100644 (file)
index 0000000..f96b77f
--- /dev/null
@@ -0,0 +1,90 @@
+"""
+psycopg3 connection objects
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import logging
+
+from . import pq
+from . import exceptions as exc
+from .conninfo import make_conninfo
+from .waiting import wait_select, wait_async, WAIT_R, WAIT_W
+
+logger = logging.getLogger(__name__)
+
+
+class BaseConnection:
+    """
+    Base class for different types of connections.
+
+    Share common functionalities such as access to the wrapped PGconn, but
+    allow different interfaces (sync/async).
+    """
+
+    def __init__(self, pgconn):
+        self.pgconn = pgconn
+
+    @classmethod
+    def _connect_gen(cls, conninfo):
+        """
+        Generator yielding connection states and returning a done connection.
+        """
+        conninfo = conninfo.encode("utf8")
+
+        conn = pq.PGconn.connect_start(conninfo)
+        logger.debug("connection started, status %s", conn.status.name)
+        while 1:
+            if conn.status == pq.ConnStatus.CONNECTION_BAD:
+                raise exc.OperationalError(
+                    f"connection is bad: {pq.error_message(conn)}"
+                )
+
+            status = conn.connect_poll()
+            logger.debug("connection polled, status %s", conn.status.name)
+            if status == pq.PollingStatus.PGRES_POLLING_OK:
+                break
+            elif status == pq.PollingStatus.PGRES_POLLING_READING:
+                yield conn.socket, WAIT_R
+            elif status == pq.PollingStatus.PGRES_POLLING_WRITING:
+                yield conn.socket, WAIT_W
+            elif status == pq.PollingStatus.PGRES_POLLING_FAILED:
+                raise exc.OperationalError(
+                    f"connection failed: {pq.error_message(conn)}"
+                )
+            else:
+                raise exc.InternalError(f"unexpected poll status: {status}")
+
+        conn.nonblocking = 1
+        return conn
+
+
+class Connection(BaseConnection):
+    """
+    Wrap a connection to the database.
+
+    This class implements a DBAPI-compliant interface.
+    """
+
+    @classmethod
+    def connect(cls, conninfo, **kwargs):
+        conninfo = make_conninfo(conninfo, **kwargs)
+        gen = cls._connect_gen(conninfo)
+        pgconn = wait_select(gen)
+        return cls(pgconn)
+
+
+class AsyncConnection(BaseConnection):
+    """
+    Wrap an asynchronous connection to the database.
+
+    This class implements a DBAPI-inspired interface, with all the blocking
+    methods implemented as coroutines.
+    """
+
+    @classmethod
+    async def connect(cls, conninfo, **kwargs):
+        conninfo = make_conninfo(conninfo, **kwargs)
+        gen = cls._connect_gen(conninfo)
+        pgconn = await wait_async(gen)
+        return cls(pgconn)
diff --git a/psycopg3/waiting.py b/psycopg3/waiting.py
new file mode 100644 (file)
index 0000000..3aca5ba
--- /dev/null
@@ -0,0 +1,94 @@
+"""
+Code concerned with waiting in different contexts (blocking, async, etc).
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+
+from select import select
+from asyncio import get_event_loop
+from asyncio.queues import Queue
+
+from . import exceptions as exc
+
+
+WAIT_R = "WAIT_R"
+WAIT_W = "WAIT_W"
+WAIT_RW = "WAIT_RW"
+READY_R = "READY_R"
+READY_W = "READY_W"
+
+
+def wait_select(gen):
+    """
+    Wait on the behalf of a generator using select().
+
+    *gen* is expected to generate tuples (fd, status). consume it and block
+    according to the status until fd is ready. Send back the ready state
+    to the generator.
+
+    Return what the generator eventually returned.
+    """
+    try:
+        while 1:
+            fd, s = next(gen)
+            if s == WAIT_R:
+                rf, wf, xf = select([fd], [], [])
+                assert rf
+                gen.send(READY_R)
+            elif s == WAIT_W:
+                rf, wf, xf = select([], [fd], [])
+                assert wf
+                gen.send(READY_W)
+            elif s == WAIT_RW:
+                rf, wf, xf = select([fd], [fd], [])
+                assert rf or wf
+                assert not (rf and wf)
+                if rf:
+                    gen.send(READY_R)
+                else:
+                    gen.send(READY_W)
+            else:
+                raise exc.InternalError("bad poll status: %s")
+    except StopIteration as e:
+        return e.args[0]
+
+
+async def wait_async(gen):
+    """
+    Coroutine waiting for a generator to complete.
+
+    *gen* is expected to generate tuples (fd, status). consume it and block
+    according to the status until fd is ready. Send back the ready state
+    to the generator.
+
+    Return what the generator eventually returned.
+    """
+    # Use a queue to block and restart after the fd state changes.
+    # Not sure this is the best implementation but it's a start.
+    q = Queue()
+    loop = get_event_loop()
+    try:
+        while 1:
+            fd, s = next(gen)
+            if s == WAIT_R:
+                loop.add_reader(fd, q.put_nowait, READY_R)
+                ready = await q.get()
+                loop.remove_reader(fd)
+                gen.send(ready)
+            elif s == WAIT_W:
+                loop.add_writer(fd, q.put_nowait, READY_W)
+                ready = await q.get()
+                loop.remove_writer(fd)
+                gen.send(ready)
+            elif s == WAIT_RW:
+                loop.add_reader(fd, q.put_nowait, READY_R)
+                loop.add_writer(fd, q.put_nowait, READY_W)
+                ready = await q.get()
+                loop.remove_reader(fd)
+                loop.remove_writer(fd)
+                gen.send(ready)
+            else:
+                raise exc.InternalError("bad poll status: %s")
+    except StopIteration as e:
+        return e.args[0]
index bf263193beb2d1e2450d2e7db78c1a0cfb43e30f..a53ed1bb17369b249e93bcd431b0a8b485ab681d 100644 (file)
@@ -1 +1 @@
-pytest_plugins = ("tests.fix_db", "tests.fix_tempenv")
+pytest_plugins = ("tests.fix_async", "tests.fix_db", "tests.fix_tempenv")
diff --git a/tests/fix_async.py b/tests/fix_async.py
new file mode 100644 (file)
index 0000000..3a29d1d
--- /dev/null
@@ -0,0 +1,9 @@
+import asyncio
+
+import pytest
+
+
+@pytest.fixture
+def loop():
+    """Return the async loop to test coroutines."""
+    return asyncio.get_event_loop()
index b93ebb29a572b4b1ca6468e3904560d315798518..3d65d0e4cf73027e64c47c478b62600bbbfcd2be 100644 (file)
@@ -21,6 +21,7 @@ def test_connectdb_badtype(pq, baddsn):
 
 def test_connect_async(pq, dsn):
     conn = pq.PGconn.connect_start(dsn.encode("utf8"))
+    conn.nonblocking = 1
     while 1:
         assert conn.status != pq.ConnStatus.CONNECTION_BAD
         rv = conn.connect_poll()
diff --git a/tests/test_async_connection.py b/tests/test_async_connection.py
new file mode 100644 (file)
index 0000000..ba42440
--- /dev/null
@@ -0,0 +1,14 @@
+import pytest
+
+import psycopg3
+from psycopg3 import AsyncConnection
+
+
+def test_connect(pq, dsn, loop):
+    conn = loop.run_until_complete(AsyncConnection.connect(dsn))
+    assert conn.pgconn.status == pq.ConnStatus.CONNECTION_OK
+
+
+def test_connect_bad(loop):
+    with pytest.raises(psycopg3.OperationalError):
+        loop.run_until_complete(AsyncConnection.connect("dbname=nosuchdb"))
diff --git a/tests/test_connection.py b/tests/test_connection.py
new file mode 100644 (file)
index 0000000..74706ea
--- /dev/null
@@ -0,0 +1,14 @@
+import pytest
+
+import psycopg3
+from psycopg3 import Connection
+
+
+def test_connect(pq, dsn):
+    conn = Connection.connect(dsn)
+    assert conn.pgconn.status == pq.ConnStatus.CONNECTION_OK
+
+
+def test_connect_bad():
+    with pytest.raises(psycopg3.OperationalError):
+        Connection.connect("dbname=nosuchdb")