from collections import deque
from . import errors as e
+from .pq import TransactionStatus
from .connection import Connection
WORKER_TIMEOUT = 60.0
return conn
def putconn(self, conn: Connection) -> None:
- # TODO: this should happen in a maintenance thread
- # TODO: add check for broken connections
-
if conn._pool is not self:
if conn._pool:
- raise ValueError(f"the connection belongs to {conn._pool}")
+ msg = f"it comes from pool {conn._pool.name!r}"
else:
- raise ValueError("the connection doesn't belong to a pool")
+ msg = "it doesn't come from any pool"
+ raise ValueError(
+ f"can't return connection to pool {self.name!r}, {msg}: {conn}"
+ )
+ # Use a worker to perform eventual maintenance work in a separate thread
+ self.add_task(ReturnConnection(self, conn))
+
+ def _return_connection(self, conn: Connection) -> None:
# Remove the pool reference from the connection before returning it
# to the state, to avoid to create a reference loop.
conn._pool = None
+ self._reset_transaction_status(conn)
+ if conn.pgconn.transaction_status == TransactionStatus.UNKNOWN:
+ # Connection no more in working state: create a new one.
+ logger.warning("discarding closed connection: %s", conn)
+ self.add_task(AddConnection(self))
+ return
+
# Critical section: if there is a client waiting give it the connection
# otherwise put it back into the pool.
with self._lock:
# No client waiting for a connection: put it back into the queue
self._pool.append(conn)
+ def _reset_transaction_status(self, conn: Connection) -> None:
+ """
+ Bring a connection to IDLE state or close it.
+ """
+ status = conn.pgconn.transaction_status
+ if status == TransactionStatus.IDLE:
+ return
+
+ if status in (TransactionStatus.INTRANS, TransactionStatus.INERROR):
+ # Connection returned with an active transaction
+ logger.warning("rolling back returned connection: %s", conn)
+ try:
+ conn.rollback()
+ except Exception as e:
+ logger.warning(
+ "rollback failed: %s: %s. Discarding connection %s",
+ e.__class__.__name__,
+ e,
+ conn,
+ )
+ conn.close()
+
+ elif status == TransactionStatus.ACTIVE:
+ # Connection returned during an operation. Bad... just close it.
+ logger.warning("closing returned connection: %s", conn)
+ conn.close()
+
def add_task(self, task: "MaintenanceTask") -> None:
"""Add a task to the queue of tasts to perform."""
self._wqueue.put(task)
conn = self.pool._connect()
conn._pool = self.pool # make it acceptable
self.pool.putconn(conn)
+
+
+class ReturnConnection(MaintenanceTask):
+ def __init__(self, pool: ConnectionPool, conn: Connection):
+ super().__init__(pool)
+ self.conn = conn
+
+ def __call__(self) -> None:
+ super().__call__()
+ self.pool._return_connection(self.conn)
+import logging
from time import time
from threading import Thread
import pytest
+import psycopg3
+from psycopg3.pq import TransactionStatus
from psycopg3 import pool
assert len(errors) == 1
for e in errors:
assert 0.1 < e[1] < 0.15
+
+
+def test_broken_reconnect(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg3.pool")
+ p = pool.ConnectionPool(dsn, minconn=1)
+ with pytest.raises(psycopg3.OperationalError):
+ with p.connection() as conn:
+ with conn.execute("select pg_backend_pid()") as cur:
+ (pid1,) = cur.fetchone()
+ conn.close()
+
+ with p.connection() as conn2:
+ with conn2.execute("select pg_backend_pid()") as cur:
+ (pid2,) = cur.fetchone()
+
+ assert pid1 != pid2
+
+
+def test_intrans_rollback(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg3.pool")
+ p = pool.ConnectionPool(dsn, minconn=1)
+ conn = p.getconn()
+ pid = conn.pgconn.backend_pid
+ conn.execute("create table test_intrans_rollback ()")
+ assert conn.pgconn.transaction_status == TransactionStatus.INTRANS
+ p.putconn(conn)
+
+ with p.connection() as conn2:
+ assert conn2.pgconn.backend_pid == pid
+ assert conn2.pgconn.transaction_status == TransactionStatus.IDLE
+ assert not conn.execute(
+ "select 1 from pg_class where relname = 'test_intrans_rollback'"
+ ).fetchone()
+
+ assert len(caplog.records) == 1
+ assert "INTRANS" in caplog.records[0].message
+
+
+def test_inerror_rollback(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg3.pool")
+ p = pool.ConnectionPool(dsn, minconn=1)
+ conn = p.getconn()
+ pid = conn.pgconn.backend_pid
+ with pytest.raises(psycopg3.ProgrammingError):
+ conn.execute("wat")
+ assert conn.pgconn.transaction_status == TransactionStatus.INERROR
+ p.putconn(conn)
+
+ with p.connection() as conn2:
+ assert conn2.pgconn.backend_pid == pid
+ assert conn2.pgconn.transaction_status == TransactionStatus.IDLE
+
+ assert len(caplog.records) == 1
+ assert "INERROR" in caplog.records[0].message
+
+
+def test_active_close(dsn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg3.pool")
+ p = pool.ConnectionPool(dsn, minconn=1)
+ conn = p.getconn()
+ pid = conn.pgconn.backend_pid
+ cur = conn.cursor()
+ with cur.copy("copy (select * from generate_series(1, 10)) to stdout"):
+ pass
+ assert conn.pgconn.transaction_status == TransactionStatus.ACTIVE
+ p.putconn(conn)
+
+ with p.connection() as conn2:
+ assert conn2.pgconn.backend_pid != pid
+ assert conn2.pgconn.transaction_status == TransactionStatus.IDLE
+
+ assert len(caplog.records) == 2
+ assert "ACTIVE" in caplog.records[0].message
+ assert "BAD" in caplog.records[1].message
+
+
+def test_fail_rollback_close(dsn, caplog, monkeypatch):
+ caplog.set_level(logging.WARNING, logger="psycopg3.pool")
+ p = pool.ConnectionPool(dsn, minconn=1)
+ conn = p.getconn()
+
+ # Make the rollback fail
+ orig_rollback = conn.rollback
+
+ def bad_rollback():
+ conn.pgconn.finish()
+ orig_rollback()
+
+ monkeypatch.setattr(conn, "rollback", bad_rollback)
+
+ pid = conn.pgconn.backend_pid
+ with pytest.raises(psycopg3.ProgrammingError):
+ conn.execute("wat")
+ assert conn.pgconn.transaction_status == TransactionStatus.INERROR
+ p.putconn(conn)
+
+ with p.connection() as conn2:
+ assert conn2.pgconn.backend_pid != pid
+ assert conn2.pgconn.transaction_status == TransactionStatus.IDLE
+
+ assert len(caplog.records) == 3
+ assert "INERROR" in caplog.records[0].message
+ assert "OperationalError" in caplog.records[1].message
+ assert "BAD" in caplog.records[2].message
+
+
+def test_putconn_no_pool(dsn):
+ p = pool.ConnectionPool(dsn, minconn=1)
+ conn = psycopg3.connect(dsn)
+ with pytest.raises(ValueError):
+ p.putconn(conn)
+
+
+def test_putconn_wrong_pool(dsn):
+ p1 = pool.ConnectionPool(dsn, minconn=1)
+ p2 = pool.ConnectionPool(dsn, minconn=1)
+ conn = p1.getconn()
+ with pytest.raises(ValueError):
+ p2.putconn(conn)