]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
refactor: return lists from conninfo attempt functions
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 13 Dec 2023 00:33:04 +0000 (01:33 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 13 Dec 2023 00:33:04 +0000 (01:33 +0100)
psycopg/psycopg/_dns.py
psycopg/psycopg/connection_async.py
psycopg/psycopg/conninfo.py
tests/test_conninfo.py

index 86f3468d2c36e2fdbe7f9c67c1ef557d94b1974b..a9619b56d09396a4418a0b9404f7714ae977e5dc 100644 (file)
@@ -52,7 +52,7 @@ async def resolve_hostaddr_async(params: Dict[str, Any]) -> Dict[str, Any]:
     hostaddrs: list[str] = []
     ports: list[str] = []
 
-    async for attempt in conninfo.conninfo_attempts_async(params):
+    for attempt in await conninfo.conninfo_attempts_async(params):
         if attempt.get("host") is not None:
             hosts.append(attempt["host"])
         if attempt.get("hostaddr") is not None:
index 6766a224532fd00bf5d8ece7255760419fc56157..5e628d614f37f98bce67c56b315543ade700f8fa 100644 (file)
@@ -120,7 +120,7 @@ class AsyncConnection(BaseConnection[Row]):
         params = await cls._get_connection_params(conninfo, **kwargs)
         timeout = int(params["connect_timeout"])
         rv = None
-        async for attempt in conninfo_attempts_async(params):
+        for attempt in await conninfo_attempts_async(params):
             try:
                 conninfo = make_conninfo(**attempt)
                 rv = await cls._wait_conn(cls._connect_gen(conninfo), timeout=timeout)
index ee01cd6aca978c7889d884d1ee5149c7dc927444..9351cc951051b67c2bdd59b0f4d0c99174d0e417 100644 (file)
@@ -10,7 +10,7 @@ import os
 import re
 import socket
 import asyncio
-from typing import Any, Iterator, AsyncIterator
+from typing import Any
 from random import shuffle
 from pathlib import Path
 from datetime import tzinfo
@@ -282,7 +282,7 @@ class ConnectionInfo:
         return value.decode(self.encoding)
 
 
-def conninfo_attempts(params: ConnDict) -> Iterator[ConnDict]:
+def conninfo_attempts(params: ConnDict) -> list[ConnDict]:
     """Split a set of connection params on the single attempts to perform.
 
     A connection param can perform more than one attempt more than one ``host``
@@ -298,10 +298,10 @@ def conninfo_attempts(params: ConnDict) -> Iterator[ConnDict]:
     attempts = _split_attempts(params)
     if params.get("load_balance_hosts", "disable") == "random":
         shuffle(attempts)
-    yield from attempts
+    return attempts
 
 
-async def conninfo_attempts_async(params: ConnDict) -> AsyncIterator[ConnDict]:
+async def conninfo_attempts_async(params: ConnDict) -> list[ConnDict]:
     """Split a set of connection params on the single attempts to perform.
 
     A connection param can perform more than one attempt more than one ``host``
@@ -331,8 +331,7 @@ async def conninfo_attempts_async(params: ConnDict) -> AsyncIterator[ConnDict]:
     if params.get("load_balance_hosts", "disable") == "random":
         shuffle(attempts)
 
-    for attempt in attempts:
-        yield attempt
+    return attempts
 
 
 def _split_attempts(params: ConnDict) -> list[ConnDict]:
index db2e38b8a892bd0d384af72d2120da35dc2d1dd3..48995dfdbf73f0785305a592300f8b5d0d55153e 100644 (file)
@@ -10,7 +10,6 @@ from psycopg.conninfo import make_conninfo, conninfo_to_dict, ConnectionInfo
 from psycopg.conninfo import conninfo_attempts, conninfo_attempts_async
 from psycopg._encodings import pg2pyenc
 
-from .utils import alist
 from .fix_crdb import crdb_encoding
 
 snowman = "\u2603"
@@ -349,7 +348,7 @@ class TestConnectionInfo:
 def test_conninfo_attempts(setpgenv, conninfo, want, env):
     setpgenv(env)
     params = conninfo_to_dict(conninfo)
-    attempts = list(conninfo_attempts(params))
+    attempts = conninfo_attempts(params)
     want = list(map(conninfo_to_dict, want))
     assert want == attempts
 
@@ -401,7 +400,7 @@ async def test_conninfo_attempts_async_no_resolve(
 ):
     setpgenv(env)
     params = conninfo_to_dict(conninfo)
-    attempts = await alist(conninfo_attempts_async(params))
+    attempts = await conninfo_attempts_async(params)
     want = list(map(conninfo_to_dict, want))
     assert want == attempts
 
@@ -460,7 +459,7 @@ async def test_conninfo_attempts_async_no_resolve(
 @pytest.mark.anyio
 async def test_conninfo_attempts_async(conninfo, want, env, fake_resolve):
     params = conninfo_to_dict(conninfo)
-    attempts = await alist(conninfo_attempts_async(params))
+    attempts = await conninfo_attempts_async(params)
     want = list(map(conninfo_to_dict, want))
     assert want == attempts
 
@@ -479,7 +478,7 @@ async def test_conninfo_attempts_async_bad(setpgenv, conninfo, env, fake_resolve
     setpgenv(env)
     params = conninfo_to_dict(conninfo)
     with pytest.raises(psycopg.Error):
-        await alist(conninfo_attempts_async(params))
+        await conninfo_attempts_async(params)
 
 
 @pytest.mark.parametrize(
@@ -495,7 +494,7 @@ def test_conninfo_attempts_bad(setpgenv, conninfo, env):
     setpgenv(env)
     params = conninfo_to_dict(conninfo)
     with pytest.raises(psycopg.Error):
-        list(conninfo_attempts(params))
+        conninfo_attempts(params)
 
 
 def test_conninfo_random():
@@ -518,16 +517,16 @@ def test_conninfo_random():
 @pytest.mark.anyio
 async def test_conninfo_random_async(fake_resolve):
     args = {"host": "alot.com"}
-    hostaddrs = [att["hostaddr"] async for att in conninfo_attempts_async(args)]
+    hostaddrs = [att["hostaddr"] for att in await conninfo_attempts_async(args)]
     assert len(hostaddrs) == 20
     assert hostaddrs == sorted(hostaddrs)
 
     args["load_balance_hosts"] = "disable"
-    hostaddrs = [att["hostaddr"] async for att in conninfo_attempts_async(args)]
+    hostaddrs = [att["hostaddr"] for att in await conninfo_attempts_async(args)]
     assert hostaddrs == sorted(hostaddrs)
 
     args["load_balance_hosts"] = "random"
-    hostaddrs = [att["hostaddr"] async for att in conninfo_attempts_async(args)]
+    hostaddrs = [att["hostaddr"] for att in await conninfo_attempts_async(args)]
     assert hostaddrs != sorted(hostaddrs)