]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Reset the connection status returning it to the pool
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 13 Feb 2021 22:07:54 +0000 (23:07 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 12 Mar 2021 04:07:25 +0000 (05:07 +0100)
psycopg3/psycopg3/pool.py
tests/test_pool.py

index d3e081ba124ea34c1e3d2ac2cdf38f6e871f2c88..358203ec5e4d0ff385053753e0ba506af2ee429c 100644 (file)
@@ -13,6 +13,7 @@ from contextlib import contextmanager
 from collections import deque
 
 from . import errors as e
+from .pq import TransactionStatus
 from .connection import Connection
 
 WORKER_TIMEOUT = 60.0
@@ -120,19 +121,30 @@ class ConnectionPool:
         return conn
 
     def putconn(self, conn: Connection) -> None:
-        # TODO: this should happen in a maintenance thread
-        # TODO: add check for broken connections
-
         if conn._pool is not self:
             if conn._pool:
-                raise ValueError(f"the connection belongs to {conn._pool}")
+                msg = f"it comes from pool {conn._pool.name!r}"
             else:
-                raise ValueError("the connection doesn't belong to a pool")
+                msg = "it doesn't come from any pool"
+            raise ValueError(
+                f"can't return connection to pool {self.name!r}, {msg}: {conn}"
+            )
 
+        # Use a worker to perform eventual maintenance work in a separate thread
+        self.add_task(ReturnConnection(self, conn))
+
+    def _return_connection(self, conn: Connection) -> None:
         # Remove the pool reference from the connection before returning it
         # to the state, to avoid to create a reference loop.
         conn._pool = None
 
+        self._reset_transaction_status(conn)
+        if conn.pgconn.transaction_status == TransactionStatus.UNKNOWN:
+            # Connection no more in working state: create a new one.
+            logger.warning("discarding closed connection: %s", conn)
+            self.add_task(AddConnection(self))
+            return
+
         # Critical section: if there is a client waiting give it the connection
         # otherwise put it back into the pool.
         with self._lock:
@@ -144,6 +156,33 @@ class ConnectionPool:
                 # No client waiting for a connection: put it back into the queue
                 self._pool.append(conn)
 
+    def _reset_transaction_status(self, conn: Connection) -> None:
+        """
+        Bring a connection to IDLE state or close it.
+        """
+        status = conn.pgconn.transaction_status
+        if status == TransactionStatus.IDLE:
+            return
+
+        if status in (TransactionStatus.INTRANS, TransactionStatus.INERROR):
+            # Connection returned with an active transaction
+            logger.warning("rolling back returned connection: %s", conn)
+            try:
+                conn.rollback()
+            except Exception as e:
+                logger.warning(
+                    "rollback failed: %s: %s. Discarding connection %s",
+                    e.__class__.__name__,
+                    e,
+                    conn,
+                )
+                conn.close()
+
+        elif status == TransactionStatus.ACTIVE:
+            # Connection returned during an operation. Bad... just close it.
+            logger.warning("closing returned connection: %s", conn)
+            conn.close()
+
     def add_task(self, task: "MaintenanceTask") -> None:
         """Add a task to the queue of tasts to perform."""
         self._wqueue.put(task)
@@ -252,3 +291,13 @@ class AddConnection(MaintenanceTask):
         conn = self.pool._connect()
         conn._pool = self.pool  # make it acceptable
         self.pool.putconn(conn)
+
+
+class ReturnConnection(MaintenanceTask):
+    def __init__(self, pool: ConnectionPool, conn: Connection):
+        super().__init__(pool)
+        self.conn = conn
+
+    def __call__(self) -> None:
+        super().__call__()
+        self.pool._return_connection(self.conn)
index febe018c5e589bcd05989df08b75b4847220a1cf..b327cc1ddd8d112fa8718ffa815c76aba069d0b4 100644 (file)
@@ -1,8 +1,11 @@
+import logging
 from time import time
 from threading import Thread
 
 import pytest
 
+import psycopg3
+from psycopg3.pq import TransactionStatus
 from psycopg3 import pool
 
 
@@ -133,3 +136,122 @@ def test_queue_timeout_override(dsn):
     assert len(errors) == 1
     for e in errors:
         assert 0.1 < e[1] < 0.15
+
+
+def test_broken_reconnect(dsn, caplog):
+    caplog.set_level(logging.WARNING, logger="psycopg3.pool")
+    p = pool.ConnectionPool(dsn, minconn=1)
+    with pytest.raises(psycopg3.OperationalError):
+        with p.connection() as conn:
+            with conn.execute("select pg_backend_pid()") as cur:
+                (pid1,) = cur.fetchone()
+            conn.close()
+
+    with p.connection() as conn2:
+        with conn2.execute("select pg_backend_pid()") as cur:
+            (pid2,) = cur.fetchone()
+
+    assert pid1 != pid2
+
+
+def test_intrans_rollback(dsn, caplog):
+    caplog.set_level(logging.WARNING, logger="psycopg3.pool")
+    p = pool.ConnectionPool(dsn, minconn=1)
+    conn = p.getconn()
+    pid = conn.pgconn.backend_pid
+    conn.execute("create table test_intrans_rollback ()")
+    assert conn.pgconn.transaction_status == TransactionStatus.INTRANS
+    p.putconn(conn)
+
+    with p.connection() as conn2:
+        assert conn2.pgconn.backend_pid == pid
+        assert conn2.pgconn.transaction_status == TransactionStatus.IDLE
+        assert not conn.execute(
+            "select 1 from pg_class where relname = 'test_intrans_rollback'"
+        ).fetchone()
+
+    assert len(caplog.records) == 1
+    assert "INTRANS" in caplog.records[0].message
+
+
+def test_inerror_rollback(dsn, caplog):
+    caplog.set_level(logging.WARNING, logger="psycopg3.pool")
+    p = pool.ConnectionPool(dsn, minconn=1)
+    conn = p.getconn()
+    pid = conn.pgconn.backend_pid
+    with pytest.raises(psycopg3.ProgrammingError):
+        conn.execute("wat")
+    assert conn.pgconn.transaction_status == TransactionStatus.INERROR
+    p.putconn(conn)
+
+    with p.connection() as conn2:
+        assert conn2.pgconn.backend_pid == pid
+        assert conn2.pgconn.transaction_status == TransactionStatus.IDLE
+
+    assert len(caplog.records) == 1
+    assert "INERROR" in caplog.records[0].message
+
+
+def test_active_close(dsn, caplog):
+    caplog.set_level(logging.WARNING, logger="psycopg3.pool")
+    p = pool.ConnectionPool(dsn, minconn=1)
+    conn = p.getconn()
+    pid = conn.pgconn.backend_pid
+    cur = conn.cursor()
+    with cur.copy("copy (select * from generate_series(1, 10)) to stdout"):
+        pass
+    assert conn.pgconn.transaction_status == TransactionStatus.ACTIVE
+    p.putconn(conn)
+
+    with p.connection() as conn2:
+        assert conn2.pgconn.backend_pid != pid
+        assert conn2.pgconn.transaction_status == TransactionStatus.IDLE
+
+    assert len(caplog.records) == 2
+    assert "ACTIVE" in caplog.records[0].message
+    assert "BAD" in caplog.records[1].message
+
+
+def test_fail_rollback_close(dsn, caplog, monkeypatch):
+    caplog.set_level(logging.WARNING, logger="psycopg3.pool")
+    p = pool.ConnectionPool(dsn, minconn=1)
+    conn = p.getconn()
+
+    # Make the rollback fail
+    orig_rollback = conn.rollback
+
+    def bad_rollback():
+        conn.pgconn.finish()
+        orig_rollback()
+
+    monkeypatch.setattr(conn, "rollback", bad_rollback)
+
+    pid = conn.pgconn.backend_pid
+    with pytest.raises(psycopg3.ProgrammingError):
+        conn.execute("wat")
+    assert conn.pgconn.transaction_status == TransactionStatus.INERROR
+    p.putconn(conn)
+
+    with p.connection() as conn2:
+        assert conn2.pgconn.backend_pid != pid
+        assert conn2.pgconn.transaction_status == TransactionStatus.IDLE
+
+    assert len(caplog.records) == 3
+    assert "INERROR" in caplog.records[0].message
+    assert "OperationalError" in caplog.records[1].message
+    assert "BAD" in caplog.records[2].message
+
+
+def test_putconn_no_pool(dsn):
+    p = pool.ConnectionPool(dsn, minconn=1)
+    conn = psycopg3.connect(dsn)
+    with pytest.raises(ValueError):
+        p.putconn(conn)
+
+
+def test_putconn_wrong_pool(dsn):
+    p1 = pool.ConnectionPool(dsn, minconn=1)
+    p2 = pool.ConnectionPool(dsn, minconn=1)
+    conn = p1.getconn()
+    with pytest.raises(ValueError):
+        p2.putconn(conn)