connection string, if available (:ticket:`#194`).
- Fix possible warnings in objects deletion on interpreter shutdown
(:ticket:`#198`).
+- Don't leave connections in ACTIVE state in case of error during COPY ... TO
+ STDOUT (:ticket:`#203`).
Psycopg 3.0.7
return row
- def _end_copy_gen(self, exc: Optional[BaseException]) -> PQGen[None]:
+ def _end_copy_in_gen(self, exc: Optional[BaseException]) -> PQGen[None]:
bmsg: Optional[bytes]
if exc:
msg = f"error from Python: {type(exc).__qualname__} - {exc}"
self.cursor._rowcount = nrows if nrows is not None else -1
self._finished = True
+ def _end_copy_out_gen(self, exc: Optional[BaseException]) -> PQGen[None]:
+ if not exc:
+ return
+
+ if (
+ self.connection.pgconn.transaction_status
+ != pq.TransactionStatus.ACTIVE
+ ):
+ # The server has already finished to send copy data. The connection
+ # is already in a good state.
+ return
+
+ # Throw a cancel to the server, then consume the rest of the copy data
+ # (which might or might not have been already transferred entirely to
+ # the client, so we won't necessary see the exception associated with
+ # canceling).
+ self.connection.cancel()
+ try:
+ while (yield from self._read_gen()):
+ pass
+ except e.QueryCanceled:
+ pass
+
class Copy(BaseCopy["Connection[Any]"]):
"""Manage a :sql:`COPY` operation."""
by exit. It is available if, despite what is documented, you end up
using the `Copy` object outside a block.
"""
- # no-op in COPY TO
- if self._pgresult.status == ExecStatus.COPY_OUT:
- return
-
- self._write_end()
- self.connection.wait(self._end_copy_gen(exc))
+ if self._pgresult.status == ExecStatus.COPY_IN:
+ self._write_end()
+ self.connection.wait(self._end_copy_in_gen(exc))
+ else:
+ self.connection.wait(self._end_copy_out_gen(exc))
# Concurrent copy support
The function is designed to be run in a separate thread.
"""
- while 1:
+ while True:
data = self._queue.get(block=True, timeout=24 * 60 * 60)
if not data:
break
await self._write(data)
async def finish(self, exc: Optional[BaseException]) -> None:
- # no-op in COPY TO
- if self._pgresult.status == ExecStatus.COPY_OUT:
- return
-
- await self._write_end()
- await self.connection.wait(self._end_copy_gen(exc))
+ if self._pgresult.status == ExecStatus.COPY_IN:
+ await self._write_end()
+ await self.connection.wait(self._end_copy_in_gen(exc))
+ else:
+ await self.connection.wait(self._end_copy_out_gen(exc))
# Concurrent copy support
The function is designed to be run in a separate thread.
"""
- while 1:
+ while True:
data = await self._queue.get()
if not data:
break
assert "in rollback" in rec.message
-def test_context_active_rollback_no_clobber(conn, dsn, caplog):
+def test_context_active_rollback_no_clobber(dsn, caplog):
caplog.set_level(logging.WARNING, logger="psycopg")
with pytest.raises(ZeroDivisionError):
- with psycopg.connect(dsn) as conn2:
- with conn2.cursor() as cur:
- with cur.copy(
- "copy (select generate_series(1, 10)) to stdout"
- ) as copy:
- for row in copy.rows():
- 1 / 0
+ with psycopg.connect(dsn) as conn:
+ conn.pgconn.exec_(
+ b"copy (select generate_series(1, 10)) to stdout"
+ )
+ status = conn.info.transaction_status
+ assert status == conn.TransactionStatus.ACTIVE
+ 1 / 0
assert len(caplog.records) == 1
rec = caplog.records[0]
assert "in rollback" in rec.message
-async def test_context_active_rollback_no_clobber(conn, dsn, caplog):
+async def test_context_active_rollback_no_clobber(dsn, caplog):
caplog.set_level(logging.WARNING, logger="psycopg")
with pytest.raises(ZeroDivisionError):
- async with await psycopg.AsyncConnection.connect(dsn) as conn2:
- async with conn2.cursor() as cur:
- async with cur.copy(
- "copy (select generate_series(1, 10)) to stdout"
- ) as copy:
- async for row in copy.rows():
- 1 / 0
+ async with await psycopg.AsyncConnection.connect(dsn) as conn:
+ conn.pgconn.exec_(
+ b"copy (select generate_series(1, 10)) to stdout"
+ )
+ status = conn.info.transaction_status
+ assert status == conn.TransactionStatus.ACTIVE
+ 1 / 0
assert len(caplog.records) == 1
rec = caplog.records[0]
assert conn.info.transaction_status == conn.TransactionStatus.INERROR
+def test_copy_out_error_with_copy_finished(conn):
+ cur = conn.cursor()
+ with pytest.raises(ZeroDivisionError):
+ with cur.copy("copy (select generate_series(1, 2)) to stdout") as copy:
+ copy.read_row()
+ 1 / 0
+
+ assert conn.info.transaction_status == conn.TransactionStatus.INTRANS
+
+
+def test_copy_out_error_with_copy_not_finished(conn):
+ cur = conn.cursor()
+ with pytest.raises(ZeroDivisionError):
+ with cur.copy(
+ "copy (select generate_series(1, 1000000)) to stdout"
+ ) as copy:
+ copy.read_row()
+ 1 / 0
+
+ assert conn.info.transaction_status == conn.TransactionStatus.INERROR
+
+
+def test_copy_out_server_error(conn):
+ cur = conn.cursor()
+ with pytest.raises(e.DivisionByZero):
+ with cur.copy(
+ "copy (select 1/n from generate_series(-10, 10) x(n)) to stdout"
+ ) as copy:
+ for block in copy:
+ pass
+
+ assert conn.info.transaction_status == conn.TransactionStatus.INERROR
+
+
@pytest.mark.parametrize("format", Format)
def test_copy_in_records(conn, format):
cur = conn.cursor()
assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR
+async def test_copy_out_error_with_copy_finished(aconn):
+ cur = aconn.cursor()
+ with pytest.raises(ZeroDivisionError):
+ async with cur.copy(
+ "copy (select generate_series(1, 2)) to stdout"
+ ) as copy:
+ await copy.read_row()
+ 1 / 0
+
+ assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS
+
+
+async def test_copy_out_error_with_copy_not_finished(aconn):
+ cur = aconn.cursor()
+ with pytest.raises(ZeroDivisionError):
+ async with cur.copy(
+ "copy (select generate_series(1, 1000000)) to stdout"
+ ) as copy:
+ await copy.read_row()
+ 1 / 0
+
+ assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR
+
+
+async def test_copy_out_server_error(aconn):
+ cur = aconn.cursor()
+ with pytest.raises(e.DivisionByZero):
+ async with cur.copy(
+ "copy (select 1/n from generate_series(-10, 10) x(n)) to stdout"
+ ) as copy:
+ async for block in copy:
+ pass
+
+ assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR
+
+
@pytest.mark.parametrize("format", Format)
async def test_copy_in_records(aconn, format):
cur = aconn.cursor()
assert "in rollback" in rec.message
-def test_context_active_rollback_no_clobber(conn, dsn, caplog):
+def test_context_active_rollback_no_clobber(dsn, caplog):
caplog.set_level(logging.WARNING, logger="psycopg")
- with pytest.raises(ZeroDivisionError):
- conn2 = Connection.connect(dsn)
- with conn2.transaction():
- with conn2.cursor() as cur:
- with cur.copy(
- "copy (select generate_series(1, 10)) to stdout"
- ) as copy:
- for row in copy.rows():
- 1 / 0
+ conn = Connection.connect(dsn)
+ try:
+ with pytest.raises(ZeroDivisionError):
+ with conn.transaction():
+ conn.pgconn.exec_(
+ b"copy (select generate_series(1, 10)) to stdout"
+ )
+ status = conn.info.transaction_status
+ assert status == conn.TransactionStatus.ACTIVE
+ 1 / 0
- assert len(caplog.records) == 1
- rec = caplog.records[0]
- assert rec.levelno == logging.WARNING
- assert "in rollback" in rec.message
- conn2.close()
+ assert len(caplog.records) == 1
+ rec = caplog.records[0]
+ assert rec.levelno == logging.WARNING
+ assert "in rollback" in rec.message
+ finally:
+ conn.close()
def test_interaction_dbapi_transaction(conn):
assert "in rollback" in rec.message
-async def test_context_active_rollback_no_clobber(aconn, dsn, caplog):
+async def test_context_active_rollback_no_clobber(dsn, caplog):
caplog.set_level(logging.WARNING, logger="psycopg")
- with pytest.raises(ZeroDivisionError):
- conn2 = await AsyncConnection.connect(dsn)
- async with conn2.transaction():
- async with conn2.cursor() as cur:
- async with cur.copy(
- "copy (select generate_series(1, 10)) to stdout"
- ) as copy:
- async for row in copy.rows():
- 1 / 0
+ conn = await AsyncConnection.connect(dsn)
+ try:
+ with pytest.raises(ZeroDivisionError):
+ async with conn.transaction():
+ conn.pgconn.exec_(
+ b"copy (select generate_series(1, 10)) to stdout"
+ )
+ status = conn.info.transaction_status
+ assert status == conn.TransactionStatus.ACTIVE
+ 1 / 0
- assert len(caplog.records) == 1
- rec = caplog.records[0]
- assert rec.levelno == logging.WARNING
- assert "in rollback" in rec.message
- await conn2.close()
+ assert len(caplog.records) == 1
+ rec = caplog.records[0]
+ assert rec.levelno == logging.WARNING
+ assert "in rollback" in rec.message
+ finally:
+ await conn.close()
async def test_interaction_dbapi_transaction(aconn):