From 94bba62774377fd1654296c7ca4fe5114f75fcf5 Mon Sep 17 00:00:00 2001 From: Ilia Dmitriev Date: Thu, 22 Jun 2023 18:02:26 +0300 Subject: [PATCH] Feature asyncpg dialect doesn't support mutlihost connection string + moved base postgres dialect create arguments logic `PGDialect._split_multihost_from_url` + added asyncpg specific logic to `PGDialect_asyncpg.create_connect_args` + test case when one of hosts or ports is None + error message for case when one of hosts or ports is None + intelligeble error message when one of ports is not an integer --- .../dialects/postgresql/_psycopg_common.py | 25 +-- lib/sqlalchemy/dialects/postgresql/asyncpg.py | 23 ++- lib/sqlalchemy/dialects/postgresql/base.py | 48 ++++++ test/dialect/postgresql/test_dialect.py | 153 ++++++++++++++++-- 4 files changed, 217 insertions(+), 32 deletions(-) diff --git a/lib/sqlalchemy/dialects/postgresql/_psycopg_common.py b/lib/sqlalchemy/dialects/postgresql/_psycopg_common.py index b985180994..dfb25a5689 100644 --- a/lib/sqlalchemy/dialects/postgresql/_psycopg_common.py +++ b/lib/sqlalchemy/dialects/postgresql/_psycopg_common.py @@ -131,9 +131,7 @@ class _PGDialect_common_psycopg(PGDialect): def create_connect_args(self, url): opts = url.translate_connect_args(username="user", database="dbname") - is_multihost = False - if "host" in url.query: - is_multihost = isinstance(url.query["host"], (list, tuple)) + multihosts, multiports = self._split_multihost_from_url(url) if opts or url.query: if not opts: @@ -141,21 +139,12 @@ class _PGDialect_common_psycopg(PGDialect): if "port" in opts: opts["port"] = int(opts["port"]) opts.update(url.query) - if is_multihost: - hosts, ports = zip( - *[ - token.split(":") if ":" in token else (token, "") - for token in url.query["host"] - ] - ) - opts["host"] = ",".join(hosts) - if "port" in opts: - raise exc.ArgumentError( - "Can't mix 'multihost' formats together; use " - '"host=h1,h2,h3&port=p1,p2,p3" or ' - '"host=h1:p1&host=h2:p2&host=h3:p3" separately' - ) - opts["port"] = ",".join(ports) + + if multihosts: + opts["host"] = ",".join(multihosts) + comma_ports = ",".join(str(p) if p else "" for p in multiports) + if comma_ports: + opts["port"] = comma_ports return ([], opts) else: # no connection arguments whatsoever; psycopg2.connect() diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py index 9eb17801e7..97ef48dbf0 100644 --- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py +++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py @@ -47,6 +47,12 @@ in conjunction with :func:`_sa.create_engine`:: ``json_deserializer`` when creating the engine with :func:`create_engine` or :func:`create_async_engine`. +.. _asyncpg_multihost_connecting: + +Multihost Connections +-------------------------- + +Described in :ref:`psycopg2_multi_host` .. _asyncpg_prepared_statement_cache: @@ -1060,10 +1066,25 @@ class PGDialect_asyncpg(PGDialect): def create_connect_args(self, url): opts = url.translate_connect_args(username="user") + multihosts, multiports = self._split_multihost_from_url(url) opts.update(url.query) + + if multihosts: + if not all(multihosts): + raise exc.ArgumentError( + "Some hosts is not specified in a connection string" + ) + if not all(multiports): + raise exc.ArgumentError( + "All ports are required to be present" + " for asyncpg multiple host URL" + ) + opts["host"] = list(multihosts) + opts["port"] = list(multiports) + else: + util.coerce_kw_type(opts, "port", int) util.coerce_kw_type(opts, "prepared_statement_cache_size", int) - util.coerce_kw_type(opts, "port", int) return ([], opts) def do_ping(self, dbapi_connection): diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 61aa76db71..9b3f26d915 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -3086,6 +3086,54 @@ class PGDialect(default.DefaultDialect): def get_deferrable(self, connection): raise NotImplementedError() + def _split_multihost_from_url(self, url): + hosts = ports = None + + integrated_multihost = False + + if "host" in url.query: + if isinstance(url.query["host"], (list, tuple)): + integrated_multihost = True + hosts, ports = zip( + *[ + token.split(":") if ":" in token else (token, None) + for token in url.query["host"] + ] + ) + + elif isinstance(url.query["host"], str): + hosts = tuple(url.query["host"].split(",")) + + if "port" in url.query: + if integrated_multihost: + raise exc.ArgumentError( + "Can't mix 'multihost' formats together; use " + '"host=h1,h2,h3&port=p1,p2,p3" or ' + '"host=h1:p1&host=h2:p2&host=h3:p3" separately' + ) + if isinstance(url.query["port"], (list, tuple)): + ports = url.query["port"] + elif isinstance(url.query["port"], str): + ports = tuple(url.query["port"].split(",")) + + if ports: + try: + ports = tuple(int(x) if x else None for x in ports) + except ValueError: + raise exc.ArgumentError( + f"Some of specified ports is not a " + f"valid integer: `{ports}`" + ) from None + + if ports and (not hosts or len(hosts) != len(ports)): + raise exc.ArgumentError("number of hosts and ports don't match") + + if hosts is not None: + if ports is None: + ports = tuple(None for _ in hosts) + + return hosts, ports + def do_begin_twophase(self, connection, xid): self.do_begin(connection.connection) diff --git a/test/dialect/postgresql/test_dialect.py b/test/dialect/postgresql/test_dialect.py index 08a1cd8f6b..6794a273e5 100644 --- a/test/dialect/postgresql/test_dialect.py +++ b/test/dialect/postgresql/test_dialect.py @@ -27,6 +27,7 @@ from sqlalchemy import Table from sqlalchemy import testing from sqlalchemy import text from sqlalchemy import TypeDecorator +from sqlalchemy.dialects.postgresql import asyncpg as asyncpg_dialect from sqlalchemy.dialects.postgresql import base as postgresql from sqlalchemy.dialects.postgresql import HSTORE from sqlalchemy.dialects.postgresql import JSONB @@ -240,53 +241,156 @@ $$ LANGUAGE plpgsql;""" ), ( "postgresql+psycopg2://USER:PASS@/DB" - "?host=hostA&host=hostB:portB&host=hostC:portC", + "?host=hostA&host=hostB:222&host=hostC:333", { "dbname": "DB", "user": "USER", "password": "PASS", "host": "hostA,hostB,hostC", - "port": ",portB,portC", + "port": ",222,333", }, ), ( "postgresql+psycopg2://USER:PASS@/DB?" - "host=hostA:portA&host=hostB:portB&host=hostC:portC", + "host=hostA:111&host=hostB:222&host=hostC:333", { "dbname": "DB", "user": "USER", "password": "PASS", "host": "hostA,hostB,hostC", - "port": "portA,portB,portC", + "port": "111,222,333", }, ), ( "postgresql+psycopg2:///" - "?host=hostA:portA&host=hostB:portB&host=hostC:portC", - {"host": "hostA,hostB,hostC", "port": "portA,portB,portC"}, + "?host=hostA:111&host=hostB:222&host=hostC:333", + {"host": "hostA,hostB,hostC", "port": "111,222,333"}, ), ( "postgresql+psycopg2:///" - "?host=hostA:portA&host=hostB:portB&host=hostC:portC", - {"host": "hostA,hostB,hostC", "port": "portA,portB,portC"}, + "?host=hostA:111&host=hostB:222&host=hostC:333", + {"host": "hostA,hostB,hostC", "port": "111,222,333"}, ), ( "postgresql+psycopg2:///" - "?host=hostA,hostB,hostC&port=portA,portB,portC", - {"host": "hostA,hostB,hostC", "port": "portA,portB,portC"}, + "?host=hostA,hostB,hostC&port=111,222,333", + {"host": "hostA,hostB,hostC", "port": "111,222,333"}, ), - argnames="url_string,expected", + argnames="url_string,expected_psycopg", ) @testing.combinations( psycopg2_dialect.dialect(), psycopg_dialect.dialect(), argnames="dialect", ) - def test_psycopg_multi_hosts(self, dialect, url_string, expected): + def test_multi_hosts(self, dialect, url_string, expected_psycopg): + url_string = url_string.replace("psycopg2", dialect.name) + + u = url.make_url(url_string) + + if dialect.driver in ("psycopg", "psycopg2"): + cargs, cparams = dialect.create_connect_args(u) + eq_(cparams, expected_psycopg) + eq_(cargs, []) + else: + assert False + + @testing.combinations( + ( + "postgresql+asyncpg://USER:PASS@/DB?" + "host=hostA:111&host=hostB:222&host=hostC:333", + { + "database": "DB", + "user": "USER", + "password": "PASS", + "host": ["hostA", "hostB", "hostC"], + "port": [111, 222, 333], + }, + ), + ( + "postgresql+asyncpg:///" + "?host=hostA:111&host=hostB:222&host=hostC:333", + { + "host": ["hostA", "hostB", "hostC"], + "port": [111, 222, 333], + }, + ), + ( + "postgresql+asyncpg:///" + "?host=hostA:111&host=hostB:222&host=hostC:333", + { + "host": ["hostA", "hostB", "hostC"], + "port": [111, 222, 333], + }, + ), + ( + "postgresql+asyncpg:///" + "?host=hostA,hostB,hostC&port=111,222,333", + { + "host": ["hostA", "hostB", "hostC"], + "port": [111, 222, 333], + }, + ), + argnames="url_string,expected", + ) + def test_asyncpg_multi_hosts(self, url_string, expected): + dialect = asyncpg_dialect.dialect() u = url.make_url(url_string) cargs, cparams = dialect.create_connect_args(u) - eq_(cargs, []) eq_(cparams, expected) + eq_(cargs, []) + + @testing.combinations( + ( + "postgresql+asyncpg://USER:PASS@/DB?host=hostA", + "All ports are required to be present" + " for asyncpg multiple host URL", + ), + ( + "postgresql+asyncpg://USER:PASS@/DB" + "?host=hostA&host=hostB&host=hostC", + "All ports are required to be present" + " for asyncpg multiple host URL", + ), + ( + "postgresql+asyncpg://USER:PASS@/DB" + "?host=hostA&host=hostB:222&host=hostC:333", + "All ports are required to be present" + " for asyncpg multiple host URL", + ), + ( + "postgresql+asyncpg://USER:PASS@/DB" + "?host=hostA,hostB,hostC&port=111,,333", + "All ports are required to be present" + " for asyncpg multiple host URL", + ), + ( + "postgresql+asyncpg://USER:PASS@/DB" + "?host=hostA,hostB,&port=111,222,333", + "Some hosts is not specified in a connection string", + ), + ( + "postgresql+asyncpg://USER:PASS@/DB" + "?host=hostA:111&host=:222&host=hostC:333", + "Some hosts is not specified in a connection string", + ), + ( + "postgresql+asyncpg://USER:PASS@/DB" + "?host=hostA:111&host=hostB:vvv&host=hostC:333", + "Some of specified ports is not a valid integer", + ), + ( + "postgresql+asyncpg://USER:PASS@/DB" + "?host=hostA,hostB,hostC&port=111,222", + "number of hosts and ports don't match", + ), + argnames="url_string,expected_message", + ) + def test_asyncpg_multi_hosts_errors(self, url_string, expected_message): + dialect = asyncpg_dialect.dialect() + with expect_raises_message(exc.ArgumentError, expected_message): + u = url.make_url(url_string) + dialect.create_connect_args(u) @testing.combinations( "postgresql+psycopg2:///?host=H&host=H&port=5432,5432", @@ -373,6 +477,29 @@ class BackendDialectTest(fixtures.TestBase): with e.connect() as conn: eq_(conn.exec_driver_sql("select 1").scalar(), 1) + @testing.only_on(["+asyncpg"]) + @testing.combinations( + "host=H:P&host=H:P&host=H:P", + "host=H,H,H&port=P,P,P", + ) + def test_connect_asyncpg_multiple_hosts(self, pattern): + tdb_url = testing.db.url + + host = tdb_url.host + if host == "127.0.0.1": + host = "localhost" + port = str(tdb_url.port) if tdb_url.port else "5432" + + query_str = pattern.replace("H", host).replace("P", port) + url_string = ( + f"{tdb_url.drivername}://{tdb_url.username}:" + f"{tdb_url.password}@/{tdb_url.database}?{query_str}" + ) + + e = create_engine(url_string) + with e.connect() as conn: + eq_(conn.exec_driver_sql("select 1").scalar(), 1) + class PGCodeTest(fixtures.TestBase): __only_on__ = "postgresql" -- 2.47.3