]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added automatic transaction start
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 20 May 2020 05:22:56 +0000 (17:22 +1200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 20 May 2020 05:22:56 +0000 (17:22 +1200)
psycopg3/connection.py
psycopg3/cursor.py
tests/fix_async.py
tests/fix_db.py
tests/test_async_connection.py
tests/test_async_cursor.py
tests/test_connection.py
tests/test_cursor.py
tests/types/test_composite.py

index bcf45dd153734ab8eca99665754809a787bc3d6e..522fa1ce362a38a906ee6684eb73ebc61437abbb 100644 (file)
@@ -14,6 +14,7 @@ from . import pq
 from . import errors as e
 from . import cursor
 from . import proto
+from .pq import TransactionStatus, ExecStatus
 from .conninfo import make_conninfo
 from .waiting import wait, wait_async
 
@@ -153,6 +154,19 @@ class Connection(BaseConnection):
         cur = super().cursor(name, binary)
         return cast(cursor.Cursor, cur)
 
+    def _start_query(self) -> None:
+        # the function is meant to be called by a cursor once the lock is taken
+        status = self.pgconn.transaction_status
+        if status == TransactionStatus.INTRANS:
+            return
+
+        self.pgconn.send_query(b"begin")
+        (pgres,) = self.wait(execute(self.pgconn))
+        if pgres.status != ExecStatus.COMMAND_OK:
+            raise e.OperationalError(
+                f"error on begin: {pq.error_message(pgres)}"
+            )
+
     def commit(self) -> None:
         self._exec_commit_rollback(b"commit")
 
@@ -162,12 +176,12 @@ class Connection(BaseConnection):
     def _exec_commit_rollback(self, command: bytes) -> None:
         with self.lock:
             status = self.pgconn.transaction_status
-            if status == pq.TransactionStatus.IDLE:
+            if status == TransactionStatus.IDLE:
                 return
 
             self.pgconn.send_query(command)
             (pgres,) = self.wait(execute(self.pgconn))
-            if pgres.status != pq.ExecStatus.COMMAND_OK:
+            if pgres.status != ExecStatus.COMMAND_OK:
                 raise e.OperationalError(
                     f"error on {command.decode('utf8')}:"
                     f" {pq.error_message(pgres)}"
@@ -187,7 +201,7 @@ class Connection(BaseConnection):
             )
             gen = execute(self.pgconn)
             (result,) = self.wait(gen)
-            if result.status != pq.ExecStatus.TUPLES_OK:
+            if result.status != ExecStatus.TUPLES_OK:
                 raise e.error_from_result(result)
 
 
@@ -226,6 +240,19 @@ class AsyncConnection(BaseConnection):
         cur = super().cursor(name, binary)
         return cast(cursor.AsyncCursor, cur)
 
+    async def _start_query(self) -> None:
+        # the function is meant to be called by a cursor once the lock is taken
+        status = self.pgconn.transaction_status
+        if status == TransactionStatus.INTRANS:
+            return
+
+        self.pgconn.send_query(b"begin")
+        (pgres,) = await self.wait(execute(self.pgconn))
+        if pgres.status != ExecStatus.COMMAND_OK:
+            raise e.OperationalError(
+                f"error on begin: {pq.error_message(pgres)}"
+            )
+
     async def commit(self) -> None:
         await self._exec_commit_rollback(b"commit")
 
@@ -235,12 +262,12 @@ class AsyncConnection(BaseConnection):
     async def _exec_commit_rollback(self, command: bytes) -> None:
         async with self.lock:
             status = self.pgconn.transaction_status
-            if status == pq.TransactionStatus.IDLE:
+            if status == TransactionStatus.IDLE:
                 return
 
             self.pgconn.send_query(command)
             (pgres,) = await self.wait(execute(self.pgconn))
-            if pgres.status != pq.ExecStatus.COMMAND_OK:
+            if pgres.status != ExecStatus.COMMAND_OK:
                 raise e.OperationalError(
                     f"error on {command.decode('utf8')}:"
                     f" {pq.error_message(pgres)}"
@@ -258,5 +285,5 @@ class AsyncConnection(BaseConnection):
             )
             gen = execute(self.pgconn)
             (result,) = await self.wait(gen)
-            if result.status != pq.ExecStatus.TUPLES_OK:
+            if result.status != ExecStatus.TUPLES_OK:
                 raise e.error_from_result(result)
index 3582bccca640d53f54703cb8c0bab14f41a09d87..867515e39b71183fc5b9ecb5e6e91139146cba67 100644 (file)
@@ -263,6 +263,7 @@ class Cursor(BaseCursor):
     def execute(self, query: Query, vars: Optional[Params] = None) -> "Cursor":
         with self.connection.lock:
             self._start_query()
+            self.connection._start_query()
             self._execute_send(query, vars)
             gen = execute(self.connection.pgconn)
             results = self.connection.wait(gen)
@@ -274,8 +275,10 @@ class Cursor(BaseCursor):
     ) -> "Cursor":
         with self.connection.lock:
             self._start_query()
