name: str,
oid: int,
array_oid: int,
+ *,
alt_name: str = "",
delimiter: str = ",",
):
self.name = name
self.oid = oid
self.array_oid = array_oid
- self.alt_name = alt_name
+ self.alt_name = alt_name or name
self.delimiter = delimiter
def __repr__(self) -> str:
import codecs
import string
from abc import ABC, abstractmethod
-from typing import Any, Iterator, Iterable, List, Optional, Sequence, Union
+from typing import Any, Dict, Iterator, Iterable, List
+from typing import Optional, Sequence, Union, Tuple
from .pq import Escaping
from .abc import AdaptContext
"""
+ _names_cache: Dict[Tuple[str, str], bytes] = {}
+
def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
tx = Transformer.from_context(context)
dumper = tx.get_dumper(self._obj, PyFormat.TEXT)
if rv[-1] == b"'"[0] and dumper.oid:
ti = tx.adapters.types.get(dumper.oid)
if ti:
- # TODO: ugly encoding just to be decoded by as_string()
- rv = b"%s::%s" % (rv, ti.name.encode(tx.encoding))
+ try:
+ type_name = self._names_cache[ti.alt_name, tx.encoding]
+ except KeyError:
+ type_name = ti.alt_name.encode(tx.encoding)
+ self._names_cache[ti.alt_name, tx.encoding] = type_name
+ rv = b"%s::%s" % (rv, type_name)
return rv
from psycopg import pq, sql, ProgrammingError
from psycopg.adapt import PyFormat
from psycopg._encodings import py2pgenc
+from psycopg.types import TypeInfo
+from psycopg.types.string import StrDumper
eur = "\u20ac"
with pytest.raises(ProgrammingError):
sql.Literal(Foo()).as_string(conn)
+ @pytest.mark.parametrize("name", ["a-b", f"{eur}", "order"])
+ def test_invalid_name(self, conn, name):
+ conn.execute(
+ f"""
+ set client_encoding to utf8;
+ create type "{name}";
+ create function invin(cstring) returns "{name}"
+ language internal immutable strict as 'textin';
+ create function invout("{name}") returns cstring
+ language internal immutable strict as 'textout';
+ create type "{name}" (input=invin, output=invout, like=text);
+ """
+ )
+ info = TypeInfo.fetch(conn, f'"{name}"')
+
+ class InvDumper(StrDumper):
+ oid = info.oid
+
+ def dump(self, obj):
+ rv = super().dump(obj)
+ return b"%s-inv" % rv
+
+ info.register(conn)
+ conn.adapters.register_dumper(str, InvDumper)
+
+ assert sql.Literal("hello").as_string(conn) == f"'hello-inv'::\"{name}\""
+ cur = conn.execute(sql.SQL("select {}").format("hello"))
+ assert cur.fetchone()[0] == "hello-inv"
+
class TestSQL:
def test_class(self):