From: Daniele Varrazzo Date: Fri, 20 Mar 2020 11:30:08 +0000 (+1300) Subject: Added sketch of high-level connection objects X-Git-Tag: 3.0.dev0~693 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=a6fc9a43d4fcfcbf9c76836f1a90346c10f9c0ab;p=thirdparty%2Fpsycopg.git Added sketch of high-level connection objects Trying to create two similar Connection/AsyncConnection classes, a DBAPI compliant and an async DBAPI-inspired. The classes share the same implementation of asynchronous connection, which is independent from the waiting method, and offer different ways to wait: one blocking the thread, the other as async coroutine. --- diff --git a/psycopg3/__init__.py b/psycopg3/__init__.py index 05475ae53..c147bf1c7 100644 --- a/psycopg3/__init__.py +++ b/psycopg3/__init__.py @@ -5,6 +5,7 @@ psycopg3 -- PostgreSQL database adapter for Python # Copyright (C) 2020 The Psycopg Team from .consts import VERSION as __version__ # noqa +from .connection import AsyncConnection, Connection from .exceptions import ( Warning, @@ -19,6 +20,8 @@ from .exceptions import ( NotSupportedError, ) +connect = Connection.connect + __all__ = [ "Warning", "Error", @@ -30,4 +33,7 @@ __all__ = [ "InternalError", "ProgrammingError", "NotSupportedError", + "AsyncConnection", + "Connection", + "connect", ] diff --git a/psycopg3/connection.py b/psycopg3/connection.py new file mode 100644 index 000000000..f96b77f3c --- /dev/null +++ b/psycopg3/connection.py @@ -0,0 +1,90 @@ +""" +psycopg3 connection objects +""" + +# Copyright (C) 2020 The Psycopg Team + +import logging + +from . import pq +from . import exceptions as exc +from .conninfo import make_conninfo +from .waiting import wait_select, wait_async, WAIT_R, WAIT_W + +logger = logging.getLogger(__name__) + + +class BaseConnection: + """ + Base class for different types of connections. + + Share common functionalities such as access to the wrapped PGconn, but + allow different interfaces (sync/async). + """ + + def __init__(self, pgconn): + self.pgconn = pgconn + + @classmethod + def _connect_gen(cls, conninfo): + """ + Generator yielding connection states and returning a done connection. + """ + conninfo = conninfo.encode("utf8") + + conn = pq.PGconn.connect_start(conninfo) + logger.debug("connection started, status %s", conn.status.name) + while 1: + if conn.status == pq.ConnStatus.CONNECTION_BAD: + raise exc.OperationalError( + f"connection is bad: {pq.error_message(conn)}" + ) + + status = conn.connect_poll() + logger.debug("connection polled, status %s", conn.status.name) + if status == pq.PollingStatus.PGRES_POLLING_OK: + break + elif status == pq.PollingStatus.PGRES_POLLING_READING: + yield conn.socket, WAIT_R + elif status == pq.PollingStatus.PGRES_POLLING_WRITING: + yield conn.socket, WAIT_W + elif status == pq.PollingStatus.PGRES_POLLING_FAILED: + raise exc.OperationalError( + f"connection failed: {pq.error_message(conn)}" + ) + else: + raise exc.InternalError(f"unexpected poll status: {status}") + + conn.nonblocking = 1 + return conn + + +class Connection(BaseConnection): + """ + Wrap a connection to the database. + + This class implements a DBAPI-compliant interface. + """ + + @classmethod + def connect(cls, conninfo, **kwargs): + conninfo = make_conninfo(conninfo, **kwargs) + gen = cls._connect_gen(conninfo) + pgconn = wait_select(gen) + return cls(pgconn) + + +class AsyncConnection(BaseConnection): + """ + Wrap an asynchronous connection to the database. + + This class implements a DBAPI-inspired interface, with all the blocking + methods implemented as coroutines. + """ + + @classmethod + async def connect(cls, conninfo, **kwargs): + conninfo = make_conninfo(conninfo, **kwargs) + gen = cls._connect_gen(conninfo) + pgconn = await wait_async(gen) + return cls(pgconn) diff --git a/psycopg3/waiting.py b/psycopg3/waiting.py new file mode 100644 index 000000000..3aca5ba90 --- /dev/null +++ b/psycopg3/waiting.py @@ -0,0 +1,94 @@ +""" +Code concerned with waiting in different contexts (blocking, async, etc). +""" + +# Copyright (C) 2020 The Psycopg Team + + +from select import select +from asyncio import get_event_loop +from asyncio.queues import Queue + +from . import exceptions as exc + + +WAIT_R = "WAIT_R" +WAIT_W = "WAIT_W" +WAIT_RW = "WAIT_RW" +READY_R = "READY_R" +READY_W = "READY_W" + + +def wait_select(gen): + """ + Wait on the behalf of a generator using select(). + + *gen* is expected to generate tuples (fd, status). consume it and block + according to the status until fd is ready. Send back the ready state + to the generator. + + Return what the generator eventually returned. + """ + try: + while 1: + fd, s = next(gen) + if s == WAIT_R: + rf, wf, xf = select([fd], [], []) + assert rf + gen.send(READY_R) + elif s == WAIT_W: + rf, wf, xf = select([], [fd], []) + assert wf + gen.send(READY_W) + elif s == WAIT_RW: + rf, wf, xf = select([fd], [fd], []) + assert rf or wf + assert not (rf and wf) + if rf: + gen.send(READY_R) + else: + gen.send(READY_W) + else: + raise exc.InternalError("bad poll status: %s") + except StopIteration as e: + return e.args[0] + + +async def wait_async(gen): + """ + Coroutine waiting for a generator to complete. + + *gen* is expected to generate tuples (fd, status). consume it and block + according to the status until fd is ready. Send back the ready state + to the generator. + + Return what the generator eventually returned. + """ + # Use a queue to block and restart after the fd state changes. + # Not sure this is the best implementation but it's a start. + q = Queue() + loop = get_event_loop() + try: + while 1: + fd, s = next(gen) + if s == WAIT_R: + loop.add_reader(fd, q.put_nowait, READY_R) + ready = await q.get() + loop.remove_reader(fd) + gen.send(ready) + elif s == WAIT_W: + loop.add_writer(fd, q.put_nowait, READY_W) + ready = await q.get() + loop.remove_writer(fd) + gen.send(ready) + elif s == WAIT_RW: + loop.add_reader(fd, q.put_nowait, READY_R) + loop.add_writer(fd, q.put_nowait, READY_W) + ready = await q.get() + loop.remove_reader(fd) + loop.remove_writer(fd) + gen.send(ready) + else: + raise exc.InternalError("bad poll status: %s") + except StopIteration as e: + return e.args[0] diff --git a/tests/conftest.py b/tests/conftest.py index bf263193b..a53ed1bb1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1 +1 @@ -pytest_plugins = ("tests.fix_db", "tests.fix_tempenv") +pytest_plugins = ("tests.fix_async", "tests.fix_db", "tests.fix_tempenv") diff --git a/tests/fix_async.py b/tests/fix_async.py new file mode 100644 index 000000000..3a29d1d47 --- /dev/null +++ b/tests/fix_async.py @@ -0,0 +1,9 @@ +import asyncio + +import pytest + + +@pytest.fixture +def loop(): + """Return the async loop to test coroutines.""" + return asyncio.get_event_loop() diff --git a/tests/pq/test_pgconn.py b/tests/pq/test_pgconn.py index b93ebb29a..3d65d0e4c 100644 --- a/tests/pq/test_pgconn.py +++ b/tests/pq/test_pgconn.py @@ -21,6 +21,7 @@ def test_connectdb_badtype(pq, baddsn): def test_connect_async(pq, dsn): conn = pq.PGconn.connect_start(dsn.encode("utf8")) + conn.nonblocking = 1 while 1: assert conn.status != pq.ConnStatus.CONNECTION_BAD rv = conn.connect_poll() diff --git a/tests/test_async_connection.py b/tests/test_async_connection.py new file mode 100644 index 000000000..ba42440f8 --- /dev/null +++ b/tests/test_async_connection.py @@ -0,0 +1,14 @@ +import pytest + +import psycopg3 +from psycopg3 import AsyncConnection + + +def test_connect(pq, dsn, loop): + conn = loop.run_until_complete(AsyncConnection.connect(dsn)) + assert conn.pgconn.status == pq.ConnStatus.CONNECTION_OK + + +def test_connect_bad(loop): + with pytest.raises(psycopg3.OperationalError): + loop.run_until_complete(AsyncConnection.connect("dbname=nosuchdb")) diff --git a/tests/test_connection.py b/tests/test_connection.py new file mode 100644 index 000000000..74706eafb --- /dev/null +++ b/tests/test_connection.py @@ -0,0 +1,14 @@ +import pytest + +import psycopg3 +from psycopg3 import Connection + + +def test_connect(pq, dsn): + conn = Connection.connect(dsn) + assert conn.pgconn.status == pq.ConnStatus.CONNECTION_OK + + +def test_connect_bad(): + with pytest.raises(psycopg3.OperationalError): + Connection.connect("dbname=nosuchdb")