]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
ensure unix path syntax works for asyncpg as well
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 6 Jul 2023 22:01:25 +0000 (18:01 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 6 Jul 2023 22:01:25 +0000 (18:01 -0400)
update for a2c06a2a0acf769060f11bb34c1b55cecae5f5fe

updates test suite to include direct expected data / errors in
the test data

Fixes: #10069
Change-Id: I1e689101b90b7469608b74ed37abd7c2122151a4

lib/sqlalchemy/dialects/postgresql/asyncpg.py
test/dialect/postgresql/test_dialect.py

index e3713c6d6e9f9b87e0a6a912be02eda65065427a..d4350cc2892dc478f19bba368bb9c3974ada7c4f 100644 (file)
@@ -1089,7 +1089,11 @@ class PGDialect_asyncpg(PGDialect):
 
         if multihosts:
             assert multiports
-            if not all(multihosts):
+            if len(multihosts) == 1:
+                opts["host"] = multihosts[0]
+                if multiports[0] is not None:
+                    opts["port"] = multiports[0]
+            elif not all(multihosts):
                 raise exc.ArgumentError(
                     "All hosts are required to be present"
                     " for asyncpg multiple host URL"
@@ -1099,8 +1103,9 @@ class PGDialect_asyncpg(PGDialect):
                     "All ports are required to be present"
                     " for asyncpg multiple host URL"
                 )
-            opts["host"] = list(multihosts)
-            opts["port"] = list(multiports)
+            else:
+                opts["host"] = list(multihosts)
+                opts["port"] = list(multiports)
         else:
             util.coerce_kw_type(opts, "port", int)
         util.coerce_kw_type(opts, "prepared_statement_cache_size", int)
index 31335a84c0c2fa07f2d0e9555f7582580f8c33c2..a55fc0a6bbdc8b1681ad50c555a0387b69e19077 100644 (file)
@@ -2,7 +2,6 @@ import dataclasses
 import datetime
 import logging
 import logging.handlers
-import re
 
 from sqlalchemy import BigInteger
 from sqlalchemy import bindparam
@@ -317,6 +316,7 @@ class MultiHostConnectTest(fixtures.TestBase):
                     "password": "PASS",
                     "host": "192.168.1.50",
                     "port": "5678",
+                    "asyncpg_port": 5678,
                 },
             ),
             (
@@ -345,6 +345,7 @@ class MultiHostConnectTest(fixtures.TestBase):
                     "password": "PASS",
                     "host": "HOSTNAME",
                     "port": "1234",
+                    "asyncpg_port": 1234,
                 },
             ),
             (
@@ -384,6 +385,7 @@ class MultiHostConnectTest(fixtures.TestBase):
                     "password": "PASS",
                     "host": "hostA",
                     "port": "1234",
+                    "asyncpg_port": 1234,
                 },
             ),
             (
@@ -395,6 +397,8 @@ class MultiHostConnectTest(fixtures.TestBase):
                     "password": "PASS",
                     "host": "hostA,hostB,hostC",
                     "port": ",,",
+                    "asyncpg_error": "All ports are required to be present"
+                    " for asyncpg multiple host URL",
                 },
             ),
             (
@@ -406,6 +410,8 @@ class MultiHostConnectTest(fixtures.TestBase):
                     "password": "PASS",
                     "host": "hostA,hostB,hostC",
                     "port": ",222,333",
+                    "asyncpg_error": "All ports are required to be present"
+                    " for asyncpg multiple host URL",
                 },
             ),
             (
@@ -417,22 +423,39 @@ class MultiHostConnectTest(fixtures.TestBase):
                     "password": "PASS",
                     "host": "hostA,hostB,hostC",
                     "port": "111,222,333",
+                    "asyncpg_host": ["hostA", "hostB", "hostC"],
+                    "asyncpg_port": [111, 222, 333],
                 },
             ),
             (
                 "postgresql+psycopg2:///"
                 "?host=hostA:111&host=hostB:222&host=hostC:333",
-                {"host": "hostA,hostB,hostC", "port": "111,222,333"},
+                {
+                    "host": "hostA,hostB,hostC",
+                    "port": "111,222,333",
+                    "asyncpg_host": ["hostA", "hostB", "hostC"],
+                    "asyncpg_port": [111, 222, 333],
+                },
             ),
             (
                 "postgresql+psycopg2:///"
                 "?host=hostA:111&host=hostB:222&host=hostC:333",
-                {"host": "hostA,hostB,hostC", "port": "111,222,333"},
+                {
+                    "host": "hostA,hostB,hostC",
+                    "port": "111,222,333",
+                    "asyncpg_host": ["hostA", "hostB", "hostC"],
+                    "asyncpg_port": [111, 222, 333],
+                },
             ),
             (
                 "postgresql+psycopg2:///"
                 "?host=hostA,hostB,hostC&port=111,222,333",
-                {"host": "hostA,hostB,hostC", "port": "111,222,333"},
+                {
+                    "host": "hostA,hostB,hostC",
+                    "port": "111,222,333",
+                    "asyncpg_host": ["hostA", "hostB", "hostC"],
+                    "asyncpg_port": [111, 222, 333],
+                },
             ),
             (
                 "postgresql+asyncpg://USER:PASS@/DB"
@@ -443,20 +466,29 @@ class MultiHostConnectTest(fixtures.TestBase):
                     "dbname": "DB",
                     "user": "USER",
                     "password": "PASS",
+                    "asyncpg_error": "All hosts are required to be present"
+                    " for asyncpg multiple host URL",
                 },
             ),
         ]
         for url_string, expected_psycopg in psycopg_combinations:
