From: Daniele Varrazzo Date: Sun, 22 May 2022 00:08:03 +0000 (+0200) Subject: fix: don't add a cast to text literals X-Git-Tag: 3.1~73 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=adbc6fbe2b6b46d604751279f92e2d39530e8fb1;p=thirdparty%2Fpsycopg.git fix: don't add a cast to text literals Normally cast is not added to unknown literals, and the text oid is not usually used. However, when it is, it causes problems. Often Postgres wants just a literal, not a cast type; for instance, this is not valid: NOTIFY foo: 'bar'::text; --- diff --git a/psycopg/psycopg/_transform.py b/psycopg/psycopg/_transform.py index 91577f92d..200f31c45 100644 --- a/psycopg/psycopg/_transform.py +++ b/psycopg/psycopg/_transform.py @@ -14,7 +14,7 @@ from . import errors as e from .abc import Buffer, LoadFunc, AdaptContext, PyFormat, DumperKey from .rows import Row, RowMaker from ._compat import TypeAlias -from .postgres import INVALID_OID +from .postgres import INVALID_OID, TEXT_OID from ._encodings import pgconn_encoding if TYPE_CHECKING: @@ -189,11 +189,11 @@ class Transformer(AdaptContext): def as_literal(self, obj: Any) -> Buffer: dumper = self.get_dumper(obj, PyFormat.TEXT) rv = dumper.quote(obj) - # If the result is quoted, and the oid not unknown, + # If the result is quoted, and the oid not unknown or text, # add an explicit type cast. # Check the last char because the first one might be 'E'. oid = dumper.oid - if oid and rv and rv[-1] == b"'"[0]: + if oid and rv and rv[-1] == b"'"[0] and oid != TEXT_OID: try: type_sql = self._oid_types[oid] except KeyError: diff --git a/psycopg_c/psycopg_c/_psycopg/transform.pyx b/psycopg_c/psycopg_c/_psycopg/transform.pyx index e290c3adb..8f2b5f3ee 100644 --- a/psycopg_c/psycopg_c/_psycopg/transform.pyx +++ b/psycopg_c/psycopg_c/_psycopg/transform.pyx @@ -215,10 +215,10 @@ cdef class Transformer: rv = dumper.quote(obj) oid = dumper.oid - # If the result is quoted and the oid not unknown, + # If the result is quoted and the oid not unknown or text, # add an explicit type cast. # Check the last char because the first one might be 'E'. - if oid and rv and rv[-1] == 39: + if oid and oid != oids.TEXT_OID and rv and rv[-1] == 39: if self._oid_types is None: self._oid_types = {} type_ptr = PyDict_GetItem(self._oid_types, oid) diff --git a/tests/test_adapt.py b/tests/test_adapt.py index e43038c9b..7437a71ab 100644 --- a/tests/test_adapt.py +++ b/tests/test_adapt.py @@ -142,7 +142,7 @@ def test_dumper_protocol(conn): assert cur.fetchone()[0] == "hellohello" cur = conn.execute("select %s", [["hi", "ha"]]) assert cur.fetchone()[0] == ["hihi", "haha"] - assert sql.Literal("hello").as_string(conn) == "'qelloqello'::text" + assert sql.Literal("hello").as_string(conn) == "'qelloqello'" def test_loader_protocol(conn): diff --git a/tests/test_sql.py b/tests/test_sql.py index 304bcc779..43a58aa27 100644 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -356,6 +356,10 @@ class TestLiteral: == "'{\"2000-01-01 00:00:00\"}'::timestamp[]" ) + def test_text_literal(self, conn): + conn.adapters.register_dumper(str, StrDumper) + assert sql.Literal("foo").as_string(conn) == "'foo'" + @pytest.mark.parametrize("name", ["a-b", f"{eur}", "order", "foo bar"]) def test_invalid_name(self, conn, name): conn.execute(