From: Daniele Varrazzo Date: Sun, 13 Mar 2022 02:35:41 +0000 (+0100) Subject: feat(sql): add explicit type cast to Literal output X-Git-Tag: 3.1~109^2~10 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=97dd0b37c803eefacd9e5084f16452415d26f3d3;p=thirdparty%2Fpsycopg.git feat(sql): add explicit type cast to Literal output --- diff --git a/docs/api/sql.rst b/docs/api/sql.rst index cb1ba2043..63906f873 100644 --- a/docs/api/sql.rst +++ b/docs/api/sql.rst @@ -123,6 +123,10 @@ The `!sql` objects are in the following inheritance hierarchy: .. autoclass:: Literal + .. versionchanged:: 3.1 + Add a type cast to the representation if useful in ambiguous context + (e.g. ``'2000-01-01'::date``) + .. autoclass:: Placeholder .. autoclass:: Composed diff --git a/docs/news.rst b/docs/news.rst index 1943f3715..83762a50f 100644 --- a/docs/news.rst +++ b/docs/news.rst @@ -23,6 +23,7 @@ Psycopg 3.1 (unreleased) - Add `pq.PGconn.trace()` and related trace functions (:ticket:`#167`). - Add ``prepare_threshold`` parameter to `Connection` init (:ticket:`#200`). - Add `Error.pgconn` and `Error.pgresult` attributes (:ticket:`#242`). +- Add explicit type cast to values converted by `sql.Literal` (:ticket:`#205`). - Drop support for Python 3.6. diff --git a/psycopg/psycopg/sql.py b/psycopg/psycopg/sql.py index 39ecfc9e8..eb5be8aef 100644 --- a/psycopg/psycopg/sql.py +++ b/psycopg/psycopg/sql.py @@ -381,18 +381,26 @@ class Literal(Composable): Example:: - >>> s1 = sql.Literal("foo") - >>> s2 = sql.Literal("ba'r") - >>> s3 = sql.Literal(42) + >>> s1 = sql.Literal("fo'o") + >>> s2 = sql.Literal(42) + >>> s3 = sql.Literal(date(2000, 1, 1)) >>> print(sql.SQL(', ').join([s1, s2, s3]).as_string(conn)) - 'foo', 'ba''r', 42 + 'fo''o', 42, '2000-01-01'::date """ def as_bytes(self, context: Optional[AdaptContext]) -> bytes: tx = Transformer.from_context(context) dumper = tx.get_dumper(self._obj, PyFormat.TEXT) - return dumper.quote(self._obj) + rv = dumper.quote(self._obj) + # If the result is quoted and the oid not unknown, + # add an explicit type cast. + if rv[-1] == 39 and dumper.oid: + ti = tx.adapters.types.get(dumper.oid) + if ti: + # TODO: ugly encoding just to be decoded by as_string() + rv = b"%s::%s" % (rv, ti.name.encode(tx.encoding)) + return rv class Placeholder(Composable): diff --git a/tests/test_adapt.py b/tests/test_adapt.py index a64ac0095..9656a8f25 100644 --- a/tests/test_adapt.py +++ b/tests/test_adapt.py @@ -142,7 +142,7 @@ def test_dumper_protocol(conn): assert cur.fetchone()[0] == "hellohello" cur = conn.execute("select %s", [["hi", "ha"]]) assert cur.fetchone()[0] == ["hihi", "haha"] - assert sql.Literal("hello").as_string(conn) == "'qelloqello'" + assert sql.Literal("hello").as_string(conn) == "'qelloqello'::text" def test_loader_protocol(conn): diff --git a/tests/test_sql.py b/tests/test_sql.py index 6550c3f8e..98629b350 100644 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -48,7 +48,7 @@ def test_quote_stable_despite_deranged_libpq(conn): # Verify the libpq behaviour of PQescapeString using the last setting seen. # Check that we are not affected by it. good_str = " E'\\\\'" - good_bytes = " E'\\\\000'" + good_bytes = " E'\\\\000'::bytea" conn.execute("set standard_conforming_strings to on") assert pq.Escaping().escape_string(b"\\") == b"\\" assert sql.quote("\\") == good_str @@ -109,7 +109,7 @@ class TestSqlFormat: def test_compose_literal(self, conn): s = sql.SQL("select {0};").format(sql.Literal(dt.date(2016, 12, 31))) s1 = s.as_string(conn) - assert s1 == "select '2016-12-31';" + assert s1 == "select '2016-12-31'::date;" def test_compose_empty(self, conn): s = sql.SQL("select foo;").format() @@ -161,7 +161,7 @@ class TestSqlFormat: def test_auto_literal(self, conn): s = sql.SQL("select {}, {}, {}").format("he'lo", 10, dt.date(2020, 1, 1)) - assert s.as_string(conn) == "select 'he''lo', 10, '2020-01-01'" + assert s.as_string(conn) == "select 'he''lo', 10, '2020-01-01'::date" def test_execute(self, conn): cur = conn.cursor() @@ -311,13 +311,13 @@ class TestLiteral: assert sql.Literal(None).as_string(conn) == "NULL" assert no_e(sql.Literal("foo").as_string(conn)) == "'foo'" assert sql.Literal(42).as_string(conn) == "42" - assert sql.Literal(dt.date(2017, 1, 1)).as_string(conn) == "'2017-01-01'" + assert sql.Literal(dt.date(2017, 1, 1)).as_string(conn) == "'2017-01-01'::date" def test_as_bytes(self, conn): assert sql.Literal(None).as_bytes(conn) == b"NULL" assert no_e(sql.Literal("foo").as_bytes(conn)) == b"'foo'" assert sql.Literal(42).as_bytes(conn) == b"42" - assert sql.Literal(dt.date(2017, 1, 1)).as_bytes(conn) == b"'2017-01-01'" + assert sql.Literal(dt.date(2017, 1, 1)).as_bytes(conn) == b"'2017-01-01'::date" conn.execute("set client_encoding to utf8") assert sql.Literal(eur).as_bytes(conn) == f"'{eur}'".encode() @@ -432,7 +432,7 @@ class TestComposed: obj = sql.Composed(["fo'o", dt.date(2020, 1, 1)]) obj = obj.join(", ") assert isinstance(obj, sql.Composed) - assert no_e(obj.as_string(conn)) == "'fo''o', '2020-01-01'" + assert no_e(obj.as_string(conn)) == "'fo''o', '2020-01-01'::date" def test_sum(self, conn): obj = sql.Composed([sql.SQL("foo ")])