timeout = timeout_from_conninfo(params)
rv = None
attempts = conninfo_attempts(params)
+ connection_errors: list[tuple[e.Error, str]] = []
for attempt in attempts:
try:
conninfo = make_conninfo("", **attempt)
gen = cls._connect_gen(conninfo, timeout=timeout)
rv = waiting.wait_conn(gen, interval=_WAIT_INTERVAL)
except e.Error as ex:
- if len(attempts) > 1:
- logger.debug(
- "connection attempt failed: host: %r port: %r, hostaddr %r: %s",
- attempt.get("host"),
- attempt.get("port"),
- attempt.get("hostaddr"),
- str(ex),
- )
- last_ex = ex
+ attempt_details = "host: {}, port: {}, hostaddr: {}".format(
+ repr(attempt.get("host")),
+ repr(attempt.get("port")),
+ repr(attempt.get("hostaddr")),
+ )
+ connection_errors.append((ex, attempt_details))
except e._NO_TRACEBACK as ex:
raise ex.with_traceback(None)
else:
break
if not rv:
- assert last_ex
- raise last_ex.with_traceback(None)
+ last_exception, _ = connection_errors[-1]
+ if len(connection_errors) == 1:
+ raise last_exception.with_traceback(None)
+ else:
+ formatted_attempts = []
+ for error, attempt_details in connection_errors:
+ formatted_attempts.append(f"- {attempt_details}: {error}")
+ # Create a new exception with the same type as the last one, containing
+ # all attempt errors while preserving backward compatibility.
+ last_exception_type = type(last_exception)
+ message_lines = [
+ f"{last_exception}",
+ "Multiple connection attempts failed. All failures were:",
+ ]
+ message_lines.extend(formatted_attempts)
+ enhanced_message = "\n".join(message_lines)
+ enhanced_exception = last_exception_type(enhanced_message)
+ raise enhanced_exception.with_traceback(None)
rv._autocommit = bool(autocommit)
if row_factory:
timeout = timeout_from_conninfo(params)
rv = None
attempts = await conninfo_attempts_async(params)
+ connection_errors: list[tuple[e.Error, str]] = []
for attempt in attempts:
try:
conninfo = make_conninfo("", **attempt)
gen = cls._connect_gen(conninfo, timeout=timeout)
rv = await waiting.wait_conn_async(gen, interval=_WAIT_INTERVAL)
except e.Error as ex:
- if len(attempts) > 1:
- logger.debug(
- "connection attempt failed: host: %r port: %r, hostaddr %r: %s",
- attempt.get("host"),
- attempt.get("port"),
- attempt.get("hostaddr"),
- str(ex),
- )
- last_ex = ex
+ attempt_details = "host: {}, port: {}, hostaddr: {}".format(
+ repr(attempt.get("host")),
+ repr(attempt.get("port")),
+ repr(attempt.get("hostaddr")),
+ )
+ connection_errors.append((ex, attempt_details))
except e._NO_TRACEBACK as ex:
raise ex.with_traceback(None)
else:
break
if not rv:
- assert last_ex
- raise last_ex.with_traceback(None)
+ last_exception, _ = connection_errors[-1]
+ if len(connection_errors) == 1:
+ raise last_exception.with_traceback(None)
+ else:
+ formatted_attempts = []
+ for error, attempt_details in connection_errors:
+ formatted_attempts.append(f"- {attempt_details}: {error}")
+
+ # Create a new exception with the same type as the last one, containing
+ # all attempt errors while preserving backward compatibility.
+ last_exception_type = type(last_exception)
+
+ message_lines = [
+ f"{last_exception}",
+ "Multiple connection attempts failed. All failures were:",
+ ]
+ message_lines.extend(formatted_attempts)
+ enhanced_message = "\n".join(message_lines)
+ enhanced_exception = last_exception_type(enhanced_message)
+ raise enhanced_exception.with_traceback(None)
rv._autocommit = bool(autocommit)
if row_factory:
from ._test_connection import conninfo_params_timeout, tx_params, tx_params_isolation
from ._test_connection import tx_values_map
+MULTI_FAILURE_MESSAGE = "Multiple connection attempts failed. All failures were:"
+
def test_connect(conn_cls, dsn):
conn = conn_cls.connect(dsn)
conn_cls.connect("dbname=nosuchdb")
+@pytest.mark.slow
+def test_connect_error_single_host_original_message_preserved(conn_cls, proxy):
+ with proxy.deaf_listen():
+ with pytest.raises(psycopg.OperationalError) as e:
+ conn_cls.connect(proxy.client_dsn, connect_timeout=2)
+
+ msg = str(e)
+ assert "connection timeout expired" in msg
+ assert MULTI_FAILURE_MESSAGE not in msg
+
+
+@pytest.mark.slow
+def test_connect_error_multi_hosts_each_message_preserved(conn_cls):
+ # IPv4 address blocks reserved for documentation.
+ # https://datatracker.ietf.org/doc/rfc5737/
+ args = {"host": "192.0.2.1,198.51.100.1", "port": "1234,5678"}
+ with pytest.raises(psycopg.OperationalError) as e:
+ conn_cls.connect(**args, connect_timeout=2)
+
+ msg = str(e.value)
+ assert MULTI_FAILURE_MESSAGE in msg
+
+ host1, host2 = args["host"].split(",")
+ port1, port2 = args["port"].split(",")
+
+ msg_lines = msg.splitlines()
+
+ expected_host1 = f"host: '{host1}', port: '{port1}', hostaddr: '{host1}'"
+ expected_host2 = f"host: '{host2}', port: '{port2}', hostaddr: '{host2}'"
+ expected_error = "connection timeout expired"
+
+ assert any(
+ (expected_host1 in line and expected_error in line for line in msg_lines)
+ )
+ assert any(
+ (expected_host2 in line and expected_error in line for line in msg_lines)
+ )
+
+
def test_connect_str_subclass(conn_cls, dsn):
class MyString(str):
from ._test_connection import conninfo_params_timeout, tx_params, tx_params_isolation
from ._test_connection import tx_values_map
+MULTI_FAILURE_MESSAGE = "Multiple connection attempts failed. All failures were:"
+
async def test_connect(aconn_cls, dsn):
conn = await aconn_cls.connect(dsn)
await aconn_cls.connect("dbname=nosuchdb")
+@pytest.mark.slow
+async def test_connect_error_single_host_original_message_preserved(aconn_cls, proxy):
+ with proxy.deaf_listen():
+ with pytest.raises(psycopg.OperationalError) as e:
+ await aconn_cls.connect(proxy.client_dsn, connect_timeout=2)
+
+ msg = str(e)
+ assert "connection timeout expired" in msg
+ assert MULTI_FAILURE_MESSAGE not in msg
+
+
+@pytest.mark.slow
+async def test_connect_error_multi_hosts_each_message_preserved(aconn_cls):
+ args = {
+ # IPv4 address blocks reserved for documentation.
+ # https://datatracker.ietf.org/doc/rfc5737/
+ "host": "192.0.2.1,198.51.100.1",
+ "port": "1234,5678",
+ }
+ with pytest.raises(psycopg.OperationalError) as e:
+ await aconn_cls.connect(**args, connect_timeout=2)
+
+ msg = str(e.value)
+ assert MULTI_FAILURE_MESSAGE in msg
+
+ host1, host2 = args["host"].split(",")
+ port1, port2 = args["port"].split(",")
+
+ msg_lines = msg.splitlines()
+
+ expected_host1 = f"host: '{host1}', port: '{port1}', hostaddr: '{host1}'"
+ expected_host2 = f"host: '{host2}', port: '{port2}', hostaddr: '{host2}'"
+ expected_error = "connection timeout expired"
+
+ assert any(expected_host1 in line and expected_error in line for line in msg_lines)
+ assert any(expected_host2 in line and expected_error in line for line in msg_lines)
+
+
async def test_connect_str_subclass(aconn_cls, dsn):
class MyString(str):
pass