+import os
import sys
from typing import Iterator, List, NamedTuple
from tempfile import TemporaryFile
raise
+@pytest.fixture
+def setpgenv(monkeypatch):
+ """Replace the PG* env vars with the vars provided."""
+
+ def setpgenv_(env):
+ ks = [k for k in os.environ if k.startswith("PG")]
+ for k in ks:
+ monkeypatch.delenv(k)
+
+ if env:
+ for k, v in env.items():
+ monkeypatch.setenv(k, v)
+
+ return setpgenv_
+
+
@pytest.fixture
def trace(libpq):
pqver = pq.__build_version__ or pq.version()
(("dbname=foo",), {"user": None}, "dbname=foo"),
],
)
-async def test_connect_args(aconn_cls, monkeypatch, pgconn, args, kwargs, want):
+async def test_connect_args(
+ aconn_cls, monkeypatch, setpgenv, pgconn, args, kwargs, want
+):
the_conninfo: str
def fake_connect(conninfo):
return pgconn
yield
+ setpgenv({})
monkeypatch.setattr(psycopg.connection, "connect", fake_connect)
conn = await aconn_cls.connect(*args, **kwargs)
assert conninfo_to_dict(the_conninfo) == conninfo_to_dict(want)
@pytest.mark.parametrize("dsn, kwargs, exp", conninfo_params_timeout)
-async def test_get_connection_params(aconn_cls, dsn, kwargs, exp):
+async def test_get_connection_params(aconn_cls, dsn, kwargs, exp, setpgenv):
+ setpgenv({})
params = await aconn_cls._get_connection_params(dsn, **kwargs)
conninfo = make_conninfo(**params)
assert conninfo_to_dict(conninfo) == exp[0]
)
@pytest.mark.asyncio
async def test_resolve_hostaddr_async_no_resolve(
- monkeypatch, conninfo, want, env, fail_resolve
+ setpgenv, conninfo, want, env, fail_resolve
):
- if env:
- for k, v in env.items():
- monkeypatch.setenv(k, v)
+ setpgenv(env)
params = conninfo_to_dict(conninfo)
params = await resolve_hostaddr_async(params)
assert conninfo_to_dict(want) == params
],
)
@pytest.mark.asyncio
-async def test_resolve_hostaddr_async_bad(monkeypatch, conninfo, env, fake_resolve):
- if env:
- for k, v in env.items():
- monkeypatch.setenv(k, v)
+async def test_resolve_hostaddr_async_bad(setpgenv, conninfo, env, fake_resolve):
+ setpgenv(env)
params = conninfo_to_dict(conninfo)
with pytest.raises(psycopg.Error):
await resolve_hostaddr_async(params)
@pytest.mark.flakey("random weight order, might cause wrong order")
@pytest.mark.parametrize("conninfo, want, env", samples_ok)
-def test_srv(conninfo, want, env, fake_srv, monkeypatch):
- if env:
- for k, v in env.items():
- monkeypatch.setenv(k, v)
+def test_srv(conninfo, want, env, fake_srv, setpgenv):
+ setpgenv(env)
params = conninfo_to_dict(conninfo)
params = psycopg._dns.resolve_srv(params) # type: ignore[attr-defined]
assert conninfo_to_dict(want) == params
@pytest.mark.asyncio
@pytest.mark.parametrize("conninfo, want, env", samples_ok)
-async def test_srv_async(conninfo, want, env, afake_srv, monkeypatch):
- if env:
- for k, v in env.items():
- monkeypatch.setenv(k, v)
+async def test_srv_async(conninfo, want, env, afake_srv, setpgenv):
+ setpgenv(env)
params = conninfo_to_dict(conninfo)
params = await (
psycopg._dns.resolve_srv_async(params) # type: ignore[attr-defined]
@pytest.mark.parametrize("conninfo, env", samples_bad)
-def test_srv_bad(conninfo, env, fake_srv, monkeypatch):
- if env:
- for k, v in env.items():
- monkeypatch.setenv(k, v)
+def test_srv_bad(conninfo, env, fake_srv, setpgenv):
+ setpgenv(env)
params = conninfo_to_dict(conninfo)
with pytest.raises(psycopg.OperationalError):
psycopg._dns.resolve_srv(params) # type: ignore[attr-defined]
@pytest.mark.asyncio
@pytest.mark.parametrize("conninfo, env", samples_bad)
-async def test_srv_bad_async(conninfo, env, afake_srv, monkeypatch):
- if env:
- for k, v in env.items():
- monkeypatch.setenv(k, v)
+async def test_srv_bad_async(conninfo, env, afake_srv, setpgenv):
+ setpgenv(env)
params = conninfo_to_dict(conninfo)
with pytest.raises(psycopg.OperationalError):
await psycopg._dns.resolve_srv_async(params) # type: ignore[attr-defined]