-            for i, vars in enumerate(vars_seq):
-                if i == 0:
+            self.connection._start_query()
+            first = True
+            for vars in vars_seq:
+                if first:
                     pgq = self._send_prepare(b"", query, vars)
                     gen = execute(self.connection.pgconn)
                     (result,) = self.connection.wait(gen)
@@ -350,6 +353,7 @@ class AsyncCursor(BaseCursor):
     ) -> "AsyncCursor":
         async with self.connection.lock:
             self._start_query()
+            await self.connection._start_query()
             self._execute_send(query, vars)
             gen = execute(self.connection.pgconn)
             results = await self.connection.wait(gen)
@@ -361,8 +365,10 @@ class AsyncCursor(BaseCursor):
     ) -> "AsyncCursor":
         async with self.connection.lock:
             self._start_query()
-            for i, vars in enumerate(vars_seq):
-                if i == 0:
+            await self.connection._start_query()
+            first = True
+            for vars in vars_seq:
+                if first:
                     pgq = self._send_prepare(b"", query, vars)
                     gen = execute(self.connection.pgconn)
                     (result,) = await self.connection.wait(gen)
index 5bd4c0b700f29dedbea022f176c8cba8ad204954..9b119a3babe14d1dec13020684a049b56874a498 100644 (file)
@@ -15,4 +15,5 @@ def aconn(loop, dsn, pq):
     from psycopg3 import AsyncConnection
 
     conn = loop.run_until_complete(AsyncConnection.connect(dsn))
-    return conn
+    yield conn
+    loop.run_until_complete(conn.close())
index 8ac7cf3c868dacfc2258b21d75877d370b64bcde..6f0f06449a7f27e0871a9749a2008f357d258d65 100644 (file)
@@ -30,7 +30,8 @@ def pgconn(pq, dsn):
         pytest.fail(
             f"bad connection: {conn.error_message.decode('utf8', 'replace')}"
         )
-    return conn
+    yield conn
+    conn.finish()
 
 
 @pytest.fixture
@@ -38,7 +39,9 @@ def conn(dsn):
     """Return a `Connection` connected to the ``--test-dsn`` database."""
     from psycopg3 import Connection
 
-    return Connection.connect(dsn)
+    conn = Connection.connect(dsn)
+    yield conn
+    conn.close()
 
 
 @pytest.fixture(scope="session")
@@ -48,4 +51,6 @@ def svcconn(dsn):
     """
     from psycopg3 import Connection
 
-    return Connection.connect(dsn)
+    conn = Connection.connect(dsn)
+    yield conn
+    conn.close()
index f5a8875e90bdc4857611bc9e34e45ae046de244f..8192153b9b6599d5ef73aad2ddb24804c2fc9a26 100644 (file)
@@ -30,7 +30,7 @@ def test_commit(loop, aconn):
     aconn.pgconn.exec_(b"create table foo (id int primary key)")
     aconn.pgconn.exec_(b"begin")
     assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS
-    res = aconn.pgconn.exec_(b"insert into foo values (1)")
+    aconn.pgconn.exec_(b"insert into foo values (1)")
     loop.run_until_complete(aconn.commit())
     assert aconn.pgconn.transaction_status == aconn.TransactionStatus.IDLE
     res = aconn.pgconn.exec_(b"select id from foo where id = 1")
@@ -46,17 +46,55 @@ def test_rollback(loop, aconn):
     aconn.pgconn.exec_(b"create table foo (id int primary key)")
     aconn.pgconn.exec_(b"begin")
     assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS
-    res = aconn.pgconn.exec_(b"insert into foo values (1)")
+    aconn.pgconn.exec_(b"insert into foo values (1)")
     loop.run_until_complete(aconn.rollback())
     assert aconn.pgconn.transaction_status == aconn.TransactionStatus.IDLE
     res = aconn.pgconn.exec_(b"select id from foo where id = 1")
-    assert res.get_value(0, 0) is None
+    assert res.ntuples == 0
 
     loop.run_until_complete(aconn.close())
     with pytest.raises(psycopg3.OperationalError):
         loop.run_until_complete(aconn.rollback())
 
 
+def test_auto_transaction(loop, aconn):
+    aconn.pgconn.exec_(b"drop table if exists foo")
+    aconn.pgconn.exec_(b"create table foo (id int primary key)")
+
+    cur = aconn.cursor()
+    assert aconn.pgconn.transaction_status == aconn.TransactionStatus.IDLE
+
+    loop.run_until_complete(cur.execute("insert into foo values (1)"))
+    assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS
+
+    loop.run_until_complete(aconn.commit())
+    assert aconn.pgconn.transaction_status == aconn.TransactionStatus.IDLE
+    loop.run_until_complete(cur.execute("select * from foo"))
+    assert loop.run_until_complete(cur.fetchone()) == (1,)
+    assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS
+
+
+def test_auto_transaction_fail(loop, aconn):
+    aconn.pgconn.exec_(b"drop table if exists foo")
+    aconn.pgconn.exec_(b"create table foo (id int primary key)")
+
+    cur = aconn.cursor()
+    assert aconn.pgconn.transaction_status == aconn.TransactionStatus.IDLE
+
+    loop.run_until_complete(cur.execute("insert into foo values (1)"))
+    assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS
+
+    with pytest.raises(psycopg3.DatabaseError):
+        loop.run_until_complete(cur.execute("meh"))
+    assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INERROR
+
+    loop.run_until_complete(aconn.commit())
+    assert aconn.pgconn.transaction_status == aconn.TransactionStatus.IDLE
+    loop.run_until_complete(cur.execute("select * from foo"))
+    assert loop.run_until_complete(cur.fetchone()) is None
+    assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS
+
+
 def test_get_encoding(aconn, loop):
     cur = aconn.cursor()
     loop.run_until_complete(cur.execute("show client_encoding"))
index aa749ee4b055316eef399e25123d3c317b00ec51..5c428853aeef79fe569ce55f99ff3826ffa61b5f 100644 (file)
@@ -114,12 +114,14 @@ def _execmany(svcconn):
         create table execmany (id serial primary key, num integer, data text)
         """
     )
