]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Feature asyncpg dialect doesn't support mutlihost connection string 10005/head
authorIlia Dmitriev <ilia.dmitriev@gmail.com>
Thu, 22 Jun 2023 15:02:26 +0000 (18:02 +0300)
committerIlia Dmitriev <ilia.dmitriev@gmail.com>
Wed, 28 Jun 2023 17:25:59 +0000 (20:25 +0300)
+ moved base postgres dialect create arguments logic `PGDialect._split_multihost_from_url`
+ added asyncpg specific logic to `PGDialect_asyncpg.create_connect_args`
+ test case when one of hosts or ports is None
+ error message for case when one of hosts or ports is None
+ intelligeble error message when one of ports is not an integer

lib/sqlalchemy/dialects/postgresql/_psycopg_common.py
lib/sqlalchemy/dialects/postgresql/asyncpg.py
lib/sqlalchemy/dialects/postgresql/base.py
test/dialect/postgresql/test_dialect.py

index b985180994a52bfbd38dead4af92a5c267784648..dfb25a5689006642d257ad6dac789c7e46989e92 100644 (file)
@@ -131,9 +131,7 @@ class _PGDialect_common_psycopg(PGDialect):
     def create_connect_args(self, url):
         opts = url.translate_connect_args(username="user", database="dbname")
 
-        is_multihost = False
-        if "host" in url.query:
-            is_multihost = isinstance(url.query["host"], (list, tuple))
+        multihosts, multiports = self._split_multihost_from_url(url)
 
         if opts or url.query:
             if not opts:
@@ -141,21 +139,12 @@ class _PGDialect_common_psycopg(PGDialect):
             if "port" in opts:
                 opts["port"] = int(opts["port"])
             opts.update(url.query)
-            if is_multihost:
-                hosts, ports = zip(
-                    *[
-                        token.split(":") if ":" in token else (token, "")
-                        for token in url.query["host"]
-                    ]
-                )
-                opts["host"] = ",".join(hosts)
-                if "port" in opts:
-                    raise exc.ArgumentError(
-                        "Can't mix 'multihost' formats together; use "
-                        '"host=h1,h2,h3&port=p1,p2,p3" or '
-                        '"host=h1:p1&host=h2:p2&host=h3:p3" separately'
-                    )
-                opts["port"] = ",".join(ports)
+
+            if multihosts:
+                opts["host"] = ",".join(multihosts)
+                comma_ports = ",".join(str(p) if p else "" for p in multiports)
+                if comma_ports:
+                    opts["port"] = comma_ports
             return ([], opts)
         else:
             # no connection arguments whatsoever; psycopg2.connect()
index 9eb17801e7c8f989e5b44383fb1c97382b75f5bf..97ef48dbf019545424c2ce03759564051e34b9db 100644 (file)
@@ -47,6 +47,12 @@ in conjunction with :func:`_sa.create_engine`::
     ``json_deserializer`` when creating the engine with
     :func:`create_engine` or :func:`create_async_engine`.
 
+.. _asyncpg_multihost_connecting:
+
+Multihost Connections
+--------------------------
+
+Described in :ref:`psycopg2_multi_host`
 
 .. _asyncpg_prepared_statement_cache:
 
@@ -1060,10 +1066,25 @@ class PGDialect_asyncpg(PGDialect):
 
     def create_connect_args(self, url):
         opts = url.translate_connect_args(username="user")
+        multihosts, multiports = self._split_multihost_from_url(url)
 
         opts.update(url.query)
+
+        if multihosts:
+            if not all(multihosts):
+                raise exc.ArgumentError(
+                    "Some hosts is not specified in a connection string"
+                )
+            if not all(multiports):
+                raise exc.ArgumentError(
+                    "All ports are required to be present"
+                    " for asyncpg multiple host URL"
+                )
+            opts["host"] = list(multihosts)
+            opts["port"] = list(multiports)
+        else:
+            util.coerce_kw_type(opts, "port", int)
         util.coerce_kw_type(opts, "prepared_statement_cache_size", int)
-        util.coerce_kw_type(opts, "port", int)
         return ([], opts)
 
     def do_ping(self, dbapi_connection):
index 61aa76db71c13f78f86651b493d4a298acef482f..9b3f26d915c73c321510e0c51d333e813b353d83 100644 (file)
@@ -3086,6 +3086,54 @@ class PGDialect(default.DefaultDialect):
     def get_deferrable(self, connection):
         raise NotImplementedError()
 
