From: Daniele Varrazzo Date: Wed, 13 Dec 2023 00:33:04 +0000 (+0100) Subject: refactor: return lists from conninfo attempt functions X-Git-Tag: 3.1.15~1^2~5 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=8dd0204c7354762374bffe22356999b8c2acbdc9;p=thirdparty%2Fpsycopg.git refactor: return lists from conninfo attempt functions --- diff --git a/psycopg/psycopg/_dns.py b/psycopg/psycopg/_dns.py index 86f3468d2..a9619b56d 100644 --- a/psycopg/psycopg/_dns.py +++ b/psycopg/psycopg/_dns.py @@ -52,7 +52,7 @@ async def resolve_hostaddr_async(params: Dict[str, Any]) -> Dict[str, Any]: hostaddrs: list[str] = [] ports: list[str] = [] - async for attempt in conninfo.conninfo_attempts_async(params): + for attempt in await conninfo.conninfo_attempts_async(params): if attempt.get("host") is not None: hosts.append(attempt["host"]) if attempt.get("hostaddr") is not None: diff --git a/psycopg/psycopg/connection_async.py b/psycopg/psycopg/connection_async.py index 6766a2245..5e628d614 100644 --- a/psycopg/psycopg/connection_async.py +++ b/psycopg/psycopg/connection_async.py @@ -120,7 +120,7 @@ class AsyncConnection(BaseConnection[Row]): params = await cls._get_connection_params(conninfo, **kwargs) timeout = int(params["connect_timeout"]) rv = None - async for attempt in conninfo_attempts_async(params): + for attempt in await conninfo_attempts_async(params): try: conninfo = make_conninfo(**attempt) rv = await cls._wait_conn(cls._connect_gen(conninfo), timeout=timeout) diff --git a/psycopg/psycopg/conninfo.py b/psycopg/psycopg/conninfo.py index ee01cd6ac..9351cc951 100644 --- a/psycopg/psycopg/conninfo.py +++ b/psycopg/psycopg/conninfo.py @@ -10,7 +10,7 @@ import os import re import socket import asyncio -from typing import Any, Iterator, AsyncIterator +from typing import Any from random import shuffle from pathlib import Path from datetime import tzinfo @@ -282,7 +282,7 @@ class ConnectionInfo: return value.decode(self.encoding) -def conninfo_attempts(params: ConnDict) -> Iterator[ConnDict]: +def conninfo_attempts(params: ConnDict) -> list[ConnDict]: """Split a set of connection params on the single attempts to perform. A connection param can perform more than one attempt more than one ``host`` @@ -298,10 +298,10 @@ def conninfo_attempts(params: ConnDict) -> Iterator[ConnDict]: attempts = _split_attempts(params) if params.get("load_balance_hosts", "disable") == "random": shuffle(attempts) - yield from attempts + return attempts -async def conninfo_attempts_async(params: ConnDict) -> AsyncIterator[ConnDict]: +async def conninfo_attempts_async(params: ConnDict) -> list[ConnDict]: """Split a set of connection params on the single attempts to perform. A connection param can perform more than one attempt more than one ``host`` @@ -331,8 +331,7 @@ async def conninfo_attempts_async(params: ConnDict) -> AsyncIterator[ConnDict]: if params.get("load_balance_hosts", "disable") == "random": shuffle(attempts) - for attempt in attempts: - yield attempt + return attempts def _split_attempts(params: ConnDict) -> list[ConnDict]: diff --git a/tests/test_conninfo.py b/tests/test_conninfo.py index db2e38b8a..48995dfdb 100644 --- a/tests/test_conninfo.py +++ b/tests/test_conninfo.py @@ -10,7 +10,6 @@ from psycopg.conninfo import make_conninfo, conninfo_to_dict, ConnectionInfo from psycopg.conninfo import conninfo_attempts, conninfo_attempts_async from psycopg._encodings import pg2pyenc -from .utils import alist from .fix_crdb import crdb_encoding snowman = "\u2603" @@ -349,7 +348,7 @@ class TestConnectionInfo: def test_conninfo_attempts(setpgenv, conninfo, want, env): setpgenv(env) params = conninfo_to_dict(conninfo) - attempts = list(conninfo_attempts(params)) + attempts = conninfo_attempts(params) want = list(map(conninfo_to_dict, want)) assert want == attempts @@ -401,7 +400,7 @@ async def test_conninfo_attempts_async_no_resolve( ): setpgenv(env) params = conninfo_to_dict(conninfo) - attempts = await alist(conninfo_attempts_async(params)) + attempts = await conninfo_attempts_async(params) want = list(map(conninfo_to_dict, want)) assert want == attempts @@ -460,7 +459,7 @@ async def test_conninfo_attempts_async_no_resolve( @pytest.mark.anyio async def test_conninfo_attempts_async(conninfo, want, env, fake_resolve): params = conninfo_to_dict(conninfo) - attempts = await alist(conninfo_attempts_async(params)) + attempts = await conninfo_attempts_async(params) want = list(map(conninfo_to_dict, want)) assert want == attempts @@ -479,7 +478,7 @@ async def test_conninfo_attempts_async_bad(setpgenv, conninfo, env, fake_resolve setpgenv(env) params = conninfo_to_dict(conninfo) with pytest.raises(psycopg.Error): - await alist(conninfo_attempts_async(params)) + await conninfo_attempts_async(params) @pytest.mark.parametrize( @@ -495,7 +494,7 @@ def test_conninfo_attempts_bad(setpgenv, conninfo, env): setpgenv(env) params = conninfo_to_dict(conninfo) with pytest.raises(psycopg.Error): - list(conninfo_attempts(params)) + conninfo_attempts(params) def test_conninfo_random(): @@ -518,16 +517,16 @@ def test_conninfo_random(): @pytest.mark.anyio async def test_conninfo_random_async(fake_resolve): args = {"host": "alot.com"} - hostaddrs = [att["hostaddr"] async for att in conninfo_attempts_async(args)] + hostaddrs = [att["hostaddr"] for att in await conninfo_attempts_async(args)] assert len(hostaddrs) == 20 assert hostaddrs == sorted(hostaddrs) args["load_balance_hosts"] = "disable" - hostaddrs = [att["hostaddr"] async for att in conninfo_attempts_async(args)] + hostaddrs = [att["hostaddr"] for att in await conninfo_attempts_async(args)] assert hostaddrs == sorted(hostaddrs) args["load_balance_hosts"] = "random" - hostaddrs = [att["hostaddr"] async for att in conninfo_attempts_async(args)] + hostaddrs = [att["hostaddr"] for att in await conninfo_attempts_async(args)] assert hostaddrs != sorted(hostaddrs)