From: Ilia Dmitriev Date: Wed, 28 Jun 2023 19:20:28 +0000 (-0400) Subject: Feature asyncpg dialect doesn't support mutlihost connection string X-Git-Tag: rel_2_0_18~20 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=6d5cdd1f59415ea8f08efe0f40126b8c384d1125;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Feature asyncpg dialect doesn't support mutlihost connection string Added multi-host support for the asyncpg dialect. General improvements and error checking added to the PostgreSQL URL routines for the "multihost" use case added as well. Pull request courtesy Ilia Dmitriev. Fixes: #10004 Closes: #10005 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/10005 Pull-request-sha: 94bba62774377fd1654296c7ca4fe5114f75fcf5 Change-Id: I68f5bdfe98531dffe06fa998f8b7471af1426a33 --- diff --git a/doc/build/changelog/unreleased_20/10004.rst b/doc/build/changelog/unreleased_20/10004.rst new file mode 100644 index 0000000000..cb7d9951f3 --- /dev/null +++ b/doc/build/changelog/unreleased_20/10004.rst @@ -0,0 +1,11 @@ +.. change:: + :tags: usecase, postgresql + :tickets: 10004 + + Added multi-host support for the asyncpg dialect. General improvements and + error checking added to the PostgreSQL URL routines for the "multihost" use + case added as well. Pull request courtesy Ilia Dmitriev. + + .. seealso:: + + :ref:`asyncpg_multihost` diff --git a/lib/sqlalchemy/dialects/postgresql/_psycopg_common.py b/lib/sqlalchemy/dialects/postgresql/_psycopg_common.py index b985180994..dfb25a5689 100644 --- a/lib/sqlalchemy/dialects/postgresql/_psycopg_common.py +++ b/lib/sqlalchemy/dialects/postgresql/_psycopg_common.py @@ -131,9 +131,7 @@ class _PGDialect_common_psycopg(PGDialect): def create_connect_args(self, url): opts = url.translate_connect_args(username="user", database="dbname") - is_multihost = False - if "host" in url.query: - is_multihost = isinstance(url.query["host"], (list, tuple)) + multihosts, multiports = self._split_multihost_from_url(url) if opts or url.query: if not opts: @@ -141,21 +139,12 @@ class _PGDialect_common_psycopg(PGDialect): if "port" in opts: opts["port"] = int(opts["port"]) opts.update(url.query) - if is_multihost: - hosts, ports = zip( - *[ - token.split(":") if ":" in token else (token, "") - for token in url.query["host"] - ] - ) - opts["host"] = ",".join(hosts) - if "port" in opts: - raise exc.ArgumentError( - "Can't mix 'multihost' formats together; use " - '"host=h1,h2,h3&port=p1,p2,p3" or ' - '"host=h1:p1&host=h2:p2&host=h3:p3" separately' - ) - opts["port"] = ",".join(ports) + + if multihosts: + opts["host"] = ",".join(multihosts) + comma_ports = ",".join(str(p) if p else "" for p in multiports) + if comma_ports: + opts["port"] = comma_ports return ([], opts) else: # no connection arguments whatsoever; psycopg2.connect() diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py index 9eb17801e7..dacb9ebd5a 100644 --- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py +++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py @@ -47,6 +47,29 @@ in conjunction with :func:`_sa.create_engine`:: ``json_deserializer`` when creating the engine with :func:`create_engine` or :func:`create_async_engine`. +.. _asyncpg_multihost: + +Multihost Connections +-------------------------- + +The asyncpg dialect features support for multiple fallback hosts in the +same way as that of the psycopg2 and psycopg dialects. The +syntax is the same, +using ``host=:`` combinations as additional query string arguments; +however, there is no default port, so all hosts must have a complete port number +present, otherwise an exception is raised:: + + engine = create_async_engine( + "postgresql+asyncpg://user:password@/dbname?host=HostA:5432&host=HostB:5432&host=HostC:5432" + ) + +For complete background on this syntax, see :ref:`psycopg2_multi_host`. + +.. versionadded:: 2.0.18 + +.. seealso:: + + :ref:`psycopg2_multi_host` .. _asyncpg_prepared_statement_cache: @@ -1060,10 +1083,27 @@ class PGDialect_asyncpg(PGDialect): def create_connect_args(self, url): opts = url.translate_connect_args(username="user") + multihosts, multiports = self._split_multihost_from_url(url) opts.update(url.query) + + if multihosts: + assert multiports + if not all(multihosts): + raise exc.ArgumentError( + "All hosts are required to be present" + " for asyncpg multiple host URL" + ) + elif not all(multiports): + raise exc.ArgumentError( + "All ports are required to be present" + " for asyncpg multiple host URL" + ) + opts["host"] = list(multihosts) + opts["port"] = list(multiports) + else: + util.coerce_kw_type(opts, "port", int) util.coerce_kw_type(opts, "prepared_statement_cache_size", int) - util.coerce_kw_type(opts, "port", int) return ([], opts) def do_ping(self, dbapi_connection): diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 61aa76db71..835ff5b2a4 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -1408,6 +1408,7 @@ from typing import Any from typing import List from typing import Optional from typing import Tuple +from typing import Union from . import array as _array from . import hstore as _hstore @@ -1459,6 +1460,7 @@ from ...engine import interfaces from ...engine import ObjectKind from ...engine import ObjectScope from ...engine import reflection +from ...engine import URL from ...engine.reflection import ReflectionDefaults from ...sql import bindparam from ...sql import coercions @@ -3086,6 +3088,73 @@ class PGDialect(default.DefaultDialect): def get_deferrable(self, connection): raise NotImplementedError() + def _split_multihost_from_url( + self, url: URL + ) -> Union[ + Tuple[None, None], + Tuple[Tuple[Optional[str], ...], Tuple[Optional[int], ...]], + ]: + hosts = ports = None + + integrated_multihost = False + + if "host" in url.query: + if isinstance(url.query["host"], (list, tuple)): + integrated_multihost = True + hosts, ports = zip( + *[ + token.split(":") if ":" in token else (token, None) + for token in url.query["host"] + ] + ) + + elif isinstance(url.query["host"], str): + hosts = tuple(url.query["host"].split(",")) + + if ( + "port" not in url.query + and len(hosts) == 1 + and ":" in hosts[0] + ): + integrated_multihost = True + h, p = hosts[0].split(":") + hosts = (h,) + ports = (p,) if p else (None,) + + if "port" in url.query: + if integrated_multihost: + raise exc.ArgumentError( + "Can't mix 'multihost' formats together; use " + '"host=h1,h2,h3&port=p1,p2,p3" or ' + '"host=h1:p1&host=h2:p2&host=h3:p3" separately' + ) + if isinstance(url.query["port"], (list, tuple)): + ports = url.query["port"] + elif isinstance(url.query["port"], str): + ports = tuple(url.query["port"].split(",")) + + if ports: + try: + ports = tuple(int(x) if x else None for x in ports) + except ValueError: + raise exc.ArgumentError( + f"Received non-integer port arguments: {ports}" + ) from None + + if (hosts or ports) and url.host: + raise exc.ArgumentError( + "Can't combine fixed host and multihost URL formats" + ) + + if ports and (not hosts or len(hosts) != len(ports)): + raise exc.ArgumentError("number of hosts and ports don't match") + + if hosts is not None: + if ports is None: + ports = tuple(None for _ in hosts) + + return hosts, ports + def do_begin_twophase(self, connection, xid): self.do_begin(connection.connection) diff --git a/test/dialect/postgresql/test_dialect.py b/test/dialect/postgresql/test_dialect.py index 08a1cd8f6b..771ffea625 100644 --- a/test/dialect/postgresql/test_dialect.py +++ b/test/dialect/postgresql/test_dialect.py @@ -2,6 +2,7 @@ import dataclasses import datetime import logging import logging.handlers +import re from sqlalchemy import BigInteger from sqlalchemy import bindparam @@ -27,6 +28,7 @@ from sqlalchemy import Table from sqlalchemy import testing from sqlalchemy import text from sqlalchemy import TypeDecorator +from sqlalchemy.dialects.postgresql import asyncpg as asyncpg_dialect from sqlalchemy.dialects.postgresql import base as postgresql from sqlalchemy.dialects.postgresql import HSTORE from sqlalchemy.dialects.postgresql import JSONB @@ -217,76 +219,287 @@ $$ LANGUAGE plpgsql;""" eq_(cargs, []) eq_(cparams, {"host": "somehost", "any_random_thing": "yes"}) + def test_psycopg2_disconnect(self): + class Error(Exception): + pass + + dbapi = mock.Mock() + dbapi.Error = Error + + dialect = psycopg2_dialect.dialect(dbapi=dbapi) + + for error in [ + # these error messages from libpq: interfaces/libpq/fe-misc.c + # and interfaces/libpq/fe-secure.c. + "terminating connection", + "closed the connection", + "connection not open", + "could not receive data from server", + "could not send data to server", + # psycopg2 client errors, psycopg2/connection.h, + # psycopg2/cursor.h + "connection already closed", + "cursor already closed", + # not sure where this path is originally from, it may + # be obsolete. It really says "losed", not "closed". + "losed the connection unexpectedly", + # these can occur in newer SSL + "connection has been closed unexpectedly", + "SSL error: decryption failed or bad record mac", + "SSL SYSCALL error: Bad file descriptor", + "SSL SYSCALL error: EOF detected", + "SSL SYSCALL error: Operation timed out", + "SSL SYSCALL error: Bad address", + ]: + eq_(dialect.is_disconnect(Error(error), None, None), True) + + eq_(dialect.is_disconnect("not an error", None, None), False) + + +class MultiHostConnectTest(fixtures.TestBase): + def working_combinations(): + psycopg_combinations = [ + ( + "postgresql+psycopg2://USER:PASS@/DB?host=hostA", + { + "dbname": "DB", + "user": "USER", + "password": "PASS", + "host": "hostA", + }, + ), + ( + "postgresql+psycopg2://USER:PASS@/DB?host=hostA:", + { + "dbname": "DB", + "user": "USER", + "password": "PASS", + "host": "hostA", + }, + ), + ( + "postgresql+psycopg2://USER:PASS@/DB?host=hostA:1234", + { + "dbname": "DB", + "user": "USER", + "password": "PASS", + "host": "hostA", + "port": "1234", + }, + ), + ( + "postgresql+psycopg2://USER:PASS@/DB" + "?host=hostA&host=hostB&host=hostC", + { + "dbname": "DB", + "user": "USER", + "password": "PASS", + "host": "hostA,hostB,hostC", + "port": ",,", + }, + ), + ( + "postgresql+psycopg2://USER:PASS@/DB" + "?host=hostA&host=hostB:222&host=hostC:333", + { + "dbname": "DB", + "user": "USER", + "password": "PASS", + "host": "hostA,hostB,hostC", + "port": ",222,333", + }, + ), + ( + "postgresql+psycopg2://USER:PASS@/DB?" + "host=hostA:111&host=hostB:222&host=hostC:333", + { + "dbname": "DB", + "user": "USER", + "password": "PASS", + "host": "hostA,hostB,hostC", + "port": "111,222,333", + }, + ), + ( + "postgresql+psycopg2:///" + "?host=hostA:111&host=hostB:222&host=hostC:333", + {"host": "hostA,hostB,hostC", "port": "111,222,333"}, + ), + ( + "postgresql+psycopg2:///" + "?host=hostA:111&host=hostB:222&host=hostC:333", + {"host": "hostA,hostB,hostC", "port": "111,222,333"}, + ), + ( + "postgresql+psycopg2:///" + "?host=hostA,hostB,hostC&port=111,222,333", + {"host": "hostA,hostB,hostC", "port": "111,222,333"}, + ), + ( + "postgresql+asyncpg://USER:PASS@/DB" + "?host=hostA,hostB,&port=111,222,333", + { + "host": "hostA,hostB,", + "port": "111,222,333", + "dbname": "DB", + "user": "USER", + "password": "PASS", + }, + ), + ] + for url_string, expected_psycopg in psycopg_combinations: + expected_asyncpg = dict(expected_psycopg) + if "dbname" in expected_asyncpg: + expected_asyncpg["database"] = expected_asyncpg.pop("dbname") + if "host" in expected_asyncpg: + expected_asyncpg["host"] = expected_asyncpg["host"].split(",") + if "port" in expected_asyncpg: + expected_asyncpg["port"] = [ + int(p) if re.match(r"^\d+$", p) else None + for p in expected_psycopg["port"].split(",") + ] + yield url_string, expected_psycopg, expected_asyncpg + + @testing.combinations_list( + working_combinations(), + argnames="url_string,expected_psycopg,expected_asyncpg", + ) + @testing.combinations( + psycopg2_dialect.dialect(), + psycopg_dialect.dialect(), + asyncpg_dialect.dialect(), + argnames="dialect", + ) + def test_multi_hosts( + self, dialect, url_string, expected_psycopg, expected_asyncpg + ): + url_string = url_string.replace("psycopg2", dialect.driver) + + u = url.make_url(url_string) + + if dialect.driver == "asyncpg": + if ( + "port" in expected_asyncpg + and not all(expected_asyncpg["port"]) + or ( + "host" in expected_asyncpg + and isinstance(expected_asyncpg["host"], list) + and "port" not in expected_asyncpg + ) + ): + with expect_raises_message( + exc.ArgumentError, + "All ports are required to be present" + " for asyncpg multiple host URL", + ): + dialect.create_connect_args(u) + return + elif "host" in expected_asyncpg and not all( + expected_asyncpg["host"] + ): + with expect_raises_message( + exc.ArgumentError, + "All hosts are required to be present" + " for asyncpg multiple host URL", + ): + dialect.create_connect_args(u) + return + expected = expected_asyncpg + else: + expected = expected_psycopg + + cargs, cparams = dialect.create_connect_args(u) + eq_(cparams, expected) + eq_(cargs, []) + @testing.combinations( - ( - "postgresql+psycopg2://USER:PASS@/DB?host=hostA", - { - "dbname": "DB", - "user": "USER", - "password": "PASS", - "host": "hostA", - }, - ), ( "postgresql+psycopg2://USER:PASS@/DB" - "?host=hostA&host=hostB&host=hostC", - { - "dbname": "DB", - "user": "USER", - "password": "PASS", - "host": "hostA,hostB,hostC", - "port": ",,", - }, + "?host=hostA:111&host=hostB:vvv&host=hostC:333", ), ( "postgresql+psycopg2://USER:PASS@/DB" - "?host=hostA&host=hostB:portB&host=hostC:portC", - { - "dbname": "DB", - "user": "USER", - "password": "PASS", - "host": "hostA,hostB,hostC", - "port": ",portB,portC", - }, + "?host=hostA,hostB:,hostC&port=111,vvv,333", ), ( - "postgresql+psycopg2://USER:PASS@/DB?" - "host=hostA:portA&host=hostB:portB&host=hostC:portC", - { - "dbname": "DB", - "user": "USER", - "password": "PASS", - "host": "hostA,hostB,hostC", - "port": "portA,portB,portC", - }, + "postgresql+psycopg2://USER:PASS@/DB" + "?host=hostA:xyz&host=hostB:123", ), + ("postgresql+psycopg2://USER:PASS@/DB?host=hostA:xyz",), + ("postgresql+psycopg2://USER:PASS@/DB?host=hostA&port=xyz",), + argnames="url_string", + ) + @testing.combinations( + psycopg2_dialect.dialect(), + psycopg_dialect.dialect(), + asyncpg_dialect.dialect(), + argnames="dialect", + ) + def test_non_int_port_disallowed(self, dialect, url_string): + url_string = url_string.replace("psycopg2", dialect.driver) + + u = url.make_url(url_string) + + with expect_raises_message( + exc.ArgumentError, + r"Received non-integer port arguments: \((?:'.*?',?)+\)", + ): + dialect.create_connect_args(u) + + @testing.combinations( + ("postgresql+psycopg2://USER:PASS@hostfixed/DB?port=111",), + ("postgresql+psycopg2://USER:PASS@hostfixed/DB?host=hostA:111",), ( - "postgresql+psycopg2:///" - "?host=hostA:portA&host=hostB:portB&host=hostC:portC", - {"host": "hostA,hostB,hostC", "port": "portA,portB,portC"}, + "postgresql+psycopg2://USER:PASS@hostfixed/DB" + "?host=hostA&port=111", ), + ("postgresql+psycopg2://USER:PASS@hostfixed/DB" "?host=hostA",), + argnames="url_string", + ) + @testing.combinations( + psycopg2_dialect.dialect(), + psycopg_dialect.dialect(), + asyncpg_dialect.dialect(), + argnames="dialect", + ) + def test_dont_use_fixed_host(self, dialect, url_string): + url_string = url_string.replace("psycopg2", dialect.driver) + + u = url.make_url(url_string) + with expect_raises_message( + exc.ArgumentError, + "Can't combine fixed host and multihost URL formats", + ): + dialect.create_connect_args(u) + + @testing.combinations( ( - "postgresql+psycopg2:///" - "?host=hostA:portA&host=hostB:portB&host=hostC:portC", - {"host": "hostA,hostB,hostC", "port": "portA,portB,portC"}, + "postgresql+psycopg2://USER:PASS@/DB" + "?host=hostA,hostC&port=111,222,333", ), + ("postgresql+psycopg2://USER:PASS@/DB" "?host=hostA&port=111,222",), + ("postgresql+psycopg2://USER:PASS@/DB?port=111",), ( - "postgresql+psycopg2:///" - "?host=hostA,hostB,hostC&port=portA,portB,portC", - {"host": "hostA,hostB,hostC", "port": "portA,portB,portC"}, + "postgresql+asyncpg://USER:PASS@/DB" + "?host=hostA,hostB,hostC&port=111,333", ), - argnames="url_string,expected", + argnames="url_string", ) @testing.combinations( psycopg2_dialect.dialect(), psycopg_dialect.dialect(), + asyncpg_dialect.dialect(), argnames="dialect", ) - def test_psycopg_multi_hosts(self, dialect, url_string, expected): + def test_num_host_port_doesnt_match(self, dialect, url_string): + url_string = url_string.replace("psycopg2", dialect.driver) + u = url.make_url(url_string) - cargs, cparams = dialect.create_connect_args(u) - eq_(cargs, []) - eq_(cparams, expected) + + with expect_raises_message( + exc.ArgumentError, "number of hosts and ports don't match" + ): + dialect.create_connect_args(u) @testing.combinations( "postgresql+psycopg2:///?host=H&host=H&port=5432,5432", @@ -296,65 +509,42 @@ $$ LANGUAGE plpgsql;""" @testing.combinations( psycopg2_dialect.dialect(), psycopg_dialect.dialect(), + asyncpg_dialect.dialect(), argnames="dialect", ) - def test_psycopg_no_mix_hosts(self, dialect, url_string): + def test_dont_mix_multihost_formats(self, dialect, url_string): + url_string = url_string.replace("psycopg2", dialect.name) + + u = url.make_url(url_string) + with expect_raises_message( exc.ArgumentError, "Can't mix 'multihost' formats together" ): - u = url.make_url(url_string) dialect.create_connect_args(u) - def test_psycopg2_disconnect(self): - class Error(Exception): - pass - - dbapi = mock.Mock() - dbapi.Error = Error - - dialect = psycopg2_dialect.dialect(dbapi=dbapi) - - for error in [ - # these error messages from libpq: interfaces/libpq/fe-misc.c - # and interfaces/libpq/fe-secure.c. - "terminating connection", - "closed the connection", - "connection not open", - "could not receive data from server", - "could not send data to server", - # psycopg2 client errors, psycopg2/connection.h, - # psycopg2/cursor.h - "connection already closed", - "cursor already closed", - # not sure where this path is originally from, it may - # be obsolete. It really says "losed", not "closed". - "losed the connection unexpectedly", - # these can occur in newer SSL - "connection has been closed unexpectedly", - "SSL error: decryption failed or bad record mac", - "SSL SYSCALL error: Bad file descriptor", - "SSL SYSCALL error: EOF detected", - "SSL SYSCALL error: Operation timed out", - "SSL SYSCALL error: Bad address", - ]: - eq_(dialect.is_disconnect(Error(error), None, None), True) - - eq_(dialect.is_disconnect("not an error", None, None), False) - class BackendDialectTest(fixtures.TestBase): __backend__ = True - @testing.only_on(["+psycopg", "+psycopg2"]) + @testing.only_on(["+psycopg", "+psycopg2", "+asyncpg"]) @testing.combinations( - "host=H:P&host=H:P&host=H:P", - "host=H:P&host=H&host=H", - "host=H:P&host=H&host=H:P", - "host=H&host=H:P&host=H", - "host=H,H,H&port=P,P,P", + ("postgresql+D://U:PS@/DB?host=H:P&host=H:P&host=H:P", True), + ("postgresql+D://U:PS@/DB?host=H:P&host=H&host=H", False), + ("postgresql+D://U:PS@/DB?host=H:P&host=H&host=H:P", False), + ("postgresql+D://U:PS@/DB?host=H&host=H:P&host=H", False), + ("postgresql+D://U:PS@/DB?host=H,H,H&port=P,P,P", True), + ("postgresql+D://U:PS@H:P/DB", True), + argnames="pattern,has_all_ports", ) - def test_connect_psycopg_multiple_hosts(self, pattern): - """test the fix for #4392""" + def test_multiple_host_real_connect( + self, testing_engine, pattern, has_all_ports + ): + """test the fix for #4392. + + Additionally add multiple host tests for #10004's additional + use cases + + """ tdb_url = testing.db.url @@ -363,13 +553,25 @@ class BackendDialectTest(fixtures.TestBase): host = "localhost" port = str(tdb_url.port) if tdb_url.port else "5432" - query_str = pattern.replace("H", host).replace("P", port) url_string = ( - f"{tdb_url.drivername}://{tdb_url.username}:" - f"{tdb_url.password}@/{tdb_url.database}?{query_str}" + pattern.replace("DB", tdb_url.database) + .replace("postgresql+D", tdb_url.drivername) + .replace("U", tdb_url.username) + .replace("PS", tdb_url.password) + .replace("H", host) + .replace("P", port) ) - e = create_engine(url_string) + if testing.against("+asyncpg") and not has_all_ports: + with expect_raises_message( + exc.ArgumentError, + "All ports are required to be present " + "for asyncpg multiple host URL", + ): + testing_engine(url_string) + return + + e = testing_engine(url_string) with e.connect() as conn: eq_(conn.exec_driver_sql("select 1").scalar(), 1)