configure: Optional[
Callable[[AsyncConnection], Awaitable[None]]
] = None,
+ reset: Optional[Callable[[AsyncConnection], Awaitable[None]]] = None,
**kwargs: Any,
):
# https://bugs.python.org/issue42600
)
self._configure = configure
+ self._reset = reset
self._lock = asyncio.Lock()
self._waiting: Deque["AsyncClient"] = deque()
"""
status = conn.pgconn.transaction_status
if status == TransactionStatus.IDLE:
- return
+ pass
- if status in (TransactionStatus.INTRANS, TransactionStatus.INERROR):
+ elif status in (TransactionStatus.INTRANS, TransactionStatus.INERROR):
# Connection returned with an active transaction
logger.warning("rolling back returned connection: %s", conn)
try:
logger.warning("closing returned connection: %s", conn)
await conn.close()
+ if not conn.closed and self._reset:
+ try:
+ await self._reset(conn)
+ status = conn.pgconn.transaction_status
+ if status != TransactionStatus.IDLE:
+ nstatus = TransactionStatus(status).name
+ raise e.ProgrammingError(
+ f"connection left in status {nstatus} by reset function"
+ f" {self._reset}: discarded"
+ )
+ except Exception as ex:
+ logger.warning(f"error resetting connection: {ex}")
+ await conn.close()
+
async def _shrink_pool(self) -> None:
to_close: Optional[AsyncConnection] = None
self,
conninfo: str = "",
configure: Optional[Callable[[Connection], None]] = None,
+ reset: Optional[Callable[[Connection], None]] = None,
**kwargs: Any,
):
self._configure = configure
+ self._reset = reset
self._lock = threading.RLock()
self._waiting: Deque["WaitingClient"] = deque()
logger.warning("closing returned connection: %s", conn)
conn.close()
+ if not conn.closed and self._reset:
+ try:
+ self._reset(conn)
+ status = conn.pgconn.transaction_status
+ if status != TransactionStatus.IDLE:
+ nstatus = TransactionStatus(status).name
+ raise e.ProgrammingError(
+ f"connection left in status {nstatus} by reset function"
+ f" {self._reset}: discarded"
+ )
+ except Exception as ex:
+ logger.warning(f"error resetting connection: {ex}")
+ conn.close()
+
def _shrink_pool(self) -> None:
to_close: Optional[Connection] = None
assert "WAT" in caplog.records[0].message
+def test_reset(dsn):
+ resets = 0
+
+ def setup(conn):
+ with conn.transaction():
+ conn.execute("set timezone to '+1:00'")
+
+ def reset(conn):
+ nonlocal resets
+ resets += 1
+ with conn.transaction():
+ conn.execute("set timezone to utc")
+
+ with pool.ConnectionPool(minconn=1, reset=reset) as p:
+ with p.connection() as conn:
+ assert resets == 0
+ conn.execute("set timezone to '+2:00'")
+
+ p.wait()
+ assert resets == 1
+
+ with p.connection() as conn:
+ with conn.execute("show timezone") as cur:
+ assert cur.fetchone() == ("UTC",)
+
+ p.wait()
+ assert resets == 2
+
+
+def test_reset_badstate(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg3.pool")
+
+ def reset(conn):
+ conn.execute("reset all")
+
+ with pool.ConnectionPool(minconn=1, reset=reset) as p:
+ with p.connection() as conn:
+ conn.execute("select 1")
+ pid1 = conn.pgconn.backend_pid
+
+ with p.connection() as conn:
+ conn.execute("select 1")
+ pid2 = conn.pgconn.backend_pid
+
+ assert pid1 != pid2
+ assert caplog.records
+ assert "INTRANS" in caplog.records[0].message
+
+
+def test_reset_broken(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg3.pool")
+
+ def reset(conn):
+ with conn.transaction():
+ conn.execute("WAT")
+
+ with pool.ConnectionPool(minconn=1, reset=reset) as p:
+ with p.connection() as conn:
+ conn.execute("select 1")
+ pid1 = conn.pgconn.backend_pid
+
+ with p.connection() as conn:
+ conn.execute("select 1")
+ pid2 = conn.pgconn.backend_pid
+
+ assert pid1 != pid2
+ assert caplog.records
+ assert "WAT" in caplog.records[0].message
+
+
@pytest.mark.slow
def test_queue(dsn, retries):
def worker(n):
assert "WAT" in caplog.records[0].message
+async def test_reset(dsn):
+ resets = 0
+
+ async def setup(conn):
+ async with conn.transaction():
+ await conn.execute("set timezone to '+1:00'")
+
+ async def reset(conn):
+ nonlocal resets
+ resets += 1
+ async with conn.transaction():
+ await conn.execute("set timezone to utc")
+
+ async with pool.AsyncConnectionPool(minconn=1, reset=reset) as p:
+ async with p.connection() as conn:
+ assert resets == 0
+ await conn.execute("set timezone to '+2:00'")
+
+ await p.wait()
+ assert resets == 1
+
+ async with p.connection() as conn:
+ cur = await conn.execute("show timezone")
+ assert (await cur.fetchone()) == ("UTC",)
+
+ await p.wait()
+ assert resets == 2
+
+
+async def test_reset_badstate(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg3.pool")
+
+ async def reset(conn):
+ await conn.execute("reset all")
+
+ async with pool.AsyncConnectionPool(minconn=1, reset=reset) as p:
+ async with p.connection() as conn:
+ await conn.execute("select 1")
+ pid1 = conn.pgconn.backend_pid
+
+ async with p.connection() as conn:
+ await conn.execute("select 1")
+ pid2 = conn.pgconn.backend_pid
+
+ assert pid1 != pid2
+ assert caplog.records
+ assert "INTRANS" in caplog.records[0].message
+
+
+async def test_reset_broken(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg3.pool")
+
+ async def reset(conn):
+ async with conn.transaction():
+ await conn.execute("WAT")
+
+ async with pool.AsyncConnectionPool(minconn=1, reset=reset) as p:
+ async with p.connection() as conn:
+ await conn.execute("select 1")
+ pid1 = conn.pgconn.backend_pid
+
+ async with p.connection() as conn:
+ await conn.execute("select 1")
+ pid2 = conn.pgconn.backend_pid
+
+ assert pid1 != pid2
+ assert caplog.records
+ assert "WAT" in caplog.records[0].message
+
+
@pytest.mark.slow
async def test_queue(dsn, retries):
async def worker(n):