+    def _split_multihost_from_url(self, url):
+        hosts = ports = None
+
+        integrated_multihost = False
+
+        if "host" in url.query:
+            if isinstance(url.query["host"], (list, tuple)):
+                integrated_multihost = True
+                hosts, ports = zip(
+                    *[
+                        token.split(":") if ":" in token else (token, None)
+                        for token in url.query["host"]
+                    ]
+                )
+
+            elif isinstance(url.query["host"], str):
+                hosts = tuple(url.query["host"].split(","))
+
+        if "port" in url.query:
+            if integrated_multihost:
+                raise exc.ArgumentError(
+                    "Can't mix 'multihost' formats together; use "
+                    '"host=h1,h2,h3&port=p1,p2,p3" or '
+                    '"host=h1:p1&host=h2:p2&host=h3:p3" separately'
+                )
+            if isinstance(url.query["port"], (list, tuple)):
+                ports = url.query["port"]
+            elif isinstance(url.query["port"], str):
+                ports = tuple(url.query["port"].split(","))
+
+        if ports:
+            try:
+                ports = tuple(int(x) if x else None for x in ports)
+            except ValueError:
+                raise exc.ArgumentError(
+                    f"Some of specified ports is not a "
+                    f"valid integer: `{ports}`"
+                ) from None
+
+        if ports and (not hosts or len(hosts) != len(ports)):
+            raise exc.ArgumentError("number of hosts and ports don't match")
+
+        if hosts is not None:
+            if ports is None:
+                ports = tuple(None for _ in hosts)
+
+        return hosts, ports
+
     def do_begin_twophase(self, connection, xid):
         self.do_begin(connection.connection)
 
index 08a1cd8f6bdd09dc58475408695a0025e5ace153..6794a273e590cb4bc138d262b739af245b835d51 100644 (file)
@@ -27,6 +27,7 @@ from sqlalchemy import Table
 from sqlalchemy import testing
 from sqlalchemy import text
 from sqlalchemy import TypeDecorator
+from sqlalchemy.dialects.postgresql import asyncpg as asyncpg_dialect
 from sqlalchemy.dialects.postgresql import base as postgresql
 from sqlalchemy.dialects.postgresql import HSTORE
 from sqlalchemy.dialects.postgresql import JSONB
