# Copyright (C) 2021 The Psycopg Team
+import os
from typing import Any, Dict
from functools import lru_cache
from ipaddress import ip_address
.. warning::
This function doesn't handle the ``/etc/hosts`` file.
"""
- if params.get("hostaddr") or not params.get("host"):
+ host_arg: str = params.get("host", os.environ.get("PGHOST", ""))
+ hostaddr_arg = params.get("hostaddr", os.environ.get("PGHOSTADDR", ""))
+
+ if hostaddr_arg or not host_arg:
return params
+ port_arg: str = str(params.get("port", os.environ.get("PGPORT", "")))
+
if pq.version() < 100000:
# hostaddr not supported
return params
- host = params["host"]
-
- if host.startswith("/") or host[1:2] == ":":
+ if host_arg.startswith("/") or host_arg[1:2] == ":":
# Local path
return params
- hosts_in = host.split(",")
- ports_in = str(params["port"]).split(",") if params.get("port") else []
- if len(ports_in) <= 1:
+ hosts_in = host_arg.split(",")
+ ports_in = port_arg.split(",")
+ if len(ports_in) == 1:
# If only one port is specified, the libpq will apply it to all
# the hosts, so don't mangle it.
del ports_in[:]
- else:
+ elif len(ports_in) > 1:
if len(ports_in) != len(hosts_in):
# ProgrammingError would have been more appropriate, but this is
# what the raise if the libpq fails connect in the same case.
@pytest.mark.parametrize(
- "conninfo, want",
+ "conninfo, want, env",
[
- ("", ""),
- ("host='' user=bar", "host='' user=bar"),
+ ("", "", None),
+ ("host='' user=bar", "host='' user=bar", None),
(
"host=127.0.0.1 user=bar",
"host=127.0.0.1 user=bar hostaddr=127.0.0.1",
+ None,
),
(
"host=1.1.1.1,2.2.2.2 user=bar",
"host=1.1.1.1,2.2.2.2 user=bar hostaddr=1.1.1.1,2.2.2.2",
+ None,
),
(
"host=1.1.1.1,2.2.2.2 port=5432",
"host=1.1.1.1,2.2.2.2 port=5432 hostaddr=1.1.1.1,2.2.2.2",
+ None,
+ ),
+ (
+ "port=5432",
+ "host=1.1.1.1,2.2.2.2 port=5432 hostaddr=1.1.1.1,2.2.2.2",
+ {"PGHOST": "1.1.1.1,2.2.2.2"},
+ ),
+ (
+ "host=foo.com port=5432",
+ "host=foo.com port=5432",
+ {"PGHOSTADDR": "1.2.3.4"},
),
],
)
@pytest.mark.asyncio
-async def test_resolve_hostaddr_async_no_resolve(conninfo, want, fail_resolve):
+async def test_resolve_hostaddr_async_no_resolve(
+ monkeypatch, conninfo, want, env, fail_resolve
+):
+ if env:
+ for k, v in env.items():
+ monkeypatch.setenv(k, v)
params = conninfo_to_dict(conninfo)
params = await psycopg._dns.resolve_hostaddr_async(params)
assert conninfo_to_dict(want) == params
@pytest.mark.parametrize(
- "conninfo, want",
+ "conninfo, want, env",
[
(
"host=foo.com,qux.com",
"host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2",
+ None,
),
(
"host=foo.com,qux.com port=5433",
"host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2 port=5433",
+ None,
),
(
"host=foo.com,qux.com port=5432,5433",
"host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2 port=5432,5433",
+ None,
),
(
"host=foo.com,nosuchhost.com",
"host=foo.com hostaddr=1.1.1.1",
+ None,
),
(
"host=nosuchhost.com,foo.com",
"host=foo.com hostaddr=1.1.1.1",
+ None,
+ ),
+ (
+ "host=foo.com,qux.com",
+ "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2",
+ {},
),
],
)
@pytest.mark.asyncio
-async def test_resolve_hostaddr_async(conninfo, want, fake_resolve):
+async def test_resolve_hostaddr_async(conninfo, want, env, fake_resolve):
params = conninfo_to_dict(conninfo)
params = await psycopg._dns.resolve_hostaddr_async(params)
assert conninfo_to_dict(want) == params
@pytest.mark.parametrize(
- "conninfo",
+ "conninfo, env",
[
- "host=bad1.com,bad2.com",
- "host=foo.com port=1,2",
- "host=1.1.1.1,2.2.2.2 port=5432,5433,5434",
+ ("host=bad1.com,bad2.com", None),
+ ("host=foo.com port=1,2", None),
+ ("host=1.1.1.1,2.2.2.2 port=5432,5433,5434", None),
+ ("host=1.1.1.1,2.2.2.2", {"PGPORT": "1,2,3"}),
],
)
@pytest.mark.asyncio
-async def test_resolve_hostaddr_async_bad(conninfo, fake_resolve):
+async def test_resolve_hostaddr_async_bad(
+ monkeypatch, conninfo, env, fake_resolve
+):
+ if env:
+ for k, v in env.items():
+ monkeypatch.setenv(k, v)
params = conninfo_to_dict(conninfo)
with pytest.raises((TypeError, psycopg.Error)):
await psycopg._dns.resolve_hostaddr_async(params)