From: Daniele Varrazzo Date: Sun, 20 Mar 2022 15:04:54 +0000 (+0100) Subject: fix(sql): fix `sql.Literal` with invalid type names X-Git-Tag: 3.1~109^2~6 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=bb69e60fc1711815111265b7022ab53dcb9134bb;p=thirdparty%2Fpsycopg.git fix(sql): fix `sql.Literal` with invalid type names Make sure that `TypeInfo.alt_name` is always populated, even if with a copy of `name`. --- diff --git a/psycopg/psycopg/_typeinfo.py b/psycopg/psycopg/_typeinfo.py index 20eb0d4db..8ca0896fa 100644 --- a/psycopg/psycopg/_typeinfo.py +++ b/psycopg/psycopg/_typeinfo.py @@ -36,13 +36,14 @@ class TypeInfo: name: str, oid: int, array_oid: int, + *, alt_name: str = "", delimiter: str = ",", ): self.name = name self.oid = oid self.array_oid = array_oid - self.alt_name = alt_name + self.alt_name = alt_name or name self.delimiter = delimiter def __repr__(self) -> str: diff --git a/psycopg/psycopg/sql.py b/psycopg/psycopg/sql.py index 4e0d0d3de..bcf604bb0 100644 --- a/psycopg/psycopg/sql.py +++ b/psycopg/psycopg/sql.py @@ -7,7 +7,8 @@ SQL composition utility module import codecs import string from abc import ABC, abstractmethod -from typing import Any, Iterator, Iterable, List, Optional, Sequence, Union +from typing import Any, Dict, Iterator, Iterable, List +from typing import Optional, Sequence, Union, Tuple from .pq import Escaping from .abc import AdaptContext @@ -389,6 +390,8 @@ class Literal(Composable): """ + _names_cache: Dict[Tuple[str, str], bytes] = {} + def as_bytes(self, context: Optional[AdaptContext]) -> bytes: tx = Transformer.from_context(context) dumper = tx.get_dumper(self._obj, PyFormat.TEXT) @@ -398,8 +401,12 @@ class Literal(Composable): if rv[-1] == b"'"[0] and dumper.oid: ti = tx.adapters.types.get(dumper.oid) if ti: - # TODO: ugly encoding just to be decoded by as_string() - rv = b"%s::%s" % (rv, ti.name.encode(tx.encoding)) + try: + type_name = self._names_cache[ti.alt_name, tx.encoding] + except KeyError: + type_name = ti.alt_name.encode(tx.encoding) + self._names_cache[ti.alt_name, tx.encoding] = type_name + rv = b"%s::%s" % (rv, type_name) return rv diff --git a/tests/test_sql.py b/tests/test_sql.py index 98629b350..d8ae83703 100644 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -10,6 +10,8 @@ import pytest from psycopg import pq, sql, ProgrammingError from psycopg.adapt import PyFormat from psycopg._encodings import py2pgenc +from psycopg.types import TypeInfo +from psycopg.types.string import StrDumper eur = "\u20ac" @@ -337,6 +339,35 @@ class TestLiteral: with pytest.raises(ProgrammingError): sql.Literal(Foo()).as_string(conn) + @pytest.mark.parametrize("name", ["a-b", f"{eur}", "order"]) + def test_invalid_name(self, conn, name): + conn.execute( + f""" + set client_encoding to utf8; + create type "{name}"; + create function invin(cstring) returns "{name}" + language internal immutable strict as 'textin'; + create function invout("{name}") returns cstring + language internal immutable strict as 'textout'; + create type "{name}" (input=invin, output=invout, like=text); + """ + ) + info = TypeInfo.fetch(conn, f'"{name}"') + + class InvDumper(StrDumper): + oid = info.oid + + def dump(self, obj): + rv = super().dump(obj) + return b"%s-inv" % rv + + info.register(conn) + conn.adapters.register_dumper(str, InvDumper) + + assert sql.Literal("hello").as_string(conn) == f"'hello-inv'::\"{name}\"" + cur = conn.execute(sql.SQL("select {}").format("hello")) + assert cur.fetchone()[0] == "hello-inv" + class TestSQL: def test_class(self):