@@ -240,53 +241,156 @@ $$ LANGUAGE plpgsql;"""
         ),
         (
             "postgresql+psycopg2://USER:PASS@/DB"
-            "?host=hostA&host=hostB:portB&host=hostC:portC",
+            "?host=hostA&host=hostB:222&host=hostC:333",
             {
                 "dbname": "DB",
                 "user": "USER",
                 "password": "PASS",
                 "host": "hostA,hostB,hostC",
-                "port": ",portB,portC",
+                "port": ",222,333",
             },
         ),
         (
             "postgresql+psycopg2://USER:PASS@/DB?"
-            "host=hostA:portA&host=hostB:portB&host=hostC:portC",
+            "host=hostA:111&host=hostB:222&host=hostC:333",
             {
                 "dbname": "DB",
                 "user": "USER",
                 "password": "PASS",
                 "host": "hostA,hostB,hostC",
-                "port": "portA,portB,portC",
+                "port": "111,222,333",
             },
         ),
         (
             "postgresql+psycopg2:///"
-            "?host=hostA:portA&host=hostB:portB&host=hostC:portC",
-            {"host": "hostA,hostB,hostC", "port": "portA,portB,portC"},
+            "?host=hostA:111&host=hostB:222&host=hostC:333",
+            {"host": "hostA,hostB,hostC", "port": "111,222,333"},
         ),
         (
             "postgresql+psycopg2:///"
-            "?host=hostA:portA&host=hostB:portB&host=hostC:portC",
-            {"host": "hostA,hostB,hostC", "port": "portA,portB,portC"},
+            "?host=hostA:111&host=hostB:222&host=hostC:333",
+            {"host": "hostA,hostB,hostC", "port": "111,222,333"},
         ),
         (
             "postgresql+psycopg2:///"
-            "?host=hostA,hostB,hostC&port=portA,portB,portC",
-            {"host": "hostA,hostB,hostC", "port": "portA,portB,portC"},
+            "?host=hostA,hostB,hostC&port=111,222,333",
+            {"host": "hostA,hostB,hostC", "port": "111,222,333"},
         ),
-        argnames="url_string,expected",
+        argnames="url_string,expected_psycopg",
     )
     @testing.combinations(
         psycopg2_dialect.dialect(),
         psycopg_dialect.dialect(),
         argnames="dialect",
     )
-    def test_psycopg_multi_hosts(self, dialect, url_string, expected):
+    def test_multi_hosts(self, dialect, url_string, expected_psycopg):
+        url_string = url_string.replace("psycopg2", dialect.name)
+
+        u = url.make_url(url_string)
+
+        if dialect.driver in ("psycopg", "psycopg2"):
+            cargs, cparams = dialect.create_connect_args(u)
+            eq_(cparams, expected_psycopg)
+            eq_(cargs, [])
+        else:
+            assert False
+
+    @testing.combinations(
+        (
+            "postgresql+asyncpg://USER:PASS@/DB?"
+            "host=hostA:111&host=hostB:222&host=hostC:333",
+            {
+                "database": "DB",
+                "user": "USER",
+                "password": "PASS",
+                "host": ["hostA", "hostB", "hostC"],
+                "port": [111, 222, 333],
+            },
+        ),
+        (
+            "postgresql+asyncpg:///"
+            "?host=hostA:111&host=hostB:222&host=hostC:333",
+            {
+                "host": ["hostA", "hostB", "hostC"],
+                "port": [111, 222, 333],
+            },
+        ),
+        (
+            "postgresql+asyncpg:///"
+            "?host=hostA:111&host=hostB:222&host=hostC:333",
+            {
+                "host": ["hostA", "hostB", "hostC"],
+                "port": [111, 222, 333],
+            },
+        ),
+        (
+            "postgresql+asyncpg:///"
+            "?host=hostA,hostB,hostC&port=111,222,333",
+            {
+                "host": ["hostA", "hostB", "hostC"],
+                "port": [111, 222, 333],
+            },
+        ),
+        argnames="url_string,expected",
+    )
+    def test_asyncpg_multi_hosts(self, url_string, expected):
+        dialect = asyncpg_dialect.dialect()
         u = url.make_url(url_string)
         cargs, cparams = dialect.create_connect_args(u)
-        eq_(cargs, [])
         eq_(cparams, expected)
+        eq_(cargs, [])
+
+    @testing.combinations(
+        (
+            "postgresql+asyncpg://USER:PASS@/DB?host=hostA",
+            "All ports are required to be present"
+            " for asyncpg multiple host URL",
+        ),
+        (
+            "postgresql+asyncpg://USER:PASS@/DB"
+            "?host=hostA&host=hostB&host=hostC",
+            "All ports are required to be present"
+            " for asyncpg multiple host URL",
+        ),
+        (
+            "postgresql+asyncpg://USER:PASS@/DB"
+            "?host=hostA&host=hostB:222&host=hostC:333",
+            "All ports are required to be present"
+            " for asyncpg multiple host URL",
+        ),
+        (
+            "postgresql+asyncpg://USER:PASS@/DB"
+            "?host=hostA,hostB,hostC&port=111,,333",
+            "All ports are required to be present"
+            " for asyncpg multiple host URL",
+        ),
+        (
+            "postgresql+asyncpg://USER:PASS@/DB"
+            "?host=hostA,hostB,&port=111,222,333",
+            "Some hosts is not specified in a connection string",
+        ),
+        (
+            "postgresql+asyncpg://USER:PASS@/DB"
+            "?host=hostA:111&host=:222&host=hostC:333",
+            "Some hosts is not specified in a connection string",
+        ),
+        (
+            "postgresql+asyncpg://USER:PASS@/DB"
+            "?host=hostA:111&host=hostB:vvv&host=hostC:333",
+            "Some of specified ports is not a valid integer",
+        ),
+        (
+            "postgresql+asyncpg://USER:PASS@/DB"
+            "?host=hostA,hostB,hostC&port=111,222",
+            "number of hosts and ports don't match",
+        ),
+        argnames="url_string,expected_message",
+    )
+    def test_asyncpg_multi_hosts_errors(self, url_string, expected_message):
+        dialect = asyncpg_dialect.dialect()
+        with expect_raises_message(exc.ArgumentError, expected_message):
+            u = url.make_url(url_string)
+            dialect.create_connect_args(u)
 
     @testing.combinations(
         "postgresql+psycopg2:///?host=H&host=H&port=5432,5432",
@@ -373,6 +477,29 @@ class BackendDialectTest(fixtures.TestBase):
         with e.connect() as conn:
             eq_(conn.exec_driver_sql("select 1").scalar(), 1)
 
+    @testing.only_on(["+asyncpg"])
+    @testing.combinations(
+        "host=H:P&host=H:P&host=H:P",
+        "host=H,H,H&port=P,P,P",
+    )
+    def test_connect_asyncpg_multiple_hosts(self, pattern):
+        tdb_url = testing.db.url
+
+        host = tdb_url.host
+        if host == "127.0.0.1":
+            host = "localhost"
+        port = str(tdb_url.port) if tdb_url.port else "5432"
+
+        query_str = pattern.replace("H", host).replace("P", port)
+        url_string = (
+            f"{tdb_url.drivername}://{tdb_url.username}:"
+            f"{tdb_url.password}@/{tdb_url.database}?{query_str}"
+        )
+
+        e = create_engine(url_string)
+        with e.connect() as conn:
+            eq_(conn.exec_driver_sql("select 1").scalar(), 1)
+
 
 class PGCodeTest(fixtures.TestBase):
     __only_on__ = "postgresql"