from . import errors as e
from . import cursor
from . import proto
+from .pq import TransactionStatus, ExecStatus
from .conninfo import make_conninfo
from .waiting import wait, wait_async
cur = super().cursor(name, binary)
return cast(cursor.Cursor, cur)
+ def _start_query(self) -> None:
+ # the function is meant to be called by a cursor once the lock is taken
+ status = self.pgconn.transaction_status
+ if status == TransactionStatus.INTRANS:
+ return
+
+ self.pgconn.send_query(b"begin")
+ (pgres,) = self.wait(execute(self.pgconn))
+ if pgres.status != ExecStatus.COMMAND_OK:
+ raise e.OperationalError(
+ f"error on begin: {pq.error_message(pgres)}"
+ )
+
def commit(self) -> None:
self._exec_commit_rollback(b"commit")
def _exec_commit_rollback(self, command: bytes) -> None:
with self.lock:
status = self.pgconn.transaction_status
- if status == pq.TransactionStatus.IDLE:
+ if status == TransactionStatus.IDLE:
return
self.pgconn.send_query(command)
(pgres,) = self.wait(execute(self.pgconn))
- if pgres.status != pq.ExecStatus.COMMAND_OK:
+ if pgres.status != ExecStatus.COMMAND_OK:
raise e.OperationalError(
f"error on {command.decode('utf8')}:"
f" {pq.error_message(pgres)}"
)
gen = execute(self.pgconn)
(result,) = self.wait(gen)
- if result.status != pq.ExecStatus.TUPLES_OK:
+ if result.status != ExecStatus.TUPLES_OK:
raise e.error_from_result(result)
cur = super().cursor(name, binary)
return cast(cursor.AsyncCursor, cur)
+ async def _start_query(self) -> None:
+ # the function is meant to be called by a cursor once the lock is taken
+ status = self.pgconn.transaction_status
+ if status == TransactionStatus.INTRANS:
+ return
+
+ self.pgconn.send_query(b"begin")
+ (pgres,) = await self.wait(execute(self.pgconn))
+ if pgres.status != ExecStatus.COMMAND_OK:
+ raise e.OperationalError(
+ f"error on begin: {pq.error_message(pgres)}"
+ )
+
async def commit(self) -> None:
await self._exec_commit_rollback(b"commit")
async def _exec_commit_rollback(self, command: bytes) -> None:
async with self.lock:
status = self.pgconn.transaction_status
- if status == pq.TransactionStatus.IDLE:
+ if status == TransactionStatus.IDLE:
return
self.pgconn.send_query(command)
(pgres,) = await self.wait(execute(self.pgconn))
- if pgres.status != pq.ExecStatus.COMMAND_OK:
+ if pgres.status != ExecStatus.COMMAND_OK:
raise e.OperationalError(
f"error on {command.decode('utf8')}:"
f" {pq.error_message(pgres)}"
)
gen = execute(self.pgconn)
(result,) = await self.wait(gen)
- if result.status != pq.ExecStatus.TUPLES_OK:
+ if result.status != ExecStatus.TUPLES_OK:
raise e.error_from_result(result)
def execute(self, query: Query, vars: Optional[Params] = None) -> "Cursor":
with self.connection.lock:
self._start_query()
+ self.connection._start_query()
self._execute_send(query, vars)
gen = execute(self.connection.pgconn)
results = self.connection.wait(gen)
) -> "Cursor":
with self.connection.lock:
self._start_query()
- for i, vars in enumerate(vars_seq):
- if i == 0:
+ self.connection._start_query()
+ first = True
+ for vars in vars_seq:
+ if first:
pgq = self._send_prepare(b"", query, vars)
gen = execute(self.connection.pgconn)
(result,) = self.connection.wait(gen)
) -> "AsyncCursor":
async with self.connection.lock:
self._start_query()
+ await self.connection._start_query()
self._execute_send(query, vars)
gen = execute(self.connection.pgconn)
results = await self.connection.wait(gen)
) -> "AsyncCursor":
async with self.connection.lock:
self._start_query()
- for i, vars in enumerate(vars_seq):
- if i == 0:
+ await self.connection._start_query()
+ first = True
+ for vars in vars_seq:
+ if first:
pgq = self._send_prepare(b"", query, vars)
gen = execute(self.connection.pgconn)
(result,) = await self.connection.wait(gen)
from psycopg3 import AsyncConnection
conn = loop.run_until_complete(AsyncConnection.connect(dsn))
- return conn
+ yield conn
+ loop.run_until_complete(conn.close())
pytest.fail(
f"bad connection: {conn.error_message.decode('utf8', 'replace')}"
)
- return conn
+ yield conn
+ conn.finish()
@pytest.fixture
"""Return a `Connection` connected to the ``--test-dsn`` database."""
from psycopg3 import Connection
- return Connection.connect(dsn)
+ conn = Connection.connect(dsn)
+ yield conn
+ conn.close()
@pytest.fixture(scope="session")
"""
from psycopg3 import Connection
- return Connection.connect(dsn)
+ conn = Connection.connect(dsn)
+ yield conn
+ conn.close()
aconn.pgconn.exec_(b"create table foo (id int primary key)")
aconn.pgconn.exec_(b"begin")
assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS
- res = aconn.pgconn.exec_(b"insert into foo values (1)")
+ aconn.pgconn.exec_(b"insert into foo values (1)")
loop.run_until_complete(aconn.commit())
assert aconn.pgconn.transaction_status == aconn.TransactionStatus.IDLE
res = aconn.pgconn.exec_(b"select id from foo where id = 1")
aconn.pgconn.exec_(b"create table foo (id int primary key)")
aconn.pgconn.exec_(b"begin")
assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS
- res = aconn.pgconn.exec_(b"insert into foo values (1)")
+ aconn.pgconn.exec_(b"insert into foo values (1)")
loop.run_until_complete(aconn.rollback())
assert aconn.pgconn.transaction_status == aconn.TransactionStatus.IDLE
res = aconn.pgconn.exec_(b"select id from foo where id = 1")
- assert res.get_value(0, 0) is None
+ assert res.ntuples == 0
loop.run_until_complete(aconn.close())
with pytest.raises(psycopg3.OperationalError):
loop.run_until_complete(aconn.rollback())
+def test_auto_transaction(loop, aconn):
+ aconn.pgconn.exec_(b"drop table if exists foo")
+ aconn.pgconn.exec_(b"create table foo (id int primary key)")
+
+ cur = aconn.cursor()
+ assert aconn.pgconn.transaction_status == aconn.TransactionStatus.IDLE
+
+ loop.run_until_complete(cur.execute("insert into foo values (1)"))
+ assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS
+
+ loop.run_until_complete(aconn.commit())
+ assert aconn.pgconn.transaction_status == aconn.TransactionStatus.IDLE
+ loop.run_until_complete(cur.execute("select * from foo"))
+ assert loop.run_until_complete(cur.fetchone()) == (1,)
+ assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS
+
+
+def test_auto_transaction_fail(loop, aconn):
+ aconn.pgconn.exec_(b"drop table if exists foo")
+ aconn.pgconn.exec_(b"create table foo (id int primary key)")
+
+ cur = aconn.cursor()
+ assert aconn.pgconn.transaction_status == aconn.TransactionStatus.IDLE
+
+ loop.run_until_complete(cur.execute("insert into foo values (1)"))
+ assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS
+
+ with pytest.raises(psycopg3.DatabaseError):
+ loop.run_until_complete(cur.execute("meh"))
+ assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INERROR
+
+ loop.run_until_complete(aconn.commit())
+ assert aconn.pgconn.transaction_status == aconn.TransactionStatus.IDLE
+ loop.run_until_complete(cur.execute("select * from foo"))
+ assert loop.run_until_complete(cur.fetchone()) is None
+ assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS
+
+
def test_get_encoding(aconn, loop):
cur = aconn.cursor()
loop.run_until_complete(cur.execute("show client_encoding"))
create table execmany (id serial primary key, num integer, data text)
"""
)
+ svcconn.commit()
@pytest.fixture(scope="function")
def execmany(svcconn, _execmany):
cur = svcconn.cursor()
cur.execute("truncate table execmany")
+ svcconn.commit()
def test_executemany(aconn, loop, execmany):
conn.pgconn.exec_(b"create table foo (id int primary key)")
conn.pgconn.exec_(b"begin")
assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS
- res = conn.pgconn.exec_(b"insert into foo values (1)")
+ conn.pgconn.exec_(b"insert into foo values (1)")
conn.commit()
assert conn.pgconn.transaction_status == conn.TransactionStatus.IDLE
res = conn.pgconn.exec_(b"select id from foo where id = 1")
conn.pgconn.exec_(b"create table foo (id int primary key)")
conn.pgconn.exec_(b"begin")
assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS
- res = conn.pgconn.exec_(b"insert into foo values (1)")
+ conn.pgconn.exec_(b"insert into foo values (1)")
conn.rollback()
assert conn.pgconn.transaction_status == conn.TransactionStatus.IDLE
res = conn.pgconn.exec_(b"select id from foo where id = 1")
- assert res.get_value(0, 0) is None
+ assert res.ntuples == 0
conn.close()
with pytest.raises(psycopg3.OperationalError):
conn.rollback()
+def test_auto_transaction(conn):
+ conn.pgconn.exec_(b"drop table if exists foo")
+ conn.pgconn.exec_(b"create table foo (id int primary key)")
+
+ cur = conn.cursor()
+ assert conn.pgconn.transaction_status == conn.TransactionStatus.IDLE
+
+ cur.execute("insert into foo values (1)")
+ assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS
+
+ conn.commit()
+ assert conn.pgconn.transaction_status == conn.TransactionStatus.IDLE
+ assert cur.execute("select * from foo").fetchone() == (1,)
+ assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS
+
+
+def test_auto_transaction_fail(conn):
+ conn.pgconn.exec_(b"drop table if exists foo")
+ conn.pgconn.exec_(b"create table foo (id int primary key)")
+
+ cur = conn.cursor()
+ assert conn.pgconn.transaction_status == conn.TransactionStatus.IDLE
+
+ cur.execute("insert into foo values (1)")
+ assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS
+
+ with pytest.raises(psycopg3.DatabaseError):
+ cur.execute("meh")
+ assert conn.pgconn.transaction_status == conn.TransactionStatus.INERROR
+
+ conn.commit()
+ assert conn.pgconn.transaction_status == conn.TransactionStatus.IDLE
+ assert cur.execute("select * from foo").fetchone() is None
+ assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS
+
+
def test_get_encoding(conn):
(enc,) = conn.cursor().execute("show client_encoding").fetchone()
assert enc == conn.encoding
create table execmany (id serial primary key, num integer, data text)
"""
)
+ svcconn.commit()
@pytest.fixture(scope="function")
def execmany(svcconn, _execmany):
cur = svcconn.cursor()
cur.execute("truncate table execmany")
+ svcconn.commit()
def test_executemany(conn, execmany):
create type testcomp as (foo text, bar int8, baz float8);
"""
)
+ svcconn.commit()
def test_fetch_info(conn, testcomp):