From d16294a186a8cc872480708fae56142938f9a231 Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Wed, 20 May 2020 17:22:56 +1200 Subject: [PATCH] Added automatic transaction start --- psycopg3/connection.py | 39 +++++++++++++++++++++++++----- psycopg3/cursor.py | 14 +++++++---- tests/fix_async.py | 3 ++- tests/fix_db.py | 11 ++++++--- tests/test_async_connection.py | 44 +++++++++++++++++++++++++++++++--- tests/test_async_cursor.py | 2 ++ tests/test_connection.py | 42 +++++++++++++++++++++++++++++--- tests/test_cursor.py | 2 ++ tests/types/test_composite.py | 1 + 9 files changed, 138 insertions(+), 20 deletions(-) diff --git a/psycopg3/connection.py b/psycopg3/connection.py index bcf45dd15..522fa1ce3 100644 --- a/psycopg3/connection.py +++ b/psycopg3/connection.py @@ -14,6 +14,7 @@ from . import pq from . import errors as e from . import cursor from . import proto +from .pq import TransactionStatus, ExecStatus from .conninfo import make_conninfo from .waiting import wait, wait_async @@ -153,6 +154,19 @@ class Connection(BaseConnection): cur = super().cursor(name, binary) return cast(cursor.Cursor, cur) + def _start_query(self) -> None: + # the function is meant to be called by a cursor once the lock is taken + status = self.pgconn.transaction_status + if status == TransactionStatus.INTRANS: + return + + self.pgconn.send_query(b"begin") + (pgres,) = self.wait(execute(self.pgconn)) + if pgres.status != ExecStatus.COMMAND_OK: + raise e.OperationalError( + f"error on begin: {pq.error_message(pgres)}" + ) + def commit(self) -> None: self._exec_commit_rollback(b"commit") @@ -162,12 +176,12 @@ class Connection(BaseConnection): def _exec_commit_rollback(self, command: bytes) -> None: with self.lock: status = self.pgconn.transaction_status - if status == pq.TransactionStatus.IDLE: + if status == TransactionStatus.IDLE: return self.pgconn.send_query(command) (pgres,) = self.wait(execute(self.pgconn)) - if pgres.status != pq.ExecStatus.COMMAND_OK: + if pgres.status != ExecStatus.COMMAND_OK: raise e.OperationalError( f"error on {command.decode('utf8')}:" f" {pq.error_message(pgres)}" @@ -187,7 +201,7 @@ class Connection(BaseConnection): ) gen = execute(self.pgconn) (result,) = self.wait(gen) - if result.status != pq.ExecStatus.TUPLES_OK: + if result.status != ExecStatus.TUPLES_OK: raise e.error_from_result(result) @@ -226,6 +240,19 @@ class AsyncConnection(BaseConnection): cur = super().cursor(name, binary) return cast(cursor.AsyncCursor, cur) + async def _start_query(self) -> None: + # the function is meant to be called by a cursor once the lock is taken + status = self.pgconn.transaction_status + if status == TransactionStatus.INTRANS: + return + + self.pgconn.send_query(b"begin") + (pgres,) = await self.wait(execute(self.pgconn)) + if pgres.status != ExecStatus.COMMAND_OK: + raise e.OperationalError( + f"error on begin: {pq.error_message(pgres)}" + ) + async def commit(self) -> None: await self._exec_commit_rollback(b"commit") @@ -235,12 +262,12 @@ class AsyncConnection(BaseConnection): async def _exec_commit_rollback(self, command: bytes) -> None: async with self.lock: status = self.pgconn.transaction_status - if status == pq.TransactionStatus.IDLE: + if status == TransactionStatus.IDLE: return self.pgconn.send_query(command) (pgres,) = await self.wait(execute(self.pgconn)) - if pgres.status != pq.ExecStatus.COMMAND_OK: + if pgres.status != ExecStatus.COMMAND_OK: raise e.OperationalError( f"error on {command.decode('utf8')}:" f" {pq.error_message(pgres)}" @@ -258,5 +285,5 @@ class AsyncConnection(BaseConnection): ) gen = execute(self.pgconn) (result,) = await self.wait(gen) - if result.status != pq.ExecStatus.TUPLES_OK: + if result.status != ExecStatus.TUPLES_OK: raise e.error_from_result(result) diff --git a/psycopg3/cursor.py b/psycopg3/cursor.py index 3582bccca..867515e39 100644 --- a/psycopg3/cursor.py +++ b/psycopg3/cursor.py @@ -263,6 +263,7 @@ class Cursor(BaseCursor): def execute(self, query: Query, vars: Optional[Params] = None) -> "Cursor": with self.connection.lock: self._start_query() + self.connection._start_query() self._execute_send(query, vars) gen = execute(self.connection.pgconn) results = self.connection.wait(gen) @@ -274,8 +275,10 @@ class Cursor(BaseCursor): ) -> "Cursor": with self.connection.lock: self._start_query() - for i, vars in enumerate(vars_seq): - if i == 0: + self.connection._start_query() + first = True + for vars in vars_seq: + if first: pgq = self._send_prepare(b"", query, vars) gen = execute(self.connection.pgconn) (result,) = self.connection.wait(gen) @@ -350,6 +353,7 @@ class AsyncCursor(BaseCursor): ) -> "AsyncCursor": async with self.connection.lock: self._start_query() + await self.connection._start_query() self._execute_send(query, vars) gen = execute(self.connection.pgconn) results = await self.connection.wait(gen) @@ -361,8 +365,10 @@ class AsyncCursor(BaseCursor): ) -> "AsyncCursor": async with self.connection.lock: self._start_query() - for i, vars in enumerate(vars_seq): - if i == 0: + await self.connection._start_query() + first = True + for vars in vars_seq: + if first: pgq = self._send_prepare(b"", query, vars) gen = execute(self.connection.pgconn) (result,) = await self.connection.wait(gen) diff --git a/tests/fix_async.py b/tests/fix_async.py index 5bd4c0b70..9b119a3ba 100644 --- a/tests/fix_async.py +++ b/tests/fix_async.py @@ -15,4 +15,5 @@ def aconn(loop, dsn, pq): from psycopg3 import AsyncConnection conn = loop.run_until_complete(AsyncConnection.connect(dsn)) - return conn + yield conn + loop.run_until_complete(conn.close()) diff --git a/tests/fix_db.py b/tests/fix_db.py index 8ac7cf3c8..6f0f06449 100644 --- a/tests/fix_db.py +++ b/tests/fix_db.py @@ -30,7 +30,8 @@ def pgconn(pq, dsn): pytest.fail( f"bad connection: {conn.error_message.decode('utf8', 'replace')}" ) - return conn + yield conn + conn.finish() @pytest.fixture @@ -38,7 +39,9 @@ def conn(dsn): """Return a `Connection` connected to the ``--test-dsn`` database.""" from psycopg3 import Connection - return Connection.connect(dsn) + conn = Connection.connect(dsn) + yield conn + conn.close() @pytest.fixture(scope="session") @@ -48,4 +51,6 @@ def svcconn(dsn): """ from psycopg3 import Connection - return Connection.connect(dsn) + conn = Connection.connect(dsn) + yield conn + conn.close() diff --git a/tests/test_async_connection.py b/tests/test_async_connection.py index f5a8875e9..8192153b9 100644 --- a/tests/test_async_connection.py +++ b/tests/test_async_connection.py @@ -30,7 +30,7 @@ def test_commit(loop, aconn): aconn.pgconn.exec_(b"create table foo (id int primary key)") aconn.pgconn.exec_(b"begin") assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS - res = aconn.pgconn.exec_(b"insert into foo values (1)") + aconn.pgconn.exec_(b"insert into foo values (1)") loop.run_until_complete(aconn.commit()) assert aconn.pgconn.transaction_status == aconn.TransactionStatus.IDLE res = aconn.pgconn.exec_(b"select id from foo where id = 1") @@ -46,17 +46,55 @@ def test_rollback(loop, aconn): aconn.pgconn.exec_(b"create table foo (id int primary key)") aconn.pgconn.exec_(b"begin") assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS - res = aconn.pgconn.exec_(b"insert into foo values (1)") + aconn.pgconn.exec_(b"insert into foo values (1)") loop.run_until_complete(aconn.rollback()) assert aconn.pgconn.transaction_status == aconn.TransactionStatus.IDLE res = aconn.pgconn.exec_(b"select id from foo where id = 1") - assert res.get_value(0, 0) is None + assert res.ntuples == 0 loop.run_until_complete(aconn.close()) with pytest.raises(psycopg3.OperationalError): loop.run_until_complete(aconn.rollback()) +def test_auto_transaction(loop, aconn): + aconn.pgconn.exec_(b"drop table if exists foo") + aconn.pgconn.exec_(b"create table foo (id int primary key)") + + cur = aconn.cursor() + assert aconn.pgconn.transaction_status == aconn.TransactionStatus.IDLE + + loop.run_until_complete(cur.execute("insert into foo values (1)")) + assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS + + loop.run_until_complete(aconn.commit()) + assert aconn.pgconn.transaction_status == aconn.TransactionStatus.IDLE + loop.run_until_complete(cur.execute("select * from foo")) + assert loop.run_until_complete(cur.fetchone()) == (1,) + assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS + + +def test_auto_transaction_fail(loop, aconn): + aconn.pgconn.exec_(b"drop table if exists foo") + aconn.pgconn.exec_(b"create table foo (id int primary key)") + + cur = aconn.cursor() + assert aconn.pgconn.transaction_status == aconn.TransactionStatus.IDLE + + loop.run_until_complete(cur.execute("insert into foo values (1)")) + assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS + + with pytest.raises(psycopg3.DatabaseError): + loop.run_until_complete(cur.execute("meh")) + assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INERROR + + loop.run_until_complete(aconn.commit()) + assert aconn.pgconn.transaction_status == aconn.TransactionStatus.IDLE + loop.run_until_complete(cur.execute("select * from foo")) + assert loop.run_until_complete(cur.fetchone()) is None + assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS + + def test_get_encoding(aconn, loop): cur = aconn.cursor() loop.run_until_complete(cur.execute("show client_encoding")) diff --git a/tests/test_async_cursor.py b/tests/test_async_cursor.py index aa749ee4b..5c428853a 100644 --- a/tests/test_async_cursor.py +++ b/tests/test_async_cursor.py @@ -114,12 +114,14 @@ def _execmany(svcconn): create table execmany (id serial primary key, num integer, data text) """ ) + svcconn.commit() @pytest.fixture(scope="function") def execmany(svcconn, _execmany): cur = svcconn.cursor() cur.execute("truncate table execmany") + svcconn.commit() def test_executemany(aconn, loop, execmany): diff --git a/tests/test_connection.py b/tests/test_connection.py index b5643ce70..cfa0d9192 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -30,7 +30,7 @@ def test_commit(conn): conn.pgconn.exec_(b"create table foo (id int primary key)") conn.pgconn.exec_(b"begin") assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS - res = conn.pgconn.exec_(b"insert into foo values (1)") + conn.pgconn.exec_(b"insert into foo values (1)") conn.commit() assert conn.pgconn.transaction_status == conn.TransactionStatus.IDLE res = conn.pgconn.exec_(b"select id from foo where id = 1") @@ -46,17 +46,53 @@ def test_rollback(conn): conn.pgconn.exec_(b"create table foo (id int primary key)") conn.pgconn.exec_(b"begin") assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS - res = conn.pgconn.exec_(b"insert into foo values (1)") + conn.pgconn.exec_(b"insert into foo values (1)") conn.rollback() assert conn.pgconn.transaction_status == conn.TransactionStatus.IDLE res = conn.pgconn.exec_(b"select id from foo where id = 1") - assert res.get_value(0, 0) is None + assert res.ntuples == 0 conn.close() with pytest.raises(psycopg3.OperationalError): conn.rollback() +def test_auto_transaction(conn): + conn.pgconn.exec_(b"drop table if exists foo") + conn.pgconn.exec_(b"create table foo (id int primary key)") + + cur = conn.cursor() + assert conn.pgconn.transaction_status == conn.TransactionStatus.IDLE + + cur.execute("insert into foo values (1)") + assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS + + conn.commit() + assert conn.pgconn.transaction_status == conn.TransactionStatus.IDLE + assert cur.execute("select * from foo").fetchone() == (1,) + assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS + + +def test_auto_transaction_fail(conn): + conn.pgconn.exec_(b"drop table if exists foo") + conn.pgconn.exec_(b"create table foo (id int primary key)") + + cur = conn.cursor() + assert conn.pgconn.transaction_status == conn.TransactionStatus.IDLE + + cur.execute("insert into foo values (1)") + assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS + + with pytest.raises(psycopg3.DatabaseError): + cur.execute("meh") + assert conn.pgconn.transaction_status == conn.TransactionStatus.INERROR + + conn.commit() + assert conn.pgconn.transaction_status == conn.TransactionStatus.IDLE + assert cur.execute("select * from foo").fetchone() is None + assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS + + def test_get_encoding(conn): (enc,) = conn.cursor().execute("show client_encoding").fetchone() assert enc == conn.encoding diff --git a/tests/test_cursor.py b/tests/test_cursor.py index 86b589cbe..29db7007d 100644 --- a/tests/test_cursor.py +++ b/tests/test_cursor.py @@ -111,12 +111,14 @@ def _execmany(svcconn): create table execmany (id serial primary key, num integer, data text) """ ) + svcconn.commit() @pytest.fixture(scope="function") def execmany(svcconn, _execmany): cur = svcconn.cursor() cur.execute("truncate table execmany") + svcconn.commit() def test_executemany(conn, execmany): diff --git a/tests/types/test_composite.py b/tests/types/test_composite.py index 1f1e88352..f1277486b 100644 --- a/tests/types/test_composite.py +++ b/tests/types/test_composite.py @@ -99,6 +99,7 @@ def testcomp(svcconn): create type testcomp as (foo text, bar int8, baz float8); """ ) + svcconn.commit() def test_fetch_info(conn, testcomp): -- 2.47.2