From: Mike Bayer Date: Thu, 6 Jul 2023 22:01:25 +0000 (-0400) Subject: ensure unix path syntax works for asyncpg as well X-Git-Tag: rel_2_0_19~16 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=400aa8a676ba1a0a1536ae52a20caa93726525dd;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git ensure unix path syntax works for asyncpg as well update for a2c06a2a0acf769060f11bb34c1b55cecae5f5fe updates test suite to include direct expected data / errors in the test data Fixes: #10069 Change-Id: I1e689101b90b7469608b74ed37abd7c2122151a4 --- diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py index e3713c6d6e..d4350cc289 100644 --- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py +++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py @@ -1089,7 +1089,11 @@ class PGDialect_asyncpg(PGDialect): if multihosts: assert multiports - if not all(multihosts): + if len(multihosts) == 1: + opts["host"] = multihosts[0] + if multiports[0] is not None: + opts["port"] = multiports[0] + elif not all(multihosts): raise exc.ArgumentError( "All hosts are required to be present" " for asyncpg multiple host URL" @@ -1099,8 +1103,9 @@ class PGDialect_asyncpg(PGDialect): "All ports are required to be present" " for asyncpg multiple host URL" ) - opts["host"] = list(multihosts) - opts["port"] = list(multiports) + else: + opts["host"] = list(multihosts) + opts["port"] = list(multiports) else: util.coerce_kw_type(opts, "port", int) util.coerce_kw_type(opts, "prepared_statement_cache_size", int) diff --git a/test/dialect/postgresql/test_dialect.py b/test/dialect/postgresql/test_dialect.py index 31335a84c0..a55fc0a6bb 100644 --- a/test/dialect/postgresql/test_dialect.py +++ b/test/dialect/postgresql/test_dialect.py @@ -2,7 +2,6 @@ import dataclasses import datetime import logging import logging.handlers -import re from sqlalchemy import BigInteger from sqlalchemy import bindparam @@ -317,6 +316,7 @@ class MultiHostConnectTest(fixtures.TestBase): "password": "PASS", "host": "192.168.1.50", "port": "5678", + "asyncpg_port": 5678, }, ), ( @@ -345,6 +345,7 @@ class MultiHostConnectTest(fixtures.TestBase): "password": "PASS", "host": "HOSTNAME", "port": "1234", + "asyncpg_port": 1234, }, ), ( @@ -384,6 +385,7 @@ class MultiHostConnectTest(fixtures.TestBase): "password": "PASS", "host": "hostA", "port": "1234", + "asyncpg_port": 1234, }, ), ( @@ -395,6 +397,8 @@ class MultiHostConnectTest(fixtures.TestBase): "password": "PASS", "host": "hostA,hostB,hostC", "port": ",,", + "asyncpg_error": "All ports are required to be present" + " for asyncpg multiple host URL", }, ), ( @@ -406,6 +410,8 @@ class MultiHostConnectTest(fixtures.TestBase): "password": "PASS", "host": "hostA,hostB,hostC", "port": ",222,333", + "asyncpg_error": "All ports are required to be present" + " for asyncpg multiple host URL", }, ), ( @@ -417,22 +423,39 @@ class MultiHostConnectTest(fixtures.TestBase): "password": "PASS", "host": "hostA,hostB,hostC", "port": "111,222,333", + "asyncpg_host": ["hostA", "hostB", "hostC"], + "asyncpg_port": [111, 222, 333], }, ), ( "postgresql+psycopg2:///" "?host=hostA:111&host=hostB:222&host=hostC:333", - {"host": "hostA,hostB,hostC", "port": "111,222,333"}, + { + "host": "hostA,hostB,hostC", + "port": "111,222,333", + "asyncpg_host": ["hostA", "hostB", "hostC"], + "asyncpg_port": [111, 222, 333], + }, ), ( "postgresql+psycopg2:///" "?host=hostA:111&host=hostB:222&host=hostC:333", - {"host": "hostA,hostB,hostC", "port": "111,222,333"}, + { + "host": "hostA,hostB,hostC", + "port": "111,222,333", + "asyncpg_host": ["hostA", "hostB", "hostC"], + "asyncpg_port": [111, 222, 333], + }, ), ( "postgresql+psycopg2:///" "?host=hostA,hostB,hostC&port=111,222,333", - {"host": "hostA,hostB,hostC", "port": "111,222,333"}, + { + "host": "hostA,hostB,hostC", + "port": "111,222,333", + "asyncpg_host": ["hostA", "hostB", "hostC"], + "asyncpg_port": [111, 222, 333], + }, ), ( "postgresql+asyncpg://USER:PASS@/DB" @@ -443,20 +466,29 @@ class MultiHostConnectTest(fixtures.TestBase): "dbname": "DB", "user": "USER", "password": "PASS", + "asyncpg_error": "All hosts are required to be present" + " for asyncpg multiple host URL", }, ), ] for url_string, expected_psycopg in psycopg_combinations: + asyncpg_error = expected_psycopg.pop("asyncpg_error", False) + asyncpg_host = expected_psycopg.pop("asyncpg_host", False) + asyncpg_port = expected_psycopg.pop("asyncpg_port", False) + expected_asyncpg = dict(expected_psycopg) + if "dbname" in expected_asyncpg: expected_asyncpg["database"] = expected_asyncpg.pop("dbname") - if "host" in expected_asyncpg: - expected_asyncpg["host"] = expected_asyncpg["host"].split(",") - if "port" in expected_asyncpg: - expected_asyncpg["port"] = [ - int(p) if re.match(r"^\d+$", p) else None - for p in expected_psycopg["port"].split(",") - ] + + if asyncpg_error: + expected_asyncpg["error"] = asyncpg_error + if asyncpg_host is not False: + expected_asyncpg["host"] = asyncpg_host + + if asyncpg_port is not False: + expected_asyncpg["port"] = asyncpg_port + yield url_string, expected_psycopg, expected_asyncpg @testing.combinations_list( @@ -477,32 +509,13 @@ class MultiHostConnectTest(fixtures.TestBase): u = url.make_url(url_string) if dialect.driver == "asyncpg": - if ( - "port" in expected_asyncpg - and not all(expected_asyncpg["port"]) - or ( - "host" in expected_asyncpg - and isinstance(expected_asyncpg["host"], list) - and "port" not in expected_asyncpg - ) - ): + if "error" in expected_asyncpg: with expect_raises_message( - exc.ArgumentError, - "All ports are required to be present" - " for asyncpg multiple host URL", - ): - dialect.create_connect_args(u) - return - elif "host" in expected_asyncpg and not all( - expected_asyncpg["host"] - ): - with expect_raises_message( - exc.ArgumentError, - "All hosts are required to be present" - " for asyncpg multiple host URL", + exc.ArgumentError, expected_asyncpg["error"] ): dialect.create_connect_args(u) return + expected = expected_asyncpg else: expected = expected_psycopg