+            asyncpg_error = expected_psycopg.pop("asyncpg_error", False)
+            asyncpg_host = expected_psycopg.pop("asyncpg_host", False)
+            asyncpg_port = expected_psycopg.pop("asyncpg_port", False)
+
             expected_asyncpg = dict(expected_psycopg)
+
             if "dbname" in expected_asyncpg:
                 expected_asyncpg["database"] = expected_asyncpg.pop("dbname")
-            if "host" in expected_asyncpg:
-                expected_asyncpg["host"] = expected_asyncpg["host"].split(",")
-            if "port" in expected_asyncpg:
-                expected_asyncpg["port"] = [
-                    int(p) if re.match(r"^\d+$", p) else None
-                    for p in expected_psycopg["port"].split(",")
-                ]
+
+            if asyncpg_error:
+                expected_asyncpg["error"] = asyncpg_error
+            if asyncpg_host is not False:
+                expected_asyncpg["host"] = asyncpg_host
+
+            if asyncpg_port is not False:
+                expected_asyncpg["port"] = asyncpg_port
+
             yield url_string, expected_psycopg, expected_asyncpg
 
     @testing.combinations_list(
@@ -477,32 +509,13 @@ class MultiHostConnectTest(fixtures.TestBase):
         u = url.make_url(url_string)
 
         if dialect.driver == "asyncpg":
-            if (
-                "port" in expected_asyncpg
-                and not all(expected_asyncpg["port"])
-                or (
-                    "host" in expected_asyncpg
-                    and isinstance(expected_asyncpg["host"], list)
-                    and "port" not in expected_asyncpg
-                )
-            ):
+            if "error" in expected_asyncpg:
                 with expect_raises_message(
-                    exc.ArgumentError,
-                    "All ports are required to be present"
-                    " for asyncpg multiple host URL",
-                ):
-                    dialect.create_connect_args(u)
-                return
-            elif "host" in expected_asyncpg and not all(
-                expected_asyncpg["host"]
-            ):
-                with expect_raises_message(
-                    exc.ArgumentError,
-                    "All hosts are required to be present"
-                    " for asyncpg multiple host URL",
+                    exc.ArgumentError, expected_asyncpg["error"]
                 ):
                     dialect.create_connect_args(u)
                 return
+
             expected = expected_asyncpg
         else:
             expected = expected_psycopg