From: Daniele Varrazzo Date: Fri, 20 Mar 2020 13:40:19 +0000 (+1300) Subject: Added commit/rollback connection methods X-Git-Tag: 3.0.dev0~690 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=a27dad3937497902903649351969dc7ad6572406;p=thirdparty%2Fpsycopg.git Added commit/rollback connection methods --- diff --git a/psycopg3/connection.py b/psycopg3/connection.py index f96b77f3c..5a5c5d59b 100644 --- a/psycopg3/connection.py +++ b/psycopg3/connection.py @@ -5,11 +5,12 @@ psycopg3 connection objects # Copyright (C) 2020 The Psycopg Team import logging +from threading import Lock from . import pq from . import exceptions as exc from .conninfo import make_conninfo -from .waiting import wait_select, wait_async, WAIT_R, WAIT_W +from .waiting import wait_select, wait_async, Wait, Ready logger = logging.getLogger(__name__) @@ -24,11 +25,16 @@ class BaseConnection: def __init__(self, pgconn): self.pgconn = pgconn + self.lock = Lock() @classmethod def _connect_gen(cls, conninfo): """ - Generator yielding connection states and returning a done connection. + Generator to create a database connection using without blocking. + + Yield pairs (fileno, `Wait`) whenever an operation would block. The + generator can be restarted sending the appropriate `Ready` state when + the file descriptor is ready. """ conninfo = conninfo.encode("utf8") @@ -45,9 +51,9 @@ class BaseConnection: if status == pq.PollingStatus.PGRES_POLLING_OK: break elif status == pq.PollingStatus.PGRES_POLLING_READING: - yield conn.socket, WAIT_R + yield conn.socket, Wait.R elif status == pq.PollingStatus.PGRES_POLLING_WRITING: - yield conn.socket, WAIT_W + yield conn.socket, Wait.W elif status == pq.PollingStatus.PGRES_POLLING_FAILED: raise exc.OperationalError( f"connection failed: {pq.error_message(conn)}" @@ -58,6 +64,45 @@ class BaseConnection: conn.nonblocking = 1 return conn + @classmethod + def _exec_gen(cls, pgconn): + """ + Generator returning query results without blocking. + + The query must have already been sent using `pgconn.send_query()` or + similar. Flush the query and then return the result using nonblocking + functions. + + Yield pairs (fileno, `Wait`) whenever an operation would block. The + generator can be restarted sending the appropriate `Ready` state when + the file descriptor is ready. + + Return the list of results returned by the database (whether success + or error). + """ + results = [] + + while 1: + f = pgconn.flush() + if f == 0: + break + + ready = yield pgconn.socket, Wait.RW + if ready is Ready.R: + pgconn.consume_input() + continue + + while 1: + pgconn.consume_input() + if pgconn.is_busy(): + ready = yield pgconn.socket, Wait.R + res = pgconn.get_result() + if res is None: + break + results.append(res) + + return results + class Connection(BaseConnection): """ @@ -67,12 +112,34 @@ class Connection(BaseConnection): """ @classmethod - def connect(cls, conninfo, **kwargs): + def connect(cls, conninfo, connection_factory=None, **kwargs): + if connection_factory is not None: + raise NotImplementedError() conninfo = make_conninfo(conninfo, **kwargs) gen = cls._connect_gen(conninfo) pgconn = wait_select(gen) return cls(pgconn) + def commit(self): + self._exec_commit_rollback(b"commit") + + def rollback(self): + self._exec_commit_rollback(b"rollback") + + def _exec_commit_rollback(self, command): + with self.lock: + status = self.pgconn.transaction_status + if status == pq.TransactionStatus.PQTRANS_IDLE: + return + + self.pgconn.send_query(command) + (pgres,) = wait_select(self._exec_gen(self.pgconn)) + if pgres.status != pq.ExecStatus.PGRES_COMMAND_OK: + raise exc.OperationalError( + f"error on {command.decode('utf8')}:" + f" {pq.error_message(pgres)}" + ) + class AsyncConnection(BaseConnection): """ @@ -88,3 +155,23 @@ class AsyncConnection(BaseConnection): gen = cls._connect_gen(conninfo) pgconn = await wait_async(gen) return cls(pgconn) + + async def commit(self): + await self._exec_commit_rollback(b"commit") + + async def rollback(self): + await self._exec_commit_rollback(b"rollback") + + async def _exec_commit_rollback(self, command): + with self.lock: + status = self.pgconn.transaction_status + if status == pq.TransactionStatus.PQTRANS_IDLE: + return + + self.pgconn.send_query(command) + (pgres,) = await wait_async(self._exec_gen(self.pgconn)) + if pgres.status != pq.ExecStatus.PGRES_COMMAND_OK: + raise exc.OperationalError( + f"error on {command.decode('utf8')}:" + f" {pq.error_message(pgres)}" + ) diff --git a/psycopg3/pq/pq_ctypes.py b/psycopg3/pq/pq_ctypes.py index e6822796d..b2aaf3e96 100644 --- a/psycopg3/pq/pq_ctypes.py +++ b/psycopg3/pq/pq_ctypes.py @@ -84,8 +84,7 @@ class PGconn: impl.PQreset(self.pgconn_ptr) def reset_start(self): - rv = impl.PQresetStart(self.pgconn_ptr) - if rv == 0: + if not impl.PQresetStart(self.pgconn_ptr): raise PQerror("couldn't reset connection") def reset_poll(self): @@ -194,7 +193,8 @@ class PGconn: raise TypeError( "bytes expected, got %s instead" % type(command).__name__ ) - return impl.PQsendQuery(self.pgconn_ptr, command) + if not impl.PQsendQuery(self.pgconn_ptr, command): + raise PQerror(f"sending query failed: {error_message(self)}") def exec_params( self, diff --git a/psycopg3/waiting.py b/psycopg3/waiting.py index 3aca5ba90..5023a36b8 100644 --- a/psycopg3/waiting.py +++ b/psycopg3/waiting.py @@ -5,6 +5,7 @@ Code concerned with waiting in different contexts (blocking, async, etc). # Copyright (C) 2020 The Psycopg Team +from enum import Enum from select import select from asyncio import get_event_loop from asyncio.queues import Queue @@ -12,11 +13,8 @@ from asyncio.queues import Queue from . import exceptions as exc -WAIT_R = "WAIT_R" -WAIT_W = "WAIT_W" -WAIT_RW = "WAIT_RW" -READY_R = "READY_R" -READY_W = "READY_W" +Wait = Enum("Wait", "R W RW") +Ready = Enum("Ready", "R W") def wait_select(gen): @@ -32,22 +30,22 @@ def wait_select(gen): try: while 1: fd, s = next(gen) - if s == WAIT_R: + if s is Wait.R: rf, wf, xf = select([fd], [], []) assert rf - gen.send(READY_R) - elif s == WAIT_W: + gen.send(Ready.R) + elif s is Wait.W: rf, wf, xf = select([], [fd], []) assert wf - gen.send(READY_W) - elif s == WAIT_RW: + gen.send(Ready.W) + elif s is Wait.RW: rf, wf, xf = select([fd], [fd], []) assert rf or wf assert not (rf and wf) if rf: - gen.send(READY_R) + gen.send(Ready.R) else: - gen.send(READY_W) + gen.send(Ready.W) else: raise exc.InternalError("bad poll status: %s") except StopIteration as e: @@ -71,19 +69,19 @@ async def wait_async(gen): try: while 1: fd, s = next(gen) - if s == WAIT_R: - loop.add_reader(fd, q.put_nowait, READY_R) + if s is Wait.R: + loop.add_reader(fd, q.put_nowait, Ready.R) ready = await q.get() loop.remove_reader(fd) gen.send(ready) - elif s == WAIT_W: - loop.add_writer(fd, q.put_nowait, READY_W) + elif s is Wait.W: + loop.add_writer(fd, q.put_nowait, Ready.W) ready = await q.get() loop.remove_writer(fd) gen.send(ready) - elif s == WAIT_RW: - loop.add_reader(fd, q.put_nowait, READY_R) - loop.add_writer(fd, q.put_nowait, READY_W) + elif s is Wait.RW: + loop.add_reader(fd, q.put_nowait, Ready.R) + loop.add_writer(fd, q.put_nowait, Ready.W) ready = await q.get() loop.remove_reader(fd) loop.remove_writer(fd) diff --git a/tests/fix_async.py b/tests/fix_async.py index 3a29d1d47..fb7c0ac0c 100644 --- a/tests/fix_async.py +++ b/tests/fix_async.py @@ -7,3 +7,12 @@ import pytest def loop(): """Return the async loop to test coroutines.""" return asyncio.get_event_loop() + + +@pytest.fixture +def aconn(loop, dsn): + """Return an `AsyncConnection` connected to the ``--test-dsn`` database.""" + from psycopg3 import AsyncConnection + + conn = loop.run_until_complete(AsyncConnection.connect(dsn)) + return conn diff --git a/tests/fix_db.py b/tests/fix_db.py index 481dad724..80d4f419c 100644 --- a/tests/fix_db.py +++ b/tests/fix_db.py @@ -33,3 +33,11 @@ def dsn(request): def pgconn(pq, dsn): """Return a PGconn connection open to `--test-dsn`.""" return pq.PGconn.connect(dsn.encode("utf8")) + + +@pytest.fixture +def conn(dsn): + """Return a `Connection` connected to the ``--test-dsn`` database.""" + from psycopg3 import Connection + + return Connection.connect(dsn) diff --git a/tests/test_async_connection.py b/tests/test_async_connection.py index ba42440f8..992760d08 100644 --- a/tests/test_async_connection.py +++ b/tests/test_async_connection.py @@ -12,3 +12,31 @@ def test_connect(pq, dsn, loop): def test_connect_bad(loop): with pytest.raises(psycopg3.OperationalError): loop.run_until_complete(AsyncConnection.connect("dbname=nosuchdb")) + + +def test_commit(loop, pq, aconn): + aconn.pgconn.exec_(b"drop table if exists foo") + aconn.pgconn.exec_(b"create table foo (id int primary key)") + aconn.pgconn.exec_(b"begin") + assert ( + aconn.pgconn.transaction_status == pq.TransactionStatus.PQTRANS_INTRANS + ) + res = aconn.pgconn.exec_(b"insert into foo values (1)") + loop.run_until_complete(aconn.commit()) + assert aconn.pgconn.transaction_status == pq.TransactionStatus.PQTRANS_IDLE + res = aconn.pgconn.exec_(b"select id from foo where id = 1") + assert res.get_value(0, 0) == b"1" + + +def test_rollback(loop, pq, aconn): + aconn.pgconn.exec_(b"drop table if exists foo") + aconn.pgconn.exec_(b"create table foo (id int primary key)") + aconn.pgconn.exec_(b"begin") + assert ( + aconn.pgconn.transaction_status == pq.TransactionStatus.PQTRANS_INTRANS + ) + res = aconn.pgconn.exec_(b"insert into foo values (1)") + loop.run_until_complete(aconn.rollback()) + assert aconn.pgconn.transaction_status == pq.TransactionStatus.PQTRANS_IDLE + res = aconn.pgconn.exec_(b"select id from foo where id = 1") + assert res.get_value(0, 0) is None diff --git a/tests/test_connection.py b/tests/test_connection.py index 74706eafb..69e0532c6 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -12,3 +12,31 @@ def test_connect(pq, dsn): def test_connect_bad(): with pytest.raises(psycopg3.OperationalError): Connection.connect("dbname=nosuchdb") + + +def test_commit(pq, conn): + conn.pgconn.exec_(b"drop table if exists foo") + conn.pgconn.exec_(b"create table foo (id int primary key)") + conn.pgconn.exec_(b"begin") + assert ( + conn.pgconn.transaction_status == pq.TransactionStatus.PQTRANS_INTRANS + ) + res = conn.pgconn.exec_(b"insert into foo values (1)") + conn.commit() + assert conn.pgconn.transaction_status == pq.TransactionStatus.PQTRANS_IDLE + res = conn.pgconn.exec_(b"select id from foo where id = 1") + assert res.get_value(0, 0) == b"1" + + +def test_rollback(pq, conn): + conn.pgconn.exec_(b"drop table if exists foo") + conn.pgconn.exec_(b"create table foo (id int primary key)") + conn.pgconn.exec_(b"begin") + assert ( + conn.pgconn.transaction_status == pq.TransactionStatus.PQTRANS_INTRANS + ) + res = conn.pgconn.exec_(b"insert into foo values (1)") + conn.rollback() + assert conn.pgconn.transaction_status == pq.TransactionStatus.PQTRANS_IDLE + res = conn.pgconn.exec_(b"select id from foo where id = 1") + assert res.get_value(0, 0) is None