+    svcconn.commit()
 
 
 @pytest.fixture(scope="function")
 def execmany(svcconn, _execmany):
     cur = svcconn.cursor()
     cur.execute("truncate table execmany")
+    svcconn.commit()
 
 
 def test_executemany(aconn, loop, execmany):
index b5643ce705bf7076c1035e0c0cd586e8ce0d39da..cfa0d9192db1c1df20819c35587020f16fb2ddec 100644 (file)
@@ -30,7 +30,7 @@ def test_commit(conn):
     conn.pgconn.exec_(b"create table foo (id int primary key)")
     conn.pgconn.exec_(b"begin")
     assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS
-    res = conn.pgconn.exec_(b"insert into foo values (1)")
+    conn.pgconn.exec_(b"insert into foo values (1)")
     conn.commit()
     assert conn.pgconn.transaction_status == conn.TransactionStatus.IDLE
     res = conn.pgconn.exec_(b"select id from foo where id = 1")
@@ -46,17 +46,53 @@ def test_rollback(conn):
     conn.pgconn.exec_(b"create table foo (id int primary key)")
     conn.pgconn.exec_(b"begin")
     assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS
-    res = conn.pgconn.exec_(b"insert into foo values (1)")
+    conn.pgconn.exec_(b"insert into foo values (1)")
     conn.rollback()
     assert conn.pgconn.transaction_status == conn.TransactionStatus.IDLE
     res = conn.pgconn.exec_(b"select id from foo where id = 1")
-    assert res.get_value(0, 0) is None
+    assert res.ntuples == 0
 
     conn.close()
     with pytest.raises(psycopg3.OperationalError):
         conn.rollback()
 
 
+def test_auto_transaction(conn):
+    conn.pgconn.exec_(b"drop table if exists foo")
+    conn.pgconn.exec_(b"create table foo (id int primary key)")
+
+    cur = conn.cursor()
+    assert conn.pgconn.transaction_status == conn.TransactionStatus.IDLE
+
+    cur.execute("insert into foo values (1)")
+    assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS
+
+    conn.commit()
+    assert conn.pgconn.transaction_status == conn.TransactionStatus.IDLE
+    assert cur.execute("select * from foo").fetchone() == (1,)
+    assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS
+
+
+def test_auto_transaction_fail(conn):
+    conn.pgconn.exec_(b"drop table if exists foo")
+    conn.pgconn.exec_(b"create table foo (id int primary key)")
+
+    cur = conn.cursor()
+    assert conn.pgconn.transaction_status == conn.TransactionStatus.IDLE
+
+    cur.execute("insert into foo values (1)")
+    assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS
+
+    with pytest.raises(psycopg3.DatabaseError):
+        cur.execute("meh")
+    assert conn.pgconn.transaction_status == conn.TransactionStatus.INERROR
+
+    conn.commit()
+    assert conn.pgconn.transaction_status == conn.TransactionStatus.IDLE
+    assert cur.execute("select * from foo").fetchone() is None
+    assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS
+
+
 def test_get_encoding(conn):
     (enc,) = conn.cursor().execute("show client_encoding").fetchone()
     assert enc == conn.encoding
index 86b589cbe6c9b8b28166faa0341507a08f1debc2..29db7007deb8c06f2216d581556ea5cce154b0c7 100644 (file)
@@ -111,12 +111,14 @@ def _execmany(svcconn):
         create table execmany (id serial primary key, num integer, data text)
         """
     )
+    svcconn.commit()
 
 
 @pytest.fixture(scope="function")
 def execmany(svcconn, _execmany):
     cur = svcconn.cursor()
     cur.execute("truncate table execmany")
+    svcconn.commit()
 
 
 def test_executemany(conn, execmany):
index 1f1e88352bffb85ea0022524be6118d0599a2d4b..f1277486b659353f8cb688d1f501d2f7c3fa9b78 100644 (file)
@@ -99,6 +99,7 @@ def testcomp(svcconn):
         create type testcomp as (foo text, bar int8, baz float8);
         """
     )
+    svcconn.commit()
 
 
 def test_fetch_info(conn, testcomp):