def get_deferrable(self, connection):
raise NotImplementedError()
+ def _split_multihost_from_url(self, url):
+ hosts = ports = None
+
+ integrated_multihost = False
+
+ if "host" in url.query:
+ if isinstance(url.query["host"], (list, tuple)):
+ integrated_multihost = True
+ hosts, ports = zip(
+ *[
+ token.split(":") if ":" in token else (token, None)
+ for token in url.query["host"]
+ ]
+ )
+
+ elif isinstance(url.query["host"], str):
+ hosts = tuple(url.query["host"].split(","))
+
+ if "port" in url.query:
+ if integrated_multihost:
+ raise exc.ArgumentError(
+ "Can't mix 'multihost' formats together; use "
+ '"host=h1,h2,h3&port=p1,p2,p3" or '
+ '"host=h1:p1&host=h2:p2&host=h3:p3" separately'
+ )
+ if isinstance(url.query["port"], (list, tuple)):
+ ports = url.query["port"]
+ elif isinstance(url.query["port"], str):
+ ports = tuple(url.query["port"].split(","))
+
+ if ports:
+ try:
+ ports = tuple(int(x) if x else None for x in ports)
+ except ValueError:
+ raise exc.ArgumentError(
+ f"Some of specified ports is not a "
+ f"valid integer: `{ports}`"
+ ) from None
+
+ if ports and (not hosts or len(hosts) != len(ports)):
+ raise exc.ArgumentError("number of hosts and ports don't match")
+
+ if hosts is not None:
+ if ports is None:
+ ports = tuple(None for _ in hosts)
+
+ return hosts, ports
+
def do_begin_twophase(self, connection, xid):
self.do_begin(connection.connection)
from sqlalchemy import testing
from sqlalchemy import text
from sqlalchemy import TypeDecorator
+from sqlalchemy.dialects.postgresql import asyncpg as asyncpg_dialect
from sqlalchemy.dialects.postgresql import base as postgresql
from sqlalchemy.dialects.postgresql import HSTORE
from sqlalchemy.dialects.postgresql import JSONB
),
(
"postgresql+psycopg2://USER:PASS@/DB"
- "?host=hostA&host=hostB:portB&host=hostC:portC",
+ "?host=hostA&host=hostB:222&host=hostC:333",
{
"dbname": "DB",
"user": "USER",
"password": "PASS",
"host": "hostA,hostB,hostC",
- "port": ",portB,portC",
+ "port": ",222,333",
},
),
(
"postgresql+psycopg2://USER:PASS@/DB?"
- "host=hostA:portA&host=hostB:portB&host=hostC:portC",
+ "host=hostA:111&host=hostB:222&host=hostC:333",
{
"dbname": "DB",
"user": "USER",
"password": "PASS",
"host": "hostA,hostB,hostC",
- "port": "portA,portB,portC",
+ "port": "111,222,333",
},
),
(
"postgresql+psycopg2:///"
- "?host=hostA:portA&host=hostB:portB&host=hostC:portC",
- {"host": "hostA,hostB,hostC", "port": "portA,portB,portC"},
+ "?host=hostA:111&host=hostB:222&host=hostC:333",
+ {"host": "hostA,hostB,hostC", "port": "111,222,333"},
),
(
"postgresql+psycopg2:///"
- "?host=hostA:portA&host=hostB:portB&host=hostC:portC",
- {"host": "hostA,hostB,hostC", "port": "portA,portB,portC"},
+ "?host=hostA:111&host=hostB:222&host=hostC:333",
+ {"host": "hostA,hostB,hostC", "port": "111,222,333"},
),
(
"postgresql+psycopg2:///"
- "?host=hostA,hostB,hostC&port=portA,portB,portC",
- {"host": "hostA,hostB,hostC", "port": "portA,portB,portC"},
+ "?host=hostA,hostB,hostC&port=111,222,333",
+ {"host": "hostA,hostB,hostC", "port": "111,222,333"},
),
- argnames="url_string,expected",
+ argnames="url_string,expected_psycopg",
)
@testing.combinations(
psycopg2_dialect.dialect(),
psycopg_dialect.dialect(),
argnames="dialect",
)
- def test_psycopg_multi_hosts(self, dialect, url_string, expected):
+ def test_multi_hosts(self, dialect, url_string, expected_psycopg):
+ url_string = url_string.replace("psycopg2", dialect.name)
+
+ u = url.make_url(url_string)
+
+ if dialect.driver in ("psycopg", "psycopg2"):
+ cargs, cparams = dialect.create_connect_args(u)
+ eq_(cparams, expected_psycopg)
+ eq_(cargs, [])
+ else:
+ assert False
+
+ @testing.combinations(
+ (
+ "postgresql+asyncpg://USER:PASS@/DB?"
+ "host=hostA:111&host=hostB:222&host=hostC:333",
+ {
+ "database": "DB",
+ "user": "USER",
+ "password": "PASS",
+ "host": ["hostA", "hostB", "hostC"],
+ "port": [111, 222, 333],
+ },
+ ),
+ (
+ "postgresql+asyncpg:///"
+ "?host=hostA:111&host=hostB:222&host=hostC:333",
+ {
+ "host": ["hostA", "hostB", "hostC"],
+ "port": [111, 222, 333],
+ },
+ ),
+ (
+ "postgresql+asyncpg:///"
+ "?host=hostA:111&host=hostB:222&host=hostC:333",
+ {
+ "host": ["hostA", "hostB", "hostC"],
+ "port": [111, 222, 333],
+ },
+ ),
+ (
+ "postgresql+asyncpg:///"
+ "?host=hostA,hostB,hostC&port=111,222,333",
+ {
+ "host": ["hostA", "hostB", "hostC"],
+ "port": [111, 222, 333],
+ },
+ ),
+ argnames="url_string,expected",
+ )
+ def test_asyncpg_multi_hosts(self, url_string, expected):
+ dialect = asyncpg_dialect.dialect()
u = url.make_url(url_string)
cargs, cparams = dialect.create_connect_args(u)
- eq_(cargs, [])
eq_(cparams, expected)
+ eq_(cargs, [])
+
+ @testing.combinations(
+ (
+ "postgresql+asyncpg://USER:PASS@/DB?host=hostA",
+ "All ports are required to be present"
+ " for asyncpg multiple host URL",
+ ),
+ (
+ "postgresql+asyncpg://USER:PASS@/DB"
+ "?host=hostA&host=hostB&host=hostC",
+ "All ports are required to be present"
+ " for asyncpg multiple host URL",
+ ),
+ (
+ "postgresql+asyncpg://USER:PASS@/DB"
+ "?host=hostA&host=hostB:222&host=hostC:333",
+ "All ports are required to be present"
+ " for asyncpg multiple host URL",
+ ),
+ (
+ "postgresql+asyncpg://USER:PASS@/DB"
+ "?host=hostA,hostB,hostC&port=111,,333",
+ "All ports are required to be present"
+ " for asyncpg multiple host URL",
+ ),
+ (
+ "postgresql+asyncpg://USER:PASS@/DB"
+ "?host=hostA,hostB,&port=111,222,333",
+ "Some hosts is not specified in a connection string",
+ ),
+ (
+ "postgresql+asyncpg://USER:PASS@/DB"
+ "?host=hostA:111&host=:222&host=hostC:333",
+ "Some hosts is not specified in a connection string",
+ ),
+ (
+ "postgresql+asyncpg://USER:PASS@/DB"
+ "?host=hostA:111&host=hostB:vvv&host=hostC:333",
+ "Some of specified ports is not a valid integer",
+ ),
+ (
+ "postgresql+asyncpg://USER:PASS@/DB"
+ "?host=hostA,hostB,hostC&port=111,222",
+ "number of hosts and ports don't match",
+ ),
+ argnames="url_string,expected_message",
+ )
+ def test_asyncpg_multi_hosts_errors(self, url_string, expected_message):
+ dialect = asyncpg_dialect.dialect()
+ with expect_raises_message(exc.ArgumentError, expected_message):
+ u = url.make_url(url_string)
+ dialect.create_connect_args(u)
@testing.combinations(
"postgresql+psycopg2:///?host=H&host=H&port=5432,5432",
with e.connect() as conn:
eq_(conn.exec_driver_sql("select 1").scalar(), 1)
+ @testing.only_on(["+asyncpg"])
+ @testing.combinations(
+ "host=H:P&host=H:P&host=H:P",
+ "host=H,H,H&port=P,P,P",
+ )
+ def test_connect_asyncpg_multiple_hosts(self, pattern):
+ tdb_url = testing.db.url
+
+ host = tdb_url.host
+ if host == "127.0.0.1":
+ host = "localhost"
+ port = str(tdb_url.port) if tdb_url.port else "5432"
+
+ query_str = pattern.replace("H", host).replace("P", port)
+ url_string = (
+ f"{tdb_url.drivername}://{tdb_url.username}:"
+ f"{tdb_url.password}@/{tdb_url.database}?{query_str}"
+ )
+
+ e = create_engine(url_string)
+ with e.connect() as conn:
+ eq_(conn.exec_driver_sql("select 1").scalar(), 1)
+
class PGCodeTest(fixtures.TestBase):
__only_on__ = "postgresql"