From: Daniele Varrazzo Date: Thu, 19 Mar 2020 08:40:15 +0000 (+1300) Subject: Added high-level functions to manipulate conninfo X-Git-Tag: 3.0.dev0~697 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=80adf47e2e82bfc184e1f4c3a66403b303870239;p=thirdparty%2Fpsycopg.git Added high-level functions to manipulate conninfo --- diff --git a/psycopg3/conninfo.py b/psycopg3/conninfo.py new file mode 100644 index 000000000..365c3e3d9 --- /dev/null +++ b/psycopg3/conninfo.py @@ -0,0 +1,81 @@ +import re + +from . import pq +from . import exceptions as exc + + +def make_conninfo(conninfo=None, **kwargs): + """ + Merge a string and keyword params into a single conninfo string. + + Raise ProgrammingError if the input don't make a valid conninfo. + """ + if conninfo is None and not kwargs: + return "" + + # If no kwarg is specified don't mung the conninfo but check if it's correct + if not kwargs: + _parse_conninfo(conninfo) + return conninfo + + # Override the conninfo with the parameters + # Drop the None arguments + kwargs = {k: v for (k, v) in kwargs.items() if v is not None} + + if conninfo is not None: + tmp = conninfo_to_dict(conninfo) + tmp.update(kwargs) + kwargs = tmp + + conninfo = " ".join( + ["%s=%s" % (k, _param_escape(str(v))) for (k, v) in kwargs.items()] + ) + + # Verify the result is valid + _parse_conninfo(conninfo) + + return conninfo + + +def conninfo_to_dict(conninfo): + """ + Convert the *conninfo* string into a dictionary of parameters. + + Raise ProgrammingError if the string is not valid. + """ + opts = _parse_conninfo(conninfo) + return { + opt.keyword.decode("utf8"): opt.val.decode("utf8") + for opt in opts + if opt.val is not None + } + + +def _parse_conninfo(conninfo): + """ + Verify that *conninfo* is a valid connection string. + + Raise ProgrammingError if the string is not valid. + + Return the result of pq.Conninfo.parse() on success. + """ + try: + return pq.Conninfo.parse(conninfo.encode("utf8")) + except pq.PQerror as e: + raise exc.ProgrammingError(str(e)) + + +def _param_escape( + s, re_escape=re.compile(r"([\\'])"), re_space=re.compile(r"\s") +): + """ + Apply the escaping rule required by PQconnectdb + """ + if not s: + return "''" + + s = re_escape.sub(r"\\\1", s) + if re_space.search(s): + s = "'" + s + "'" + + return s diff --git a/tests/test_conninfo.py b/tests/test_conninfo.py new file mode 100644 index 000000000..5eacb31dc --- /dev/null +++ b/tests/test_conninfo.py @@ -0,0 +1,44 @@ +import pytest + +from psycopg3.conninfo import make_conninfo, conninfo_to_dict +from psycopg3 import ProgrammingError + + +@pytest.mark.parametrize( + "conninfo, kwargs, exp", + [ + ("", {}, ""), + ("dbname=foo", {}, "dbname=foo"), + ("dbname=foo", {"user": "bar"}, "dbname=foo user=bar"), + ("dbname=foo", {"dbname": "bar"}, "dbname=bar"), + ("user=bar", {"dbname": "foo bar"}, "dbname='foo bar' user=bar"), + ("", {"dbname": "foo"}, "dbname=foo"), + ("", {"dbname": "foo", "user": None}, "dbname=foo"), + ("", {"dbname": "a'b"}, r"dbname='a\'b'"), + ], +) +def test_make_conninfo(conninfo, kwargs, exp): + out = make_conninfo(conninfo, **kwargs) + assert conninfo_to_dict(out) == conninfo_to_dict(exp) + + +@pytest.mark.parametrize( + "conninfo, kwargs", + [("dbname=foo bar", {}), ("foo=bar", {}), ("dbname=foo", {"bar": "baz"})], +) +def test_make_conninfo_bad(conninfo, kwargs): + with pytest.raises(ProgrammingError): + make_conninfo(conninfo, **kwargs) + + +@pytest.mark.parametrize( + "conninfo, exp", + [ + ("", {}), + ("dbname=foo user=bar", {"dbname": "foo", "user": "bar"}), + ("dbname='foo bar'", {"dbname": "foo bar"}), + (r"dbname='a\'b'", {"dbname": "a'b"}), + ], +) +def test_conninfo_to_dict(conninfo, exp): + assert conninfo_to_dict(conninfo) == exp