]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
feat(sql): add explicit type cast to Literal output
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 13 Mar 2022 02:35:41 +0000 (03:35 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 10 May 2022 17:13:26 +0000 (19:13 +0200)
docs/api/sql.rst
docs/news.rst
psycopg/psycopg/sql.py
tests/test_adapt.py
tests/test_sql.py

index cb1ba20439f0f4a69a966fd6c468b544803189d8..63906f8739e8b6628083e990383602be6d57f807 100644 (file)
@@ -123,6 +123,10 @@ The `!sql` objects are in the following inheritance hierarchy:
 
 .. autoclass:: Literal
 
+    .. versionchanged:: 3.1
+        Add a type cast to the representation if useful in ambiguous context
+        (e.g. ``'2000-01-01'::date``)
+
 .. autoclass:: Placeholder
 
 .. autoclass:: Composed
index 1943f371537c86e67dc0b1a6104de59a93c4d9c2..83762a50ffa3014a1f18060e49a82736d8a3ad13 100644 (file)
@@ -23,6 +23,7 @@ Psycopg 3.1 (unreleased)
 - Add `pq.PGconn.trace()` and related trace functions (:ticket:`#167`).
 - Add ``prepare_threshold`` parameter to `Connection` init (:ticket:`#200`).
 - Add `Error.pgconn` and `Error.pgresult` attributes (:ticket:`#242`).
+- Add explicit type cast to values converted by `sql.Literal` (:ticket:`#205`).
 - Drop support for Python 3.6.
 
 
index 39ecfc9e821fd92aff2127bf640c899d55ecb36b..eb5be8aefc9564d85b95a701be439001d5c3d434 100644 (file)
@@ -381,18 +381,26 @@ class Literal(Composable):
 
     Example::
 
-        >>> s1 = sql.Literal("foo")
-        >>> s2 = sql.Literal("ba'r")
-        >>> s3 = sql.Literal(42)
+        >>> s1 = sql.Literal("fo'o")
+        >>> s2 = sql.Literal(42)
+        >>> s3 = sql.Literal(date(2000, 1, 1))
         >>> print(sql.SQL(', ').join([s1, s2, s3]).as_string(conn))
-        'foo', 'ba''r', 42
+        'fo''o', 42, '2000-01-01'::date
 
     """
 
     def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
         tx = Transformer.from_context(context)
         dumper = tx.get_dumper(self._obj, PyFormat.TEXT)
-        return dumper.quote(self._obj)
+        rv = dumper.quote(self._obj)
+        # If the result is quoted and the oid not unknown,
+        # add an explicit type cast.
+        if rv[-1] == 39 and dumper.oid:
+            ti = tx.adapters.types.get(dumper.oid)
+            if ti:
+                # TODO: ugly encoding just to be decoded by as_string()
+                rv = b"%s::%s" % (rv, ti.name.encode(tx.encoding))
+        return rv
 
 
 class Placeholder(Composable):
index a64ac00959976d930215eda86b5467dd9923e471..9656a8f2568551ecf009d9d7c09f94fbaa0c0992 100644 (file)
@@ -142,7 +142,7 @@ def test_dumper_protocol(conn):
     assert cur.fetchone()[0] == "hellohello"
     cur = conn.execute("select %s", [["hi", "ha"]])
     assert cur.fetchone()[0] == ["hihi", "haha"]
-    assert sql.Literal("hello").as_string(conn) == "'qelloqello'"
+    assert sql.Literal("hello").as_string(conn) == "'qelloqello'::text"
 
 
 def test_loader_protocol(conn):
index 6550c3f8e158ba24c663099ef4cf25235e937168..98629b3501e285f83830a75d694db041c28e9efd 100644 (file)
@@ -48,7 +48,7 @@ def test_quote_stable_despite_deranged_libpq(conn):
     # Verify the libpq behaviour of PQescapeString using the last setting seen.
     # Check that we are not affected by it.
     good_str = " E'\\\\'"
-    good_bytes = " E'\\\\000'"
+    good_bytes = " E'\\\\000'::bytea"
     conn.execute("set standard_conforming_strings to on")
     assert pq.Escaping().escape_string(b"\\") == b"\\"
     assert sql.quote("\\") == good_str
@@ -109,7 +109,7 @@ class TestSqlFormat:
     def test_compose_literal(self, conn):
         s = sql.SQL("select {0};").format(sql.Literal(dt.date(2016, 12, 31)))
         s1 = s.as_string(conn)
-        assert s1 == "select '2016-12-31';"
+        assert s1 == "select '2016-12-31'::date;"
 
     def test_compose_empty(self, conn):
         s = sql.SQL("select foo;").format()
@@ -161,7 +161,7 @@ class TestSqlFormat:
 
     def test_auto_literal(self, conn):
         s = sql.SQL("select {}, {}, {}").format("he'lo", 10, dt.date(2020, 1, 1))
-        assert s.as_string(conn) == "select 'he''lo', 10, '2020-01-01'"
+        assert s.as_string(conn) == "select 'he''lo', 10, '2020-01-01'::date"
 
     def test_execute(self, conn):
         cur = conn.cursor()
@@ -311,13 +311,13 @@ class TestLiteral:
         assert sql.Literal(None).as_string(conn) == "NULL"
         assert no_e(sql.Literal("foo").as_string(conn)) == "'foo'"
         assert sql.Literal(42).as_string(conn) == "42"
-        assert sql.Literal(dt.date(2017, 1, 1)).as_string(conn) == "'2017-01-01'"
+        assert sql.Literal(dt.date(2017, 1, 1)).as_string(conn) == "'2017-01-01'::date"
 
     def test_as_bytes(self, conn):
         assert sql.Literal(None).as_bytes(conn) == b"NULL"
         assert no_e(sql.Literal("foo").as_bytes(conn)) == b"'foo'"
         assert sql.Literal(42).as_bytes(conn) == b"42"
-        assert sql.Literal(dt.date(2017, 1, 1)).as_bytes(conn) == b"'2017-01-01'"
+        assert sql.Literal(dt.date(2017, 1, 1)).as_bytes(conn) == b"'2017-01-01'::date"
 
         conn.execute("set client_encoding to utf8")
         assert sql.Literal(eur).as_bytes(conn) == f"'{eur}'".encode()
@@ -432,7 +432,7 @@ class TestComposed:
         obj = sql.Composed(["fo'o", dt.date(2020, 1, 1)])
         obj = obj.join(", ")
         assert isinstance(obj, sql.Composed)
-        assert no_e(obj.as_string(conn)) == "'fo''o', '2020-01-01'"
+        assert no_e(obj.as_string(conn)) == "'fo''o', '2020-01-01'::date"
 
     def test_sum(self, conn):
         obj = sql.Composed([sql.SQL("foo ")])