]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added commit/rollback connection methods
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 20 Mar 2020 13:40:19 +0000 (02:40 +1300)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 20 Mar 2020 13:40:19 +0000 (02:40 +1300)
psycopg3/connection.py
psycopg3/pq/pq_ctypes.py
psycopg3/waiting.py
tests/fix_async.py
tests/fix_db.py
tests/test_async_connection.py
tests/test_connection.py

index f96b77f3cc2e0128811c4309cd3b038c39d62777..5a5c5d59b9ec5ca7a851d97849daedb95579a9ee 100644 (file)
@@ -5,11 +5,12 @@ psycopg3 connection objects
 # Copyright (C) 2020 The Psycopg Team
 
 import logging
+from threading import Lock
 
 from . import pq
 from . import exceptions as exc
 from .conninfo import make_conninfo
-from .waiting import wait_select, wait_async, WAIT_R, WAIT_W
+from .waiting import wait_select, wait_async, Wait, Ready
 
 logger = logging.getLogger(__name__)
 
@@ -24,11 +25,16 @@ class BaseConnection:
 
     def __init__(self, pgconn):
         self.pgconn = pgconn
+        self.lock = Lock()
 
     @classmethod
     def _connect_gen(cls, conninfo):
         """
-        Generator yielding connection states and returning a done connection.
+        Generator to create a database connection using without blocking.
+
+        Yield pairs (fileno, `Wait`) whenever an operation would block. The
+        generator can be restarted sending the appropriate `Ready` state when
+        the file descriptor is ready.
         """
         conninfo = conninfo.encode("utf8")
 
@@ -45,9 +51,9 @@ class BaseConnection:
             if status == pq.PollingStatus.PGRES_POLLING_OK:
                 break
             elif status == pq.PollingStatus.PGRES_POLLING_READING:
-                yield conn.socket, WAIT_R
+                yield conn.socket, Wait.R
             elif status == pq.PollingStatus.PGRES_POLLING_WRITING:
