from time import monotonic
from types import TracebackType
from typing import Any, Generic, cast
-from asyncio import CancelledError
from weakref import ref
from contextlib import contextmanager
from collections import deque
from ._acompat import Condition, Event, Lock, Queue, Worker, current_thread_name
from ._acompat import gather, sleep, spawn
+CLIENT_EXCEPTIONS = Exception
+
logger = logging.getLogger("psycopg.pool")
conn = self._getconn_unchecked(deadline - monotonic())
try:
self._check_connection(conn)
- except (Exception, CancelledError):
+ except CLIENT_EXCEPTIONS:
self._putconn(conn, from_getconn=True)
else:
logger.info("connection given by %r", self.name)
if not conn:
try:
conn = pos.wait(timeout=timeout)
- except BaseException:
+ except CLIENT_EXCEPTIONS:
self._stats[self._REQUESTS_ERRORS] += 1
raise
finally:
return
try:
self._check(conn)
- except BaseException as e:
+ except CLIENT_EXCEPTIONS as e:
logger.info("connection failed check: %s", e)
raise
# Check for broken connections
try:
self.check_connection(conn)
- except (Exception, CancelledError):
+ except CLIENT_EXCEPTIONS:
self._stats[self._CONNECTIONS_LOST] += 1
logger.warning("discarding broken connection: %s", conn)
self.run_task(AddConnection(self))
# Run the task. Make sure don't die in the attempt.
try:
task.run()
- except (Exception, CancelledError) as ex:
+ except CLIENT_EXCEPTIONS as ex:
logger.warning(
"task run %s failed: %s: %s", task, ex.__class__.__name__, ex
)
t0 = monotonic()
try:
conn = self.connection_class.connect(conninfo, **kwargs)
- except (Exception, CancelledError):
+ except CLIENT_EXCEPTIONS:
self._stats[self._CONNECTIONS_ERRORS] += 1
raise
else:
try:
conn = self._connect()
- except (Exception, CancelledError) as ex:
+ except CLIENT_EXCEPTIONS as ex:
logger.warning("error connecting in %r: %s", self.name, ex)
if attempt.time_to_give_up(now):
logger.warning(
logger.warning("rolling back returned connection: %s", conn)
try:
conn.rollback()
- except Exception as ex:
+ except CLIENT_EXCEPTIONS as ex:
logger.warning(
"rollback failed: %s: %s. Discarding connection %s",
ex.__class__.__name__,
raise e.ProgrammingError(
f"connection left in status {sname} by reset function {self._reset}: discarded"
)
- except (Exception, CancelledError) as ex:
+ except CLIENT_EXCEPTIONS as ex:
logger.warning("error resetting connection: %s", ex)
self._close_connection(conn)
self.error = PoolTimeout(
f"couldn't get a connection after {timeout:.2f} sec"
)
- except BaseException as ex:
+ except CLIENT_EXCEPTIONS as ex:
self.error = ex
if self.conn:
from time import monotonic
from types import TracebackType
from typing import Any, Generic, cast
-from asyncio import CancelledError
from weakref import ref
from contextlib import asynccontextmanager
from collections import deque
if True: # ASYNC
import asyncio
+ # The exceptions that we need to capture in order to keep the pool
+ # consistent and avoid losing connections on errors in callers code.
+ CLIENT_EXCEPTIONS = (Exception, asyncio.CancelledError)
+else:
+ CLIENT_EXCEPTIONS = Exception
+
+
logger = logging.getLogger("psycopg.pool")
conn = await self._getconn_unchecked(deadline - monotonic())
try:
await self._check_connection(conn)
- except (Exception, CancelledError):
+ except CLIENT_EXCEPTIONS:
await self._putconn(conn, from_getconn=True)
else:
logger.info("connection given by %r", self.name)
if not conn:
try:
conn = await pos.wait(timeout=timeout)
- except BaseException:
+ except CLIENT_EXCEPTIONS:
self._stats[self._REQUESTS_ERRORS] += 1
raise
finally:
return
try:
await self._check(conn)
- except BaseException as e:
+ except CLIENT_EXCEPTIONS as e:
logger.info("connection failed check: %s", e)
raise
# Check for broken connections
try:
await self.check_connection(conn)
- except (Exception, CancelledError):
+ except CLIENT_EXCEPTIONS:
self._stats[self._CONNECTIONS_LOST] += 1
logger.warning("discarding broken connection: %s", conn)
self.run_task(AddConnection(self))
# Run the task. Make sure don't die in the attempt.
try:
await task.run()
- except (Exception, CancelledError) as ex:
+ except CLIENT_EXCEPTIONS as ex:
logger.warning(
"task run %s failed: %s: %s", task, ex.__class__.__name__, ex
)
t0 = monotonic()
try:
conn = await self.connection_class.connect(conninfo, **kwargs)
- except (Exception, CancelledError):
+ except CLIENT_EXCEPTIONS:
self._stats[self._CONNECTIONS_ERRORS] += 1
raise
else:
try:
conn = await self._connect()
- except (Exception, CancelledError) as ex:
+ except CLIENT_EXCEPTIONS as ex:
logger.warning("error connecting in %r: %s", self.name, ex)
if attempt.time_to_give_up(now):
logger.warning(
logger.warning("rolling back returned connection: %s", conn)
try:
await conn.rollback()
- except Exception as ex:
+ except CLIENT_EXCEPTIONS as ex:
logger.warning(
"rollback failed: %s: %s. Discarding connection %s",
ex.__class__.__name__,
f"connection left in status {sname} by reset function"
f" {self._reset}: discarded"
)
- except (Exception, CancelledError) as ex:
+ except CLIENT_EXCEPTIONS as ex:
logger.warning("error resetting connection: %s", ex)
await self._close_connection(conn)
self.error = PoolTimeout(
f"couldn't get a connection after {timeout:.2f} sec"
)
- except BaseException as ex:
+ except CLIENT_EXCEPTIONS as ex:
self.error = ex
if self.conn:
logger = logging.getLogger(__name__)
+CLIENT_EXCEPTIONS = Exception
+
class Scheduler:
break
try:
task.action()
- except Exception as e:
+ except CLIENT_EXCEPTIONS as e:
logger.warning(
"scheduled task run %s failed: %s: %s",
task.action,
logger = logging.getLogger(__name__)
+if True: # ASYNC
+ from asyncio import CancelledError
+
+ # The exceptions that we need to capture in order to keep the pool
+ # consistent and avoid losing connections on errors in callers code.
+ CLIENT_EXCEPTIONS = (Exception, CancelledError)
+else:
+ CLIENT_EXCEPTIONS = Exception
+
class AsyncScheduler:
def __init__(self) -> None:
break
try:
await task.action()
- except Exception as e:
+ except CLIENT_EXCEPTIONS as e:
logger.warning(
"scheduled task run %s failed: %s: %s",
task.action,
conn.execute("select 1")
+@skip_sync
+def test_cancel_on_rollback(pool_cls, dsn, monkeypatch):
+ do_cancel = False
+
+ with pool_cls(dsn, min_size=min_size(pool_cls, 1), timeout=1.0) as p:
+ with p.connection() as conn:
+
+ def rollback(self):
+ if do_cancel:
+ raise CancelledError()
+ else:
+ type(self).rollback(self)
+
+ monkeypatch.setattr(type(conn), "rollback", rollback)
+ conn.execute("select 1")
+
+ do_cancel = True
+ with pytest.raises((psycopg.errors.SyntaxError, CancelledError)):
+ with p.connection() as conn:
+ conn.execute("selexx 2")
+
+ do_cancel = False
+ with p.connection() as conn:
+ cur = conn.execute("select 3")
+ assert cur.fetchone() == (3,)
+
+
def min_size(pool_cls, num=1):
"""Return the minimum min_size supported by the pool class."""
if pool_cls is pool.ConnectionPool:
await conn.execute("select 1")
+@skip_sync
+async def test_cancel_on_rollback(pool_cls, dsn, monkeypatch):
+ do_cancel = False
+
+ async with pool_cls(dsn, min_size=min_size(pool_cls, 1), timeout=1.0) as p:
+ async with p.connection() as conn:
+
+ async def rollback(self):
+ if do_cancel:
+ raise CancelledError()
+ else:
+ await type(self).rollback(self)
+
+ monkeypatch.setattr(type(conn), "rollback", rollback)
+ await conn.execute("select 1")
+
+ do_cancel = True
+ with pytest.raises((psycopg.errors.SyntaxError, CancelledError)):
+ async with p.connection() as conn:
+ await conn.execute("selexx 2")
+
+ do_cancel = False
+ async with p.connection() as conn:
+ cur = await conn.execute("select 3")
+ assert (await cur.fetchone()) == (3,)
+
+
def min_size(pool_cls, num=1):
"""Return the minimum min_size supported by the pool class."""
if pool_cls is pool.AsyncConnectionPool: