self.pgconn = conn.pgconn
self.command_queue = Deque[PipelineCommand]()
self.result_queue = Deque[PendingResult]()
+ self.level = 0
def __repr__(self) -> str:
cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
return BasePipeline._is_supported
def _enter(self) -> None:
- self.pgconn.enter_pipeline_mode()
+ if self.level == 0:
+ self.pgconn.enter_pipeline_mode()
+ self.level += 1
def _exit(self) -> None:
- if self.pgconn.status != ConnStatus.BAD:
+ self.level -= 1
+ if self.level == 0 and self.pgconn.status != ConnStatus.BAD:
self.pgconn.exit_pipeline_mode()
def _sync_gen(self) -> PQGen[None]:
except Exception as exc2:
# Don't clobber an exception raised in the block with this one
if exc_val:
- logger.warning("error ignored exiting %r: %s", self, exc2)
+ logger.warning("error ignored syncing %r: %s", self, exc2)
else:
raise
finally:
- self._exit()
+ try:
+ self._exit()
+ except Exception as exc2:
+ # Notice that this error might be pretty irrecoverable. It
+ # happens on COPY, for insance: even if sync succeeds, exiting
+ # fails with "cannot exit pipeline mode with uncollected results"
+ if exc_val:
+ logger.warning("error ignored exiting %r: %s", self, exc2)
+ else:
+ raise
class AsyncPipeline(BasePipeline):
except Exception as exc2:
# Don't clobber an exception raised in the block with this one
if exc_val:
- logger.warning("error ignored exiting %r: %s", self, exc2)
+ logger.warning("error ignored syncing %r: %s", self, exc2)
else:
raise
finally:
- self._exit()
+ try:
+ self._exit()
+ except Exception as exc2:
+ if exc_val:
+ logger.warning("error ignored exiting %r: %s", self, exc2)
+ else:
+ raise
def pipeline(self) -> Iterator[Pipeline]:
"""Context manager to switch the connection into pipeline mode."""
with self.lock:
- if self._pipeline is None:
- # We must enter pipeline mode: create a new one
+ pipeline = self._pipeline
+ if pipeline is None:
# WARNING: reference loop, broken ahead.
pipeline = self._pipeline = Pipeline(self)
- else:
- # we are already in pipeline mode: bail out as soon as we
- # leave the lock block.
- pipeline = None
-
- if not pipeline:
- # No-op re-entered inner pipeline block.
- try:
- yield self._pipeline
- finally:
- self._pipeline.sync()
- return
try:
with pipeline:
yield pipeline
finally:
- assert pipeline.status == pq.PipelineStatus.OFF, pipeline.status
- self._pipeline = None
+ if pipeline.level == 0:
+ with self.lock:
+ assert pipeline is self._pipeline
+ self._pipeline = None
def wait(self, gen: PQGen[RV], timeout: Optional[float] = 0.1) -> RV:
"""
from . import errors as e
from . import waiting
-from .pq import Format, PipelineStatus, TransactionStatus
+from .pq import Format, TransactionStatus
from .abc import AdaptContext, Params, PQGen, PQGenConn, Query, RV
from ._tpc import Xid
from .rows import Row, AsyncRowFactory, tuple_row, TupleRow, args_row
cursor_factory: Type[AsyncCursor[Row]]
server_cursor_factory: Type[AsyncServerCursor[Row]]
row_factory: AsyncRowFactory[Row]
-
- _pipeline: "Optional[AsyncPipeline]"
+ _pipeline: Optional[AsyncPipeline]
def __init__(
self,
async def pipeline(self) -> AsyncIterator[AsyncPipeline]:
"""Context manager to switch the connection into pipeline mode."""
async with self.lock:
- if self._pipeline is None:
- # We must enter pipeline mode: create a new one
+ pipeline = self._pipeline
+ if pipeline is None:
# WARNING: reference loop, broken ahead.
pipeline = self._pipeline = AsyncPipeline(self)
- else:
- # we are already in pipeline mode: bail out as soon as we
- # leave the lock block.
- pipeline = None
-
- if not pipeline:
- # No-op re-entered inner pipeline block.
- try:
- yield self._pipeline
- finally:
- await self._pipeline.sync()
- return
try:
async with pipeline:
yield pipeline
finally:
- assert pipeline.status == PipelineStatus.OFF, pipeline.status
- self._pipeline = None
+ if pipeline.level == 0:
+ async with self.lock:
+ assert pipeline is self._pipeline
+ self._pipeline = None
async def wait(self, gen: PQGen[RV]) -> RV:
try:
assert len(caplog.records) == 1
+def test_pipeline_exit_error_noclobber_nested(conn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg")
+ with pytest.raises(ZeroDivisionError):
+ with conn.pipeline():
+ with conn.pipeline():
+ conn.close()
+ 1 / 0
+
+ assert len(caplog.records) == 2
+
+
def test_pipeline_exit_sync_trace(conn, trace):
t = trace.trace(conn)
with conn.pipeline():
assert len(caplog.records) == 1
+async def test_pipeline_exit_error_noclobber_nested(aconn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg")
+ with pytest.raises(ZeroDivisionError):
+ async with aconn.pipeline():
+ async with aconn.pipeline():
+ await aconn.close()
+ 1 / 0
+
+ assert len(caplog.records) == 2
+
+
async def test_pipeline_exit_sync_trace(aconn, trace):
t = trace.trace(aconn)
async with aconn.pipeline():