]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix(sql): fix `sql.Literal` with invalid type names
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 20 Mar 2022 15:04:54 +0000 (16:04 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 10 May 2022 17:13:26 +0000 (19:13 +0200)
Make sure that `TypeInfo.alt_name` is always populated, even if with a
copy of `name`.

psycopg/psycopg/_typeinfo.py
psycopg/psycopg/sql.py
tests/test_sql.py

index 20eb0d4dba2b4276c3828bbbae41f631abea7287..8ca0896fafd5324349cc311cc8577254d0d6fb0c 100644 (file)
@@ -36,13 +36,14 @@ class TypeInfo:
         name: str,
         oid: int,
         array_oid: int,
+        *,
         alt_name: str = "",
         delimiter: str = ",",
     ):
         self.name = name
         self.oid = oid
         self.array_oid = array_oid
-        self.alt_name = alt_name
+        self.alt_name = alt_name or name
         self.delimiter = delimiter
 
     def __repr__(self) -> str:
index 4e0d0d3dece6fe6f1e3cb8fa066dd1bdb0b08df4..bcf604bb0e01394adb29027d50e229a3a28779b8 100644 (file)
@@ -7,7 +7,8 @@ SQL composition utility module
 import codecs
 import string
 from abc import ABC, abstractmethod
-from typing import Any, Iterator, Iterable, List, Optional, Sequence, Union
+from typing import Any, Dict, Iterator, Iterable, List
+from typing import Optional, Sequence, Union, Tuple
 
 from .pq import Escaping
 from .abc import AdaptContext
@@ -389,6 +390,8 @@ class Literal(Composable):
 
     """
 
+    _names_cache: Dict[Tuple[str, str], bytes] = {}
+
     def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
         tx = Transformer.from_context(context)
         dumper = tx.get_dumper(self._obj, PyFormat.TEXT)
@@ -398,8 +401,12 @@ class Literal(Composable):
         if rv[-1] == b"'"[0] and dumper.oid:
             ti = tx.adapters.types.get(dumper.oid)
             if ti:
-                # TODO: ugly encoding just to be decoded by as_string()
-                rv = b"%s::%s" % (rv, ti.name.encode(tx.encoding))
+                try:
+                    type_name = self._names_cache[ti.alt_name, tx.encoding]
+                except KeyError:
+                    type_name = ti.alt_name.encode(tx.encoding)
+                    self._names_cache[ti.alt_name, tx.encoding] = type_name
+                rv = b"%s::%s" % (rv, type_name)
         return rv
 
 
index 98629b3501e285f83830a75d694db041c28e9efd..d8ae837035f70de2cab5651daed62501ca2ae115 100644 (file)
@@ -10,6 +10,8 @@ import pytest
 from psycopg import pq, sql, ProgrammingError
 from psycopg.adapt import PyFormat
 from psycopg._encodings import py2pgenc
+from psycopg.types import TypeInfo
+from psycopg.types.string import StrDumper
 
 eur = "\u20ac"
 
@@ -337,6 +339,35 @@ class TestLiteral:
         with pytest.raises(ProgrammingError):
             sql.Literal(Foo()).as_string(conn)
 
+    @pytest.mark.parametrize("name", ["a-b", f"{eur}", "order"])
+    def test_invalid_name(self, conn, name):
+        conn.execute(
+            f"""
+            set client_encoding to utf8;
+            create type "{name}";
+            create function invin(cstring) returns "{name}"
+                language internal immutable strict as 'textin';
+            create function invout("{name}") returns cstring
+                language internal immutable strict as 'textout';
+            create type "{name}" (input=invin, output=invout, like=text);
+            """
+        )
+        info = TypeInfo.fetch(conn, f'"{name}"')
+
+        class InvDumper(StrDumper):
+            oid = info.oid
+
+            def dump(self, obj):
+                rv = super().dump(obj)
+                return b"%s-inv" % rv
+
+        info.register(conn)
+        conn.adapters.register_dumper(str, InvDumper)
+
+        assert sql.Literal("hello").as_string(conn) == f"'hello-inv'::\"{name}\""
+        cur = conn.execute(sql.SQL("select {}").format("hello"))
+        assert cur.fetchone()[0] == "hello-inv"
+
 
 class TestSQL:
     def test_class(self):