# Copyright (C) 2020 The Psycopg Team
from .consts import VERSION as __version__ # noqa
+from .connection import AsyncConnection, Connection
from .exceptions import (
Warning,
NotSupportedError,
)
+connect = Connection.connect
+
__all__ = [
"Warning",
"Error",
"InternalError",
"ProgrammingError",
"NotSupportedError",
+ "AsyncConnection",
+ "Connection",
+ "connect",
]
--- /dev/null
+"""
+psycopg3 connection objects
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import logging
+
+from . import pq
+from . import exceptions as exc
+from .conninfo import make_conninfo
+from .waiting import wait_select, wait_async, WAIT_R, WAIT_W
+
+logger = logging.getLogger(__name__)
+
+
+class BaseConnection:
+ """
+ Base class for different types of connections.
+
+ Share common functionalities such as access to the wrapped PGconn, but
+ allow different interfaces (sync/async).
+ """
+
+ def __init__(self, pgconn):
+ self.pgconn = pgconn
+
+ @classmethod
+ def _connect_gen(cls, conninfo):
+ """
+ Generator yielding connection states and returning a done connection.
+ """
+ conninfo = conninfo.encode("utf8")
+
+ conn = pq.PGconn.connect_start(conninfo)
+ logger.debug("connection started, status %s", conn.status.name)
+ while 1:
+ if conn.status == pq.ConnStatus.CONNECTION_BAD:
+ raise exc.OperationalError(
+ f"connection is bad: {pq.error_message(conn)}"
+ )
+
+ status = conn.connect_poll()
+ logger.debug("connection polled, status %s", conn.status.name)
+ if status == pq.PollingStatus.PGRES_POLLING_OK:
+ break
+ elif status == pq.PollingStatus.PGRES_POLLING_READING:
+ yield conn.socket, WAIT_R
+ elif status == pq.PollingStatus.PGRES_POLLING_WRITING:
+ yield conn.socket, WAIT_W
+ elif status == pq.PollingStatus.PGRES_POLLING_FAILED:
+ raise exc.OperationalError(
+ f"connection failed: {pq.error_message(conn)}"
+ )
+ else:
+ raise exc.InternalError(f"unexpected poll status: {status}")
+
+ conn.nonblocking = 1
+ return conn
+
+
+class Connection(BaseConnection):
+ """
+ Wrap a connection to the database.
+
+ This class implements a DBAPI-compliant interface.
+ """
+
+ @classmethod
+ def connect(cls, conninfo, **kwargs):
+ conninfo = make_conninfo(conninfo, **kwargs)
+ gen = cls._connect_gen(conninfo)
+ pgconn = wait_select(gen)
+ return cls(pgconn)
+
+
+class AsyncConnection(BaseConnection):
+ """
+ Wrap an asynchronous connection to the database.
+
+ This class implements a DBAPI-inspired interface, with all the blocking
+ methods implemented as coroutines.
+ """
+
+ @classmethod
+ async def connect(cls, conninfo, **kwargs):
+ conninfo = make_conninfo(conninfo, **kwargs)
+ gen = cls._connect_gen(conninfo)
+ pgconn = await wait_async(gen)
+ return cls(pgconn)
--- /dev/null
+"""
+Code concerned with waiting in different contexts (blocking, async, etc).
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+
+from select import select
+from asyncio import get_event_loop
+from asyncio.queues import Queue
+
+from . import exceptions as exc
+
+
+WAIT_R = "WAIT_R"
+WAIT_W = "WAIT_W"
+WAIT_RW = "WAIT_RW"
+READY_R = "READY_R"
+READY_W = "READY_W"
+
+
+def wait_select(gen):
+ """
+ Wait on the behalf of a generator using select().
+
+ *gen* is expected to generate tuples (fd, status). consume it and block
+ according to the status until fd is ready. Send back the ready state
+ to the generator.
+
+ Return what the generator eventually returned.
+ """
+ try:
+ while 1:
+ fd, s = next(gen)
+ if s == WAIT_R:
+ rf, wf, xf = select([fd], [], [])
+ assert rf
+ gen.send(READY_R)
+ elif s == WAIT_W:
+ rf, wf, xf = select([], [fd], [])
+ assert wf
+ gen.send(READY_W)
+ elif s == WAIT_RW:
+ rf, wf, xf = select([fd], [fd], [])
+ assert rf or wf
+ assert not (rf and wf)
+ if rf:
+ gen.send(READY_R)
+ else:
+ gen.send(READY_W)
+ else:
+ raise exc.InternalError("bad poll status: %s")
+ except StopIteration as e:
+ return e.args[0]
+
+
+async def wait_async(gen):
+ """
+ Coroutine waiting for a generator to complete.
+
+ *gen* is expected to generate tuples (fd, status). consume it and block
+ according to the status until fd is ready. Send back the ready state
+ to the generator.
+
+ Return what the generator eventually returned.
+ """
+ # Use a queue to block and restart after the fd state changes.
+ # Not sure this is the best implementation but it's a start.
+ q = Queue()
+ loop = get_event_loop()
+ try:
+ while 1:
+ fd, s = next(gen)
+ if s == WAIT_R:
+ loop.add_reader(fd, q.put_nowait, READY_R)
+ ready = await q.get()
+ loop.remove_reader(fd)
+ gen.send(ready)
+ elif s == WAIT_W:
+ loop.add_writer(fd, q.put_nowait, READY_W)
+ ready = await q.get()
+ loop.remove_writer(fd)
+ gen.send(ready)
+ elif s == WAIT_RW:
+ loop.add_reader(fd, q.put_nowait, READY_R)
+ loop.add_writer(fd, q.put_nowait, READY_W)
+ ready = await q.get()
+ loop.remove_reader(fd)
+ loop.remove_writer(fd)
+ gen.send(ready)
+ else:
+ raise exc.InternalError("bad poll status: %s")
+ except StopIteration as e:
+ return e.args[0]
-pytest_plugins = ("tests.fix_db", "tests.fix_tempenv")
+pytest_plugins = ("tests.fix_async", "tests.fix_db", "tests.fix_tempenv")
--- /dev/null
+import asyncio
+
+import pytest
+
+
+@pytest.fixture
+def loop():
+ """Return the async loop to test coroutines."""
+ return asyncio.get_event_loop()
def test_connect_async(pq, dsn):
conn = pq.PGconn.connect_start(dsn.encode("utf8"))
+ conn.nonblocking = 1
while 1:
assert conn.status != pq.ConnStatus.CONNECTION_BAD
rv = conn.connect_poll()
--- /dev/null
+import pytest
+
+import psycopg3
+from psycopg3 import AsyncConnection
+
+
+def test_connect(pq, dsn, loop):
+ conn = loop.run_until_complete(AsyncConnection.connect(dsn))
+ assert conn.pgconn.status == pq.ConnStatus.CONNECTION_OK
+
+
+def test_connect_bad(loop):
+ with pytest.raises(psycopg3.OperationalError):
+ loop.run_until_complete(AsyncConnection.connect("dbname=nosuchdb"))
--- /dev/null
+import pytest
+
+import psycopg3
+from psycopg3 import Connection
+
+
+def test_connect(pq, dsn):
+ conn = Connection.connect(dsn)
+ assert conn.pgconn.status == pq.ConnStatus.CONNECTION_OK
+
+
+def test_connect_bad():
+ with pytest.raises(psycopg3.OperationalError):
+ Connection.connect("dbname=nosuchdb")