-                yield conn.socket, WAIT_W
+                yield conn.socket, Wait.W
             elif status == pq.PollingStatus.PGRES_POLLING_FAILED:
                 raise exc.OperationalError(
                     f"connection failed: {pq.error_message(conn)}"
@@ -58,6 +64,45 @@ class BaseConnection:
         conn.nonblocking = 1
         return conn
 
+    @classmethod
+    def _exec_gen(cls, pgconn):
+        """
+        Generator returning query results without blocking.
+
+        The query must have already been sent using `pgconn.send_query()` or
+        similar. Flush the query and then return the result using nonblocking
+        functions.
+
+        Yield pairs (fileno, `Wait`) whenever an operation would block. The
+        generator can be restarted sending the appropriate `Ready` state when
+        the file descriptor is ready.
+
+        Return the list of results returned by the database (whether success
+        or error).
+        """
+        results = []
+
+        while 1:
+            f = pgconn.flush()
+            if f == 0:
+                break
+
+            ready = yield pgconn.socket, Wait.RW
+            if ready is Ready.R:
+                pgconn.consume_input()
+            continue
+
+        while 1:
+            pgconn.consume_input()
+            if pgconn.is_busy():
+                ready = yield pgconn.socket, Wait.R
+            res = pgconn.get_result()
+            if res is None:
+                break
+            results.append(res)
+
+        return results
+
 
 class Connection(BaseConnection):
     """
@@ -67,12 +112,34 @@ class Connection(BaseConnection):
     """
 
     @classmethod
-    def connect(cls, conninfo, **kwargs):
+    def connect(cls, conninfo, connection_factory=None, **kwargs):
+        if connection_factory is not None:
+            raise NotImplementedError()
         conninfo = make_conninfo(conninfo, **kwargs)
         gen = cls._connect_gen(conninfo)
         pgconn = wait_select(gen)
         return cls(pgconn)
 
+    def commit(self):
+        self._exec_commit_rollback(b"commit")
+
+    def rollback(self):
+        self._exec_commit_rollback(b"rollback")
+
+    def _exec_commit_rollback(self, command):
+        with self.lock:
+            status = self.pgconn.transaction_status
+            if status == pq.TransactionStatus.PQTRANS_IDLE:
+                return
+
+            self.pgconn.send_query(command)
+            (pgres,) = wait_select(self._exec_gen(self.pgconn))
+            if pgres.status != pq.ExecStatus.PGRES_COMMAND_OK:
+                raise exc.OperationalError(
+                    f"error on {command.decode('utf8')}:"
+                    f" {pq.error_message(pgres)}"
+                )
+
 
 class AsyncConnection(BaseConnection):
     """
@@ -88,3 +155,23 @@ class AsyncConnection(BaseConnection):
         gen = cls._connect_gen(conninfo)
         pgconn = await wait_async(gen)
         return cls(pgconn)
+
+    async def commit(self):
+        await self._exec_commit_rollback(b"commit")
+
+    async def rollback(self):
+        await self._exec_commit_rollback(b"rollback")
+
+    async def _exec_commit_rollback(self, command):
+        with self.lock:
+            status = self.pgconn.transaction_status
+            if status == pq.TransactionStatus.PQTRANS_IDLE:
+                return
+
+            self.pgconn.send_query(command)
+            (pgres,) = await wait_async(self._exec_gen(self.pgconn))
+            if pgres.status != pq.ExecStatus.PGRES_COMMAND_OK:
+                raise exc.OperationalError(
+                    f"error on {command.decode('utf8')}:"
+                    f" {pq.error_message(pgres)}"
+                )
index e6822796dd3bffd59a80d649d3bc16482580feaa..b2aaf3e961a83e7fe476253e971280774b0f81b4 100644 (file)
@@ -84,8 +84,7 @@ class PGconn:
         impl.PQreset(self.pgconn_ptr)
 
     def reset_start(self):
-        rv = impl.PQresetStart(self.pgconn_ptr)
-        if rv == 0:
+        if not impl.PQresetStart(self.pgconn_ptr):
             raise PQerror("couldn't reset connection")
 
     def reset_poll(self):
@@ -194,7 +193,8 @@ class PGconn:
             raise TypeError(
                 "bytes expected, got %s instead" % type(command).__name__
             )
-        return impl.PQsendQuery(self.pgconn_ptr, command)
+        if not impl.PQsendQuery(self.pgconn_ptr, command):
+            raise PQerror(f"sending query failed: {error_message(self)}")
 
     def exec_params(
         self,
index 3aca5ba903442c5d742750b09b16cc2aa98e641a..5023a36b86e0212d10ff3b749f7426133e537a05 100644 (file)
@@ -5,6 +5,7 @@ Code concerned with waiting in different contexts (blocking, async, etc).
 # Copyright (C) 2020 The Psycopg Team
 
 
+from enum import Enum
 from select import select
 from asyncio import get_event_loop
 from asyncio.queues import Queue
@@ -12,11 +13,8 @@ from asyncio.queues import Queue
 from . import exceptions as exc
 
 
-WAIT_R = "WAIT_R"
-WAIT_W = "WAIT_W"
-WAIT_RW = "WAIT_RW"
-READY_R = "READY_R"
-READY_W = "READY_W"
+Wait = Enum("Wait", "R W RW")
+Ready = Enum("Ready", "R W")
 
 
 def wait_select(gen):
@@ -32,22 +30,22 @@ def wait_select(gen):
     try:
         while 1:
             fd, s = next(gen)
-            if s == WAIT_R:
+            if s is Wait.R:
                 rf, wf, xf = select([fd], [], [])
                 assert rf
-                gen.send(READY_R)
-            elif s == WAIT_W:
+                gen.send(Ready.R)
+            elif s is Wait.W:
                 rf, wf, xf = select([], [fd], [])
                 assert wf
-                gen.send(READY_W)
-            elif s == WAIT_RW:
+                gen.send(Ready.W)
+            elif s is Wait.RW:
                 rf, wf, xf = select([fd], [fd], [])
                 assert rf or wf
                 assert not (rf and wf)
                 if rf:
-                    gen.send(READY_R)
+                    gen.send(Ready.R)
                 else:
-                    gen.send(READY_W)
+                    gen.send(Ready.W)
             else:
                 raise exc.InternalError("bad poll status: %s")
     except StopIteration as e:
@@ -71,19 +69,19 @@ async def wait_async(gen):
     try:
         while 1:
             fd, s = next(gen)
-            if s == WAIT_R:
-                loop.add_reader(fd, q.put_nowait, READY_R)
+            if s is Wait.R:
+                loop.add_reader(fd, q.put_nowait, Ready.R)
                 ready = await q.get()
                 loop.remove_reader(fd)
                 gen.send(ready)
-            elif s == WAIT_W:
-                loop.add_writer(fd, q.put_nowait, READY_W)
+            elif s is Wait.W:
+                loop.add_writer(fd, q.put_nowait, Ready.W)
                 ready = await q.get()
                 loop.remove_writer(fd)
                 gen.send(ready)
-            elif s == WAIT_RW:
-                loop.add_reader(fd, q.put_nowait, READY_R)
-                loop.add_writer(fd, q.put_nowait, READY_W)
+            elif s is Wait.RW:
+                loop.add_reader(fd, q.put_nowait, Ready.R)
+                loop.add_writer(fd, q.put_nowait, Ready.W)
                 ready = await q.get()
                 loop.remove_reader(fd)
                 loop.remove_writer(fd)
index 3a29d1d4777daab6a4e0157d499f1218a52fd968..fb7c0ac0c93e649d370e3b64cddfbb4a5223e6ee 100644 (file)
@@ -7,3 +7,12 @@ import pytest
 def loop():
     """Return the async loop to test coroutines."""
     return asyncio.get_event_loop()
+
+
+@pytest.fixture
+def aconn(loop, dsn):
+    """Return an `AsyncConnection` connected to the ``--test-dsn`` database."""
+    from psycopg3 import AsyncConnection
+
+    conn = loop.run_until_complete(AsyncConnection.connect(dsn))
+    return conn
index 481dad724d637d6fca7d03b5477d77c07d6d3fdf..80d4f419c4c577b171750691d789eebfa8b1ea52 100644 (file)
@@ -33,3 +33,11 @@ def dsn(request):
 def pgconn(pq, dsn):
     """Return a PGconn connection open to `--test-dsn`."""
     return pq.PGconn.connect(dsn.encode("utf8"))
+
+
+@pytest.fixture
+def conn(dsn):
+    """Return a `Connection` connected to the ``--test-dsn`` database."""
+    from psycopg3 import Connection
+
+    return Connection.connect(dsn)
index ba42440f83f102c8868d7742053290e148091b10..992760d089eddf57b16d24de396e0f63c2dcfc49 100644 (file)
@@ -12,3 +12,31 @@ def test_connect(pq, dsn, loop):
 def test_connect_bad(loop):
     with pytest.raises(psycopg3.OperationalError):
         loop.run_until_complete(AsyncConnection.connect("dbname=nosuchdb"))
+
+
+def test_commit(loop, pq, aconn):
+    aconn.pgconn.exec_(b"drop table if exists foo")
+    aconn.pgconn.exec_(b"create table foo (id int primary key)")
+    aconn.pgconn.exec_(b"begin")
+    assert (
+        aconn.pgconn.transaction_status == pq.TransactionStatus.PQTRANS_INTRANS
+    )
+    res = aconn.pgconn.exec_(b"insert into foo values (1)")
+    loop.run_until_complete(aconn.commit())
+    assert aconn.pgconn.transaction_status == pq.TransactionStatus.PQTRANS_IDLE
+    res = aconn.pgconn.exec_(b"select id from foo where id = 1")
+    assert res.get_value(0, 0) == b"1"
+
+
+def test_rollback(loop, pq, aconn):
+    aconn.pgconn.exec_(b"drop table if exists foo")
+    aconn.pgconn.exec_(b"create table foo (id int primary key)")
+    aconn.pgconn.exec_(b"begin")
+    assert (
+        aconn.pgconn.transaction_status == pq.TransactionStatus.PQTRANS_INTRANS
+    )
+    res = aconn.pgconn.exec_(b"insert into foo values (1)")
+    loop.run_until_complete(aconn.rollback())
+    assert aconn.pgconn.transaction_status == pq.TransactionStatus.PQTRANS_IDLE
+    res = aconn.pgconn.exec_(b"select id from foo where id = 1")
+    assert res.get_value(0, 0) is None
index 74706eafb1a256024f6a7fa761f370ef716f7c70..69e0532c6c5eeacc545d34248d8cebe595cb983c 100644 (file)
@@ -12,3 +12,31 @@ def test_connect(pq, dsn):
 def test_connect_bad():
     with pytest.raises(psycopg3.OperationalError):
         Connection.connect("dbname=nosuchdb")
+
+
+def test_commit(pq, conn):
+    conn.pgconn.exec_(b"drop table if exists foo")
+    conn.pgconn.exec_(b"create table foo (id int primary key)")
+    conn.pgconn.exec_(b"begin")
+    assert (
+        conn.pgconn.transaction_status == pq.TransactionStatus.PQTRANS_INTRANS
+    )
+    res = conn.pgconn.exec_(b"insert into foo values (1)")
+    conn.commit()
+    assert conn.pgconn.transaction_status == pq.TransactionStatus.PQTRANS_IDLE
+    res = conn.pgconn.exec_(b"select id from foo where id = 1")
+    assert res.get_value(0, 0) == b"1"
+
+
+def test_rollback(pq, conn):
+    conn.pgconn.exec_(b"drop table if exists foo")
+    conn.pgconn.exec_(b"create table foo (id int primary key)")
+    conn.pgconn.exec_(b"begin")
+    assert (
+        conn.pgconn.transaction_status == pq.TransactionStatus.PQTRANS_INTRANS
+    )
+    res = conn.pgconn.exec_(b"insert into foo values (1)")
+    conn.rollback()
+    assert conn.pgconn.transaction_status == pq.TransactionStatus.PQTRANS_IDLE
+    res = conn.pgconn.exec_(b"select id from foo where id = 1")
+    assert res.get_value(0, 0) is None