]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
refactor: generate conninfo attempts from async counterpart
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 6 Jan 2024 17:06:54 +0000 (18:06 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 6 Jan 2024 19:47:40 +0000 (20:47 +0100)
psycopg/psycopg/_conninfo_attempts.py
psycopg/psycopg/_conninfo_attempts_async.py
tests/test_conninfo_attempts.py
tools/async_to_sync.py

index 5262ab78575341dd67d3038e31376c53f0cd497e..4fc0f792a33ced351583efa27218a53f7c83603c 100644 (file)
@@ -1,3 +1,6 @@
+# WARNING: this file is auto-generated by 'async_to_sync.py'
+# from the original file '_conninfo_attempts_async.py'
+# DO NOT CHANGE! Change the original file instead.
 """
 Separate connection attempts from a connection string.
 """
@@ -14,6 +17,7 @@ from . import errors as e
 from ._conninfo_utils import ConnDict, get_param, is_ip_address, get_param_def
 from ._conninfo_utils import split_attempts
 
+
 logger = logging.getLogger("psycopg")
 
 
@@ -52,7 +56,7 @@ def conninfo_attempts(params: ConnDict) -> list[ConnDict]:
 
 def _resolve_hostnames(params: ConnDict) -> list[ConnDict]:
     """
-    Perform DNS lookup of the hosts and return a list of connection attempts.
+    Perform async DNS lookup of the hosts and return a list of connection attempts.
 
     If a ``host`` param is present but not ``hostname``, resolve the host
     addresses asynchronously.
@@ -87,4 +91,5 @@ def _resolve_hostnames(params: ConnDict) -> list[ConnDict]:
     ans = socket.getaddrinfo(
         host, port, proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM
     )
+
     return [{**params, "hostaddr": item[4][0]} for item in ans]
index 037f213ff78ce5966e0f11485f7b7e672ee5a5dd..6aca4ee3adbf3a6a402eca6f8d414098db18f537 100644 (file)
@@ -7,7 +7,6 @@ Separate connection attempts from a connection string.
 from __future__ import annotations
 
 import socket
-import asyncio
 import logging
 from random import shuffle
 
@@ -15,6 +14,9 @@ from . import errors as e
 from ._conninfo_utils import ConnDict, get_param, is_ip_address, get_param_def
 from ._conninfo_utils import split_attempts
 
+if True:  # ASYNC:
+    import asyncio
+
 logger = logging.getLogger("psycopg")
 
 
@@ -35,7 +37,7 @@ async def conninfo_attempts_async(params: ConnDict) -> list[ConnDict]:
     attempts = []
     for attempt in split_attempts(params):
         try:
-            attempts.extend(await _resolve_hostnames_async(attempt))
+            attempts.extend(await _resolve_hostnames(attempt))
         except OSError as ex:
             logger.debug("failed to resolve host %r: %s", attempt.get("host"), str(ex))
             last_exc = ex
@@ -51,7 +53,7 @@ async def conninfo_attempts_async(params: ConnDict) -> list[ConnDict]:
     return attempts
 
 
-async def _resolve_hostnames_async(params: ConnDict) -> list[ConnDict]:
+async def _resolve_hostnames(params: ConnDict) -> list[ConnDict]:
     """
     Perform async DNS lookup of the hosts and return a list of connection attempts.
 
@@ -85,8 +87,14 @@ async def _resolve_hostnames_async(params: ConnDict) -> list[ConnDict]:
         port_def = get_param_def("port")
         port = port_def and port_def.compiled or "5432"
 
-    loop = asyncio.get_running_loop()
-    ans = await loop.getaddrinfo(
-        host, port, proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM
-    )
+    if True:  # ASYNC:
+        loop = asyncio.get_running_loop()
+        ans = await loop.getaddrinfo(
+            host, port, proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM
+        )
+    else:
+        ans = socket.getaddrinfo(
+            host, port, proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM
+        )
+
     return [{**params, "hostaddr": item[4][0]} for item in ans]
index f7bd141d16ab8ab7ba71261dc86db2f4b990e9cd..c2855760ac88ec7f1603aa9b91ba6255909a2fa0 100644 (file)
@@ -1,8 +1,13 @@
+# WARNING: this file is auto-generated by 'async_to_sync.py'
+# from the original file 'test_conninfo_attempts_async.py'
+# DO NOT CHANGE! Change the original file instead.
 import pytest
 
 import psycopg
 from psycopg.conninfo import conninfo_to_dict, conninfo_attempts
 
+pytestmark = pytest.mark.anyio
+
 
 @pytest.mark.parametrize(
     "conninfo, want, env",
@@ -99,21 +104,13 @@ def test_conninfo_attempts_no_resolve(setpgenv, conninfo, want, env, fail_resolv
             ],
             None,
         ),
-        (
-            "host=foo.com,nosuchhost.com",
-            ["host=foo.com hostaddr=1.1.1.1"],
-            None,
-        ),
+        ("host=foo.com,nosuchhost.com", ["host=foo.com hostaddr=1.1.1.1"], None),
         (
             "host=foo.com, port=5432,5433",
             ["host=foo.com hostaddr=1.1.1.1 port=5432", "host='' port=5433"],
             None,
         ),
-        (
-            "host=nosuchhost.com,foo.com",
-            ["host=foo.com hostaddr=1.1.1.1"],
-            None,
-        ),
+        ("host=nosuchhost.com,foo.com", ["host=foo.com hostaddr=1.1.1.1"], None),
         (
             "host=foo.com,qux.com",
             ["host=foo.com hostaddr=1.1.1.1", "host=qux.com hostaddr=2.2.2.2"],
index ece9d9bad2c901afd2ae0c10d74f706b6d310049..d264c78bd20e5339eaa83516ff9f205f5ff1cb97 100755 (executable)
@@ -29,6 +29,7 @@ import ast_comments as ast
 PYVER = "3.11"
 
 ALL_INPUTS = """
+    psycopg/psycopg/_conninfo_attempts_async.py
     psycopg/psycopg/_copy_async.py
     psycopg/psycopg/connection_async.py
     psycopg/psycopg/cursor_async.py
@@ -40,6 +41,7 @@ ALL_INPUTS = """
     tests/pool/test_pool_null_async.py
     tests/pool/test_sched_async.py
     tests/test_connection_async.py
+    tests/test_conninfo_attempts_async.py
     tests/test_copy_async.py
     tests/test_cursor_async.py
     tests/test_cursor_client_async.py