From fe8c379cf1b7b188038c87ce310eaec572922004 Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Sun, 20 Mar 2022 19:24:53 +0100 Subject: [PATCH] feat: add `Transformer.as_literal()` to convert literals to sql This allows to keep a cache oid -> representation that will be invalidated automatically if encoding or oid mapping changes, as well as optimising the function in C. --- psycopg/psycopg/_transform.py | 27 ++++++++++++++++ psycopg/psycopg/abc.py | 3 ++ psycopg/psycopg/sql.py | 22 ++----------- psycopg_c/psycopg_c/_psycopg.pyi | 1 + psycopg_c/psycopg_c/_psycopg/transform.pyx | 36 ++++++++++++++++++++++ 5 files changed, 69 insertions(+), 20 deletions(-) diff --git a/psycopg/psycopg/_transform.py b/psycopg/psycopg/_transform.py index 76a9ca489..8509c93e3 100644 --- a/psycopg/psycopg/_transform.py +++ b/psycopg/psycopg/_transform.py @@ -76,6 +76,9 @@ class Transformer(AdaptContext): # the length of the result columns self._row_loaders: List[LoadFunc] = [] + # mapping oid -> type sql representation + self._oid_types: Dict[int, bytes] = {} + self._encoding = "" @classmethod @@ -181,6 +184,30 @@ class Transformer(AdaptContext): return out + def as_literal(self, obj: Any) -> Buffer: + dumper = self.get_dumper(obj, PyFormat.TEXT) + rv = dumper.quote(obj) + # If the result is quoted, and the oid not unknown, + # add an explicit type cast. + # Check the last char because the first one might be 'E'. + if dumper.oid and rv and rv[-1] == b"'"[0]: + try: + type_sql = self._oid_types[dumper.oid] + except KeyError: + ti = self.adapters.types.get(dumper.oid) + if ti: + type_sql = ti.regtype.encode(self.encoding) + if dumper.oid == ti.array_oid: + type_sql += b"[]" + else: + type_sql = b"" + self._oid_types[dumper.oid] = type_sql + + if type_sql: + rv = b"%s::%s" % (rv, type_sql) + + return rv + def get_dumper(self, obj: Any, format: PyFormat) -> "Dumper": """ Return a Dumper instance to dump *obj*. diff --git a/psycopg/psycopg/abc.py b/psycopg/psycopg/abc.py index 4716ab550..f4af55ac0 100644 --- a/psycopg/psycopg/abc.py +++ b/psycopg/psycopg/abc.py @@ -236,6 +236,9 @@ class Transformer(Protocol): ) -> Sequence[Optional[Buffer]]: ... + def as_literal(self, obj: Any) -> Buffer: + ... + def get_dumper(self, obj: Any, format: PyFormat) -> Dumper: ... diff --git a/psycopg/psycopg/sql.py b/psycopg/psycopg/sql.py index 9726af8b5..c111e2b74 100644 --- a/psycopg/psycopg/sql.py +++ b/psycopg/psycopg/sql.py @@ -7,8 +7,7 @@ SQL composition utility module import codecs import string from abc import ABC, abstractmethod -from typing import Any, Dict, Iterator, Iterable, List -from typing import Optional, Sequence, Union, Tuple +from typing import Any, Iterator, Iterable, List, Optional, Sequence, Union from .pq import Escaping from .abc import AdaptContext @@ -390,26 +389,9 @@ class Literal(Composable): """ - _names_cache: Dict[Tuple[str, str], bytes] = {} - def as_bytes(self, context: Optional[AdaptContext]) -> bytes: tx = Transformer.from_context(context) - dumper = tx.get_dumper(self._obj, PyFormat.TEXT) - rv = dumper.quote(self._obj) - # If the result is quoted and the oid not unknown, - # add an explicit type cast. - if rv[-1] == b"'"[0] and dumper.oid: - ti = tx.adapters.types.get(dumper.oid) - if ti: - try: - type_name = self._names_cache[ti.regtype, tx.encoding] - except KeyError: - type_name = ti.regtype.encode(tx.encoding) - self._names_cache[ti.regtype, tx.encoding] = type_name - if dumper.oid == ti.array_oid: - type_name += b"[]" - rv = b"%s::%s" % (rv, type_name) - return rv + return tx.as_literal(self._obj) class Placeholder(Composable): diff --git a/psycopg_c/psycopg_c/_psycopg.pyi b/psycopg_c/psycopg_c/_psycopg.pyi index cd150e3a1..a215eee58 100644 --- a/psycopg_c/psycopg_c/_psycopg.pyi +++ b/psycopg_c/psycopg_c/_psycopg.pyi @@ -44,6 +44,7 @@ class Transformer(abc.AdaptContext): def dump_sequence( self, params: Sequence[Any], formats: Sequence[PyFormat] ) -> Sequence[Optional[abc.Buffer]]: ... + def as_literal(self, obj: Any) -> abc.Buffer: ... def get_dumper(self, obj: Any, format: PyFormat) -> abc.Dumper: ... def load_rows(self, row0: int, row1: int, make_row: RowMaker[Row]) -> List[Row]: ... def load_row(self, row: int, make_row: RowMaker[Row]) -> Optional[Row]: ... diff --git a/psycopg_c/psycopg_c/_psycopg/transform.pyx b/psycopg_c/psycopg_c/_psycopg/transform.pyx index f7f0ffbda..c6e98355c 100644 --- a/psycopg_c/psycopg_c/_psycopg/transform.pyx +++ b/psycopg_c/psycopg_c/_psycopg/transform.pyx @@ -95,6 +95,8 @@ cdef class Transformer: cdef list _row_dumpers cdef list _row_loaders + cdef dict _oid_types + def __cinit__(self, context: Optional["AdaptContext"] = None): if context is not None: self.adapters = context.adapters @@ -202,6 +204,40 @@ cdef class Transformer: self._row_loaders = loaders + cpdef as_literal(self, obj): + cdef PyObject *row_dumper = self.get_row_dumper( + obj, PG_TEXT) + + if (row_dumper).cdumper is not None: + dumper = (row_dumper).cdumper + else: + dumper = (row_dumper).pydumper + + rv = dumper.quote(obj) + oid = dumper.oid + # If the result is quoted and the oid not unknown, + # add an explicit type cast. + # Check the last char because the first one might be 'E'. + if oid and rv and rv[-1] == 39: + if self._oid_types is None: + self._oid_types = {} + type_ptr = PyDict_GetItem(self._oid_types, oid) + if type_ptr == NULL: + type_sql = b"" + ti = self.adapters.types.get(oid) + if ti is not None: + type_sql = ti.regtype.encode(self.encoding) + if oid == ti.array_oid: + type_sql += b"[]" + + type_ptr = type_sql + PyDict_SetItem(self._oid_types, oid, type_sql) + + if type_ptr: + rv = b"%s::%s" % (rv, type_ptr) + + return rv + def get_dumper(self, obj, format) -> "Dumper": cdef PyObject *row_dumper = self.get_row_dumper( obj, format) -- 2.47.2