--- /dev/null
+import re
+
+from . import pq
+from . import exceptions as exc
+
+
+def make_conninfo(conninfo=None, **kwargs):
+ """
+ Merge a string and keyword params into a single conninfo string.
+
+ Raise ProgrammingError if the input don't make a valid conninfo.
+ """
+ if conninfo is None and not kwargs:
+ return ""
+
+ # If no kwarg is specified don't mung the conninfo but check if it's correct
+ if not kwargs:
+ _parse_conninfo(conninfo)
+ return conninfo
+
+ # Override the conninfo with the parameters
+ # Drop the None arguments
+ kwargs = {k: v for (k, v) in kwargs.items() if v is not None}
+
+ if conninfo is not None:
+ tmp = conninfo_to_dict(conninfo)
+ tmp.update(kwargs)
+ kwargs = tmp
+
+ conninfo = " ".join(
+ ["%s=%s" % (k, _param_escape(str(v))) for (k, v) in kwargs.items()]
+ )
+
+ # Verify the result is valid
+ _parse_conninfo(conninfo)
+
+ return conninfo
+
+
+def conninfo_to_dict(conninfo):
+ """
+ Convert the *conninfo* string into a dictionary of parameters.
+
+ Raise ProgrammingError if the string is not valid.
+ """
+ opts = _parse_conninfo(conninfo)
+ return {
+ opt.keyword.decode("utf8"): opt.val.decode("utf8")
+ for opt in opts
+ if opt.val is not None
+ }
+
+
+def _parse_conninfo(conninfo):
+ """
+ Verify that *conninfo* is a valid connection string.
+
+ Raise ProgrammingError if the string is not valid.
+
+ Return the result of pq.Conninfo.parse() on success.
+ """
+ try:
+ return pq.Conninfo.parse(conninfo.encode("utf8"))
+ except pq.PQerror as e:
+ raise exc.ProgrammingError(str(e))
+
+
+def _param_escape(
+ s, re_escape=re.compile(r"([\\'])"), re_space=re.compile(r"\s")
+):
+ """
+ Apply the escaping rule required by PQconnectdb
+ """
+ if not s:
+ return "''"
+
+ s = re_escape.sub(r"\\\1", s)
+ if re_space.search(s):
+ s = "'" + s + "'"
+
+ return s
--- /dev/null
+import pytest
+
+from psycopg3.conninfo import make_conninfo, conninfo_to_dict
+from psycopg3 import ProgrammingError
+
+
+@pytest.mark.parametrize(
+ "conninfo, kwargs, exp",
+ [
+ ("", {}, ""),
+ ("dbname=foo", {}, "dbname=foo"),
+ ("dbname=foo", {"user": "bar"}, "dbname=foo user=bar"),
+ ("dbname=foo", {"dbname": "bar"}, "dbname=bar"),
+ ("user=bar", {"dbname": "foo bar"}, "dbname='foo bar' user=bar"),
+ ("", {"dbname": "foo"}, "dbname=foo"),
+ ("", {"dbname": "foo", "user": None}, "dbname=foo"),
+ ("", {"dbname": "a'b"}, r"dbname='a\'b'"),
+ ],
+)
+def test_make_conninfo(conninfo, kwargs, exp):
+ out = make_conninfo(conninfo, **kwargs)
+ assert conninfo_to_dict(out) == conninfo_to_dict(exp)
+
+
+@pytest.mark.parametrize(
+ "conninfo, kwargs",
+ [("dbname=foo bar", {}), ("foo=bar", {}), ("dbname=foo", {"bar": "baz"})],
+)
+def test_make_conninfo_bad(conninfo, kwargs):
+ with pytest.raises(ProgrammingError):
+ make_conninfo(conninfo, **kwargs)
+
+
+@pytest.mark.parametrize(
+ "conninfo, exp",
+ [
+ ("", {}),
+ ("dbname=foo user=bar", {"dbname": "foo", "user": "bar"}),
+ ("dbname='foo bar'", {"dbname": "foo bar"}),
+ (r"dbname='a\'b'", {"dbname": "a'b"}),
+ ],
+)
+def test_conninfo_to_dict(conninfo, exp):
+ assert conninfo_to_dict(conninfo) == exp