From: Daniel Frankcom Date: Wed, 7 May 2025 23:20:43 +0000 (-0700) Subject: fix: show all connection errors on attempt failure X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=refs%2Fpull%2F1076%2Fhead;p=thirdparty%2Fpsycopg.git fix: show all connection errors on attempt failure --- diff --git a/psycopg/psycopg/connection.py b/psycopg/psycopg/connection.py index 8ee5128ed..f2b3a748e 100644 --- a/psycopg/psycopg/connection.py +++ b/psycopg/psycopg/connection.py @@ -95,29 +95,43 @@ class Connection(BaseConnection[Row]): timeout = timeout_from_conninfo(params) rv = None attempts = conninfo_attempts(params) + connection_errors: list[tuple[e.Error, str]] = [] for attempt in attempts: try: conninfo = make_conninfo("", **attempt) gen = cls._connect_gen(conninfo, timeout=timeout) rv = waiting.wait_conn(gen, interval=_WAIT_INTERVAL) except e.Error as ex: - if len(attempts) > 1: - logger.debug( - "connection attempt failed: host: %r port: %r, hostaddr %r: %s", - attempt.get("host"), - attempt.get("port"), - attempt.get("hostaddr"), - str(ex), - ) - last_ex = ex + attempt_details = "host: {}, port: {}, hostaddr: {}".format( + repr(attempt.get("host")), + repr(attempt.get("port")), + repr(attempt.get("hostaddr")), + ) + connection_errors.append((ex, attempt_details)) except e._NO_TRACEBACK as ex: raise ex.with_traceback(None) else: break if not rv: - assert last_ex - raise last_ex.with_traceback(None) + last_exception, _ = connection_errors[-1] + if len(connection_errors) == 1: + raise last_exception.with_traceback(None) + else: + formatted_attempts = [] + for error, attempt_details in connection_errors: + formatted_attempts.append(f"- {attempt_details}: {error}") + # Create a new exception with the same type as the last one, containing + # all attempt errors while preserving backward compatibility. + last_exception_type = type(last_exception) + message_lines = [ + f"{last_exception}", + "Multiple connection attempts failed. All failures were:", + ] + message_lines.extend(formatted_attempts) + enhanced_message = "\n".join(message_lines) + enhanced_exception = last_exception_type(enhanced_message) + raise enhanced_exception.with_traceback(None) rv._autocommit = bool(autocommit) if row_factory: diff --git a/psycopg/psycopg/connection_async.py b/psycopg/psycopg/connection_async.py index ddfccae1a..19334c280 100644 --- a/psycopg/psycopg/connection_async.py +++ b/psycopg/psycopg/connection_async.py @@ -111,29 +111,45 @@ class AsyncConnection(BaseConnection[Row]): timeout = timeout_from_conninfo(params) rv = None attempts = await conninfo_attempts_async(params) + connection_errors: list[tuple[e.Error, str]] = [] for attempt in attempts: try: conninfo = make_conninfo("", **attempt) gen = cls._connect_gen(conninfo, timeout=timeout) rv = await waiting.wait_conn_async(gen, interval=_WAIT_INTERVAL) except e.Error as ex: - if len(attempts) > 1: - logger.debug( - "connection attempt failed: host: %r port: %r, hostaddr %r: %s", - attempt.get("host"), - attempt.get("port"), - attempt.get("hostaddr"), - str(ex), - ) - last_ex = ex + attempt_details = "host: {}, port: {}, hostaddr: {}".format( + repr(attempt.get("host")), + repr(attempt.get("port")), + repr(attempt.get("hostaddr")), + ) + connection_errors.append((ex, attempt_details)) except e._NO_TRACEBACK as ex: raise ex.with_traceback(None) else: break if not rv: - assert last_ex - raise last_ex.with_traceback(None) + last_exception, _ = connection_errors[-1] + if len(connection_errors) == 1: + raise last_exception.with_traceback(None) + else: + formatted_attempts = [] + for error, attempt_details in connection_errors: + formatted_attempts.append(f"- {attempt_details}: {error}") + + # Create a new exception with the same type as the last one, containing + # all attempt errors while preserving backward compatibility. + last_exception_type = type(last_exception) + + message_lines = [ + f"{last_exception}", + "Multiple connection attempts failed. All failures were:", + ] + message_lines.extend(formatted_attempts) + enhanced_message = "\n".join(message_lines) + enhanced_exception = last_exception_type(enhanced_message) + raise enhanced_exception.with_traceback(None) rv._autocommit = bool(autocommit) if row_factory: diff --git a/tests/test_connection.py b/tests/test_connection.py index 60d2083a1..c340205c7 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -24,6 +24,8 @@ from ._test_connection import testctx # noqa: F401 # fixture from ._test_connection import conninfo_params_timeout, tx_params, tx_params_isolation from ._test_connection import tx_values_map +MULTI_FAILURE_MESSAGE = "Multiple connection attempts failed. All failures were:" + def test_connect(conn_cls, dsn): conn = conn_cls.connect(dsn) @@ -37,6 +39,45 @@ def test_connect_bad(conn_cls): conn_cls.connect("dbname=nosuchdb") +@pytest.mark.slow +def test_connect_error_single_host_original_message_preserved(conn_cls, proxy): + with proxy.deaf_listen(): + with pytest.raises(psycopg.OperationalError) as e: + conn_cls.connect(proxy.client_dsn, connect_timeout=2) + + msg = str(e) + assert "connection timeout expired" in msg + assert MULTI_FAILURE_MESSAGE not in msg + + +@pytest.mark.slow +def test_connect_error_multi_hosts_each_message_preserved(conn_cls): + # IPv4 address blocks reserved for documentation. + # https://datatracker.ietf.org/doc/rfc5737/ + args = {"host": "192.0.2.1,198.51.100.1", "port": "1234,5678"} + with pytest.raises(psycopg.OperationalError) as e: + conn_cls.connect(**args, connect_timeout=2) + + msg = str(e.value) + assert MULTI_FAILURE_MESSAGE in msg + + host1, host2 = args["host"].split(",") + port1, port2 = args["port"].split(",") + + msg_lines = msg.splitlines() + + expected_host1 = f"host: '{host1}', port: '{port1}', hostaddr: '{host1}'" + expected_host2 = f"host: '{host2}', port: '{port2}', hostaddr: '{host2}'" + expected_error = "connection timeout expired" + + assert any( + (expected_host1 in line and expected_error in line for line in msg_lines) + ) + assert any( + (expected_host2 in line and expected_error in line for line in msg_lines) + ) + + def test_connect_str_subclass(conn_cls, dsn): class MyString(str): diff --git a/tests/test_connection_async.py b/tests/test_connection_async.py index c8da671da..ea7e6f77d 100644 --- a/tests/test_connection_async.py +++ b/tests/test_connection_async.py @@ -21,6 +21,8 @@ from ._test_connection import testctx # noqa: F401 # fixture from ._test_connection import conninfo_params_timeout, tx_params, tx_params_isolation from ._test_connection import tx_values_map +MULTI_FAILURE_MESSAGE = "Multiple connection attempts failed. All failures were:" + async def test_connect(aconn_cls, dsn): conn = await aconn_cls.connect(dsn) @@ -34,6 +36,44 @@ async def test_connect_bad(aconn_cls): await aconn_cls.connect("dbname=nosuchdb") +@pytest.mark.slow +async def test_connect_error_single_host_original_message_preserved(aconn_cls, proxy): + with proxy.deaf_listen(): + with pytest.raises(psycopg.OperationalError) as e: + await aconn_cls.connect(proxy.client_dsn, connect_timeout=2) + + msg = str(e) + assert "connection timeout expired" in msg + assert MULTI_FAILURE_MESSAGE not in msg + + +@pytest.mark.slow +async def test_connect_error_multi_hosts_each_message_preserved(aconn_cls): + args = { + # IPv4 address blocks reserved for documentation. + # https://datatracker.ietf.org/doc/rfc5737/ + "host": "192.0.2.1,198.51.100.1", + "port": "1234,5678", + } + with pytest.raises(psycopg.OperationalError) as e: + await aconn_cls.connect(**args, connect_timeout=2) + + msg = str(e.value) + assert MULTI_FAILURE_MESSAGE in msg + + host1, host2 = args["host"].split(",") + port1, port2 = args["port"].split(",") + + msg_lines = msg.splitlines() + + expected_host1 = f"host: '{host1}', port: '{port1}', hostaddr: '{host1}'" + expected_host2 = f"host: '{host2}', port: '{port2}', hostaddr: '{host2}'" + expected_error = "connection timeout expired" + + assert any(expected_host1 in line and expected_error in line for line in msg_lines) + assert any(expected_host2 in line and expected_error in line for line in msg_lines) + + async def test_connect_str_subclass(aconn_cls, dsn): class MyString(str): pass