# Copyright (C) 2020 The Psycopg Team
import logging
+from threading import Lock
from . import pq
from . import exceptions as exc
from .conninfo import make_conninfo
-from .waiting import wait_select, wait_async, WAIT_R, WAIT_W
+from .waiting import wait_select, wait_async, Wait, Ready
logger = logging.getLogger(__name__)
def __init__(self, pgconn):
self.pgconn = pgconn
+ self.lock = Lock()
@classmethod
def _connect_gen(cls, conninfo):
"""
- Generator yielding connection states and returning a done connection.
+ Generator to create a database connection using without blocking.
+
+ Yield pairs (fileno, `Wait`) whenever an operation would block. The
+ generator can be restarted sending the appropriate `Ready` state when
+ the file descriptor is ready.
"""
conninfo = conninfo.encode("utf8")
if status == pq.PollingStatus.PGRES_POLLING_OK:
break
elif status == pq.PollingStatus.PGRES_POLLING_READING:
- yield conn.socket, WAIT_R
+ yield conn.socket, Wait.R
elif status == pq.PollingStatus.PGRES_POLLING_WRITING:
- yield conn.socket, WAIT_W
+ yield conn.socket, Wait.W
elif status == pq.PollingStatus.PGRES_POLLING_FAILED:
raise exc.OperationalError(
f"connection failed: {pq.error_message(conn)}"
conn.nonblocking = 1
return conn
+ @classmethod
+ def _exec_gen(cls, pgconn):
+ """
+ Generator returning query results without blocking.
+
+ The query must have already been sent using `pgconn.send_query()` or
+ similar. Flush the query and then return the result using nonblocking
+ functions.
+
+ Yield pairs (fileno, `Wait`) whenever an operation would block. The
+ generator can be restarted sending the appropriate `Ready` state when
+ the file descriptor is ready.
+
+ Return the list of results returned by the database (whether success
+ or error).
+ """
+ results = []
+
+ while 1:
+ f = pgconn.flush()
+ if f == 0:
+ break
+
+ ready = yield pgconn.socket, Wait.RW
+ if ready is Ready.R:
+ pgconn.consume_input()
+ continue
+
+ while 1:
+ pgconn.consume_input()
+ if pgconn.is_busy():
+ ready = yield pgconn.socket, Wait.R
+ res = pgconn.get_result()
+ if res is None:
+ break
+ results.append(res)
+
+ return results
+
class Connection(BaseConnection):
"""
"""
@classmethod
- def connect(cls, conninfo, **kwargs):
+ def connect(cls, conninfo, connection_factory=None, **kwargs):
+ if connection_factory is not None:
+ raise NotImplementedError()
conninfo = make_conninfo(conninfo, **kwargs)
gen = cls._connect_gen(conninfo)
pgconn = wait_select(gen)
return cls(pgconn)
+ def commit(self):
+ self._exec_commit_rollback(b"commit")
+
+ def rollback(self):
+ self._exec_commit_rollback(b"rollback")
+
+ def _exec_commit_rollback(self, command):
+ with self.lock:
+ status = self.pgconn.transaction_status
+ if status == pq.TransactionStatus.PQTRANS_IDLE:
+ return
+
+ self.pgconn.send_query(command)
+ (pgres,) = wait_select(self._exec_gen(self.pgconn))
+ if pgres.status != pq.ExecStatus.PGRES_COMMAND_OK:
+ raise exc.OperationalError(
+ f"error on {command.decode('utf8')}:"
+ f" {pq.error_message(pgres)}"
+ )
+
class AsyncConnection(BaseConnection):
"""
gen = cls._connect_gen(conninfo)
pgconn = await wait_async(gen)
return cls(pgconn)
+
+ async def commit(self):
+ await self._exec_commit_rollback(b"commit")
+
+ async def rollback(self):
+ await self._exec_commit_rollback(b"rollback")
+
+ async def _exec_commit_rollback(self, command):
+ with self.lock:
+ status = self.pgconn.transaction_status
+ if status == pq.TransactionStatus.PQTRANS_IDLE:
+ return
+
+ self.pgconn.send_query(command)
+ (pgres,) = await wait_async(self._exec_gen(self.pgconn))
+ if pgres.status != pq.ExecStatus.PGRES_COMMAND_OK:
+ raise exc.OperationalError(
+ f"error on {command.decode('utf8')}:"
+ f" {pq.error_message(pgres)}"
+ )
impl.PQreset(self.pgconn_ptr)
def reset_start(self):
- rv = impl.PQresetStart(self.pgconn_ptr)
- if rv == 0:
+ if not impl.PQresetStart(self.pgconn_ptr):
raise PQerror("couldn't reset connection")
def reset_poll(self):
raise TypeError(
"bytes expected, got %s instead" % type(command).__name__
)
- return impl.PQsendQuery(self.pgconn_ptr, command)
+ if not impl.PQsendQuery(self.pgconn_ptr, command):
+ raise PQerror(f"sending query failed: {error_message(self)}")
def exec_params(
self,
# Copyright (C) 2020 The Psycopg Team
+from enum import Enum
from select import select
from asyncio import get_event_loop
from asyncio.queues import Queue
from . import exceptions as exc
-WAIT_R = "WAIT_R"
-WAIT_W = "WAIT_W"
-WAIT_RW = "WAIT_RW"
-READY_R = "READY_R"
-READY_W = "READY_W"
+Wait = Enum("Wait", "R W RW")
+Ready = Enum("Ready", "R W")
def wait_select(gen):
try:
while 1:
fd, s = next(gen)
- if s == WAIT_R:
+ if s is Wait.R:
rf, wf, xf = select([fd], [], [])
assert rf
- gen.send(READY_R)
- elif s == WAIT_W:
+ gen.send(Ready.R)
+ elif s is Wait.W:
rf, wf, xf = select([], [fd], [])
assert wf
- gen.send(READY_W)
- elif s == WAIT_RW:
+ gen.send(Ready.W)
+ elif s is Wait.RW:
rf, wf, xf = select([fd], [fd], [])
assert rf or wf
assert not (rf and wf)
if rf:
- gen.send(READY_R)
+ gen.send(Ready.R)
else:
- gen.send(READY_W)
+ gen.send(Ready.W)
else:
raise exc.InternalError("bad poll status: %s")
except StopIteration as e:
try:
while 1:
fd, s = next(gen)
- if s == WAIT_R:
- loop.add_reader(fd, q.put_nowait, READY_R)
+ if s is Wait.R:
+ loop.add_reader(fd, q.put_nowait, Ready.R)
ready = await q.get()
loop.remove_reader(fd)
gen.send(ready)
- elif s == WAIT_W:
- loop.add_writer(fd, q.put_nowait, READY_W)
+ elif s is Wait.W:
+ loop.add_writer(fd, q.put_nowait, Ready.W)
ready = await q.get()
loop.remove_writer(fd)
gen.send(ready)
- elif s == WAIT_RW:
- loop.add_reader(fd, q.put_nowait, READY_R)
- loop.add_writer(fd, q.put_nowait, READY_W)
+ elif s is Wait.RW:
+ loop.add_reader(fd, q.put_nowait, Ready.R)
+ loop.add_writer(fd, q.put_nowait, Ready.W)
ready = await q.get()
loop.remove_reader(fd)
loop.remove_writer(fd)
def loop():
"""Return the async loop to test coroutines."""
return asyncio.get_event_loop()
+
+
+@pytest.fixture
+def aconn(loop, dsn):
+ """Return an `AsyncConnection` connected to the ``--test-dsn`` database."""
+ from psycopg3 import AsyncConnection
+
+ conn = loop.run_until_complete(AsyncConnection.connect(dsn))
+ return conn
def pgconn(pq, dsn):
"""Return a PGconn connection open to `--test-dsn`."""
return pq.PGconn.connect(dsn.encode("utf8"))
+
+
+@pytest.fixture
+def conn(dsn):
+ """Return a `Connection` connected to the ``--test-dsn`` database."""
+ from psycopg3 import Connection
+
+ return Connection.connect(dsn)
def test_connect_bad(loop):
with pytest.raises(psycopg3.OperationalError):
loop.run_until_complete(AsyncConnection.connect("dbname=nosuchdb"))
+
+
+def test_commit(loop, pq, aconn):
+ aconn.pgconn.exec_(b"drop table if exists foo")
+ aconn.pgconn.exec_(b"create table foo (id int primary key)")
+ aconn.pgconn.exec_(b"begin")
+ assert (
+ aconn.pgconn.transaction_status == pq.TransactionStatus.PQTRANS_INTRANS
+ )
+ res = aconn.pgconn.exec_(b"insert into foo values (1)")
+ loop.run_until_complete(aconn.commit())
+ assert aconn.pgconn.transaction_status == pq.TransactionStatus.PQTRANS_IDLE
+ res = aconn.pgconn.exec_(b"select id from foo where id = 1")
+ assert res.get_value(0, 0) == b"1"
+
+
+def test_rollback(loop, pq, aconn):
+ aconn.pgconn.exec_(b"drop table if exists foo")
+ aconn.pgconn.exec_(b"create table foo (id int primary key)")
+ aconn.pgconn.exec_(b"begin")
+ assert (
+ aconn.pgconn.transaction_status == pq.TransactionStatus.PQTRANS_INTRANS
+ )
+ res = aconn.pgconn.exec_(b"insert into foo values (1)")
+ loop.run_until_complete(aconn.rollback())
+ assert aconn.pgconn.transaction_status == pq.TransactionStatus.PQTRANS_IDLE
+ res = aconn.pgconn.exec_(b"select id from foo where id = 1")
+ assert res.get_value(0, 0) is None
def test_connect_bad():
with pytest.raises(psycopg3.OperationalError):
Connection.connect("dbname=nosuchdb")
+
+
+def test_commit(pq, conn):
+ conn.pgconn.exec_(b"drop table if exists foo")
+ conn.pgconn.exec_(b"create table foo (id int primary key)")
+ conn.pgconn.exec_(b"begin")
+ assert (
+ conn.pgconn.transaction_status == pq.TransactionStatus.PQTRANS_INTRANS
+ )
+ res = conn.pgconn.exec_(b"insert into foo values (1)")
+ conn.commit()
+ assert conn.pgconn.transaction_status == pq.TransactionStatus.PQTRANS_IDLE
+ res = conn.pgconn.exec_(b"select id from foo where id = 1")
+ assert res.get_value(0, 0) == b"1"
+
+
+def test_rollback(pq, conn):
+ conn.pgconn.exec_(b"drop table if exists foo")
+ conn.pgconn.exec_(b"create table foo (id int primary key)")
+ conn.pgconn.exec_(b"begin")
+ assert (
+ conn.pgconn.transaction_status == pq.TransactionStatus.PQTRANS_INTRANS
+ )
+ res = conn.pgconn.exec_(b"insert into foo values (1)")
+ conn.rollback()
+ assert conn.pgconn.transaction_status == pq.TransactionStatus.PQTRANS_IDLE
+ res = conn.pgconn.exec_(b"select id from foo where id = 1")
+ assert res.get_value(0, 0) is None