]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix: show all connection errors on attempt failure 1076/head
authorDaniel Frankcom <frankcom@amazon.com>
Wed, 7 May 2025 23:20:43 +0000 (16:20 -0700)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 8 May 2025 05:10:21 +0000 (07:10 +0200)
psycopg/psycopg/connection.py
psycopg/psycopg/connection_async.py
tests/test_connection.py
tests/test_connection_async.py

index 8ee5128ed8ecec0a6ea511e225f187c863aa5747..f2b3a748ed47d52bebe9c18346c0717c744ed876 100644 (file)
@@ -95,29 +95,43 @@ class Connection(BaseConnection[Row]):
         timeout = timeout_from_conninfo(params)
         rv = None
         attempts = conninfo_attempts(params)
+        connection_errors: list[tuple[e.Error, str]] = []
         for attempt in attempts:
             try:
                 conninfo = make_conninfo("", **attempt)
                 gen = cls._connect_gen(conninfo, timeout=timeout)
                 rv = waiting.wait_conn(gen, interval=_WAIT_INTERVAL)
             except e.Error as ex:
-                if len(attempts) > 1:
-                    logger.debug(
-                        "connection attempt failed: host: %r port: %r, hostaddr %r: %s",
-                        attempt.get("host"),
-                        attempt.get("port"),
-                        attempt.get("hostaddr"),
-                        str(ex),
-                    )
-                last_ex = ex
+                attempt_details = "host: {}, port: {}, hostaddr: {}".format(
+                    repr(attempt.get("host")),
+                    repr(attempt.get("port")),
+                    repr(attempt.get("hostaddr")),
+                )
+                connection_errors.append((ex, attempt_details))
             except e._NO_TRACEBACK as ex:
                 raise ex.with_traceback(None)
             else:
                 break
 
         if not rv:
-            assert last_ex
-            raise last_ex.with_traceback(None)
+            last_exception, _ = connection_errors[-1]
+            if len(connection_errors) == 1:
+                raise last_exception.with_traceback(None)
+            else:
+                formatted_attempts = []
+                for error, attempt_details in connection_errors:
+                    formatted_attempts.append(f"- {attempt_details}: {error}")
+                # Create a new exception with the same type as the last one, containing
+                # all attempt errors while preserving backward compatibility.
+                last_exception_type = type(last_exception)
+                message_lines = [
+                    f"{last_exception}",
+                    "Multiple connection attempts failed. All failures were:",
+                ]
+                message_lines.extend(formatted_attempts)
+                enhanced_message = "\n".join(message_lines)
+                enhanced_exception = last_exception_type(enhanced_message)
+                raise enhanced_exception.with_traceback(None)
 
         rv._autocommit = bool(autocommit)
         if row_factory:
index ddfccae1aff3aba7b4337c1d9d6c9df578e0f916..19334c28013807a15f0e07316ab83b3f148736c8 100644 (file)
@@ -111,29 +111,45 @@ class AsyncConnection(BaseConnection[Row]):
         timeout = timeout_from_conninfo(params)
         rv = None
         attempts = await conninfo_attempts_async(params)
+        connection_errors: list[tuple[e.Error, str]] = []
         for attempt in attempts:
             try:
                 conninfo = make_conninfo("", **attempt)
                 gen = cls._connect_gen(conninfo, timeout=timeout)
                 rv = await waiting.wait_conn_async(gen, interval=_WAIT_INTERVAL)
             except e.Error as ex:
-                if len(attempts) > 1:
-                    logger.debug(
-                        "connection attempt failed: host: %r port: %r, hostaddr %r: %s",
-                        attempt.get("host"),
-                        attempt.get("port"),
-                        attempt.get("hostaddr"),
-                        str(ex),
-                    )
-                last_ex = ex
+                attempt_details = "host: {}, port: {}, hostaddr: {}".format(
+                    repr(attempt.get("host")),
+                    repr(attempt.get("port")),
+                    repr(attempt.get("hostaddr")),
+                )
+                connection_errors.append((ex, attempt_details))
             except e._NO_TRACEBACK as ex:
                 raise ex.with_traceback(None)
             else:
                 break
 
         if not rv:
-            assert last_ex
-            raise last_ex.with_traceback(None)
+            last_exception, _ = connection_errors[-1]
+            if len(connection_errors) == 1:
+                raise last_exception.with_traceback(None)
+            else:
+                formatted_attempts = []
+                for error, attempt_details in connection_errors:
+                    formatted_attempts.append(f"- {attempt_details}: {error}")
+
+                # Create a new exception with the same type as the last one, containing
+                # all attempt errors while preserving backward compatibility.
+                last_exception_type = type(last_exception)
+
+                message_lines = [
+                    f"{last_exception}",
+                    "Multiple connection attempts failed. All failures were:",
+                ]
+                message_lines.extend(formatted_attempts)
+                enhanced_message = "\n".join(message_lines)
+                enhanced_exception = last_exception_type(enhanced_message)
+                raise enhanced_exception.with_traceback(None)
 
         rv._autocommit = bool(autocommit)
         if row_factory:
index 60d2083a13cd636e272c99b31fd75b06737685ba..c340205c7b01bacacc04d3e06ba4f248cc66ea07 100644 (file)
@@ -24,6 +24,8 @@ from ._test_connection import testctx  # noqa: F401  # fixture
 from ._test_connection import conninfo_params_timeout, tx_params, tx_params_isolation
 from ._test_connection import tx_values_map
 
+MULTI_FAILURE_MESSAGE = "Multiple connection attempts failed. All failures were:"
+
 
 def test_connect(conn_cls, dsn):
     conn = conn_cls.connect(dsn)
@@ -37,6 +39,45 @@ def test_connect_bad(conn_cls):
         conn_cls.connect("dbname=nosuchdb")
 
 
+@pytest.mark.slow
+def test_connect_error_single_host_original_message_preserved(conn_cls, proxy):
+    with proxy.deaf_listen():
+        with pytest.raises(psycopg.OperationalError) as e:
+            conn_cls.connect(proxy.client_dsn, connect_timeout=2)
+
+    msg = str(e)
+    assert "connection timeout expired" in msg
+    assert MULTI_FAILURE_MESSAGE not in msg
+
+
+@pytest.mark.slow
+def test_connect_error_multi_hosts_each_message_preserved(conn_cls):
+    # IPv4 address blocks reserved for documentation.
+    # https://datatracker.ietf.org/doc/rfc5737/
+    args = {"host": "192.0.2.1,198.51.100.1", "port": "1234,5678"}
+    with pytest.raises(psycopg.OperationalError) as e:
+        conn_cls.connect(**args, connect_timeout=2)
+
+    msg = str(e.value)
+    assert MULTI_FAILURE_MESSAGE in msg
+
+    host1, host2 = args["host"].split(",")
+    port1, port2 = args["port"].split(",")
+
+    msg_lines = msg.splitlines()
+
+    expected_host1 = f"host: '{host1}', port: '{port1}', hostaddr: '{host1}'"
+    expected_host2 = f"host: '{host2}', port: '{port2}', hostaddr: '{host2}'"
+    expected_error = "connection timeout expired"
+
+    assert any(
+        (expected_host1 in line and expected_error in line for line in msg_lines)
+    )
+    assert any(
+        (expected_host2 in line and expected_error in line for line in msg_lines)
+    )
+
+
 def test_connect_str_subclass(conn_cls, dsn):
 
     class MyString(str):
index c8da671da328e20d9a407507b4f7874c1df08ee3..ea7e6f77dfbbc2b4be3f2b54e67ab79d55c3cf72 100644 (file)
@@ -21,6 +21,8 @@ from ._test_connection import testctx  # noqa: F401  # fixture
 from ._test_connection import conninfo_params_timeout, tx_params, tx_params_isolation
 from ._test_connection import tx_values_map
 
+MULTI_FAILURE_MESSAGE = "Multiple connection attempts failed. All failures were:"
+
 
 async def test_connect(aconn_cls, dsn):
     conn = await aconn_cls.connect(dsn)
@@ -34,6 +36,44 @@ async def test_connect_bad(aconn_cls):
         await aconn_cls.connect("dbname=nosuchdb")
 
 
+@pytest.mark.slow
+async def test_connect_error_single_host_original_message_preserved(aconn_cls, proxy):
+    with proxy.deaf_listen():
+        with pytest.raises(psycopg.OperationalError) as e:
+            await aconn_cls.connect(proxy.client_dsn, connect_timeout=2)
+
+    msg = str(e)
+    assert "connection timeout expired" in msg
+    assert MULTI_FAILURE_MESSAGE not in msg
+
+
+@pytest.mark.slow
+async def test_connect_error_multi_hosts_each_message_preserved(aconn_cls):
+    args = {
+        # IPv4 address blocks reserved for documentation.
+        # https://datatracker.ietf.org/doc/rfc5737/
+        "host": "192.0.2.1,198.51.100.1",
+        "port": "1234,5678",
+    }
+    with pytest.raises(psycopg.OperationalError) as e:
+        await aconn_cls.connect(**args, connect_timeout=2)
+
+    msg = str(e.value)
+    assert MULTI_FAILURE_MESSAGE in msg
+
+    host1, host2 = args["host"].split(",")
+    port1, port2 = args["port"].split(",")
+
+    msg_lines = msg.splitlines()
+
+    expected_host1 = f"host: '{host1}', port: '{port1}', hostaddr: '{host1}'"
+    expected_host2 = f"host: '{host2}', port: '{port2}', hostaddr: '{host2}'"
+    expected_error = "connection timeout expired"
+
+    assert any(expected_host1 in line and expected_error in line for line in msg_lines)
+    assert any(expected_host2 in line and expected_error in line for line in msg_lines)
+
+
 async def test_connect_str_subclass(aconn_cls, dsn):
     class MyString(str):
         pass