]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
feat: add `Transformer.as_literal()` to convert literals to sql
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 20 Mar 2022 18:24:53 +0000 (19:24 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 10 May 2022 17:13:26 +0000 (19:13 +0200)
This allows to keep a cache oid -> representation that will be
invalidated automatically if encoding or oid mapping changes, as well as
optimising the function in C.

psycopg/psycopg/_transform.py
psycopg/psycopg/abc.py
psycopg/psycopg/sql.py
psycopg_c/psycopg_c/_psycopg.pyi
psycopg_c/psycopg_c/_psycopg/transform.pyx

index 76a9ca48988fdd38bed84eb5c3da7b5f6bc627fc..8509c93e3ae99e13a0bbd2526b71553619b13f1b 100644 (file)
@@ -76,6 +76,9 @@ class Transformer(AdaptContext):
         # the length of the result columns
         self._row_loaders: List[LoadFunc] = []
 
+        # mapping oid -> type sql representation
+        self._oid_types: Dict[int, bytes] = {}
+
         self._encoding = ""
 
     @classmethod
@@ -181,6 +184,30 @@ class Transformer(AdaptContext):
 
         return out
 
+    def as_literal(self, obj: Any) -> Buffer:
+        dumper = self.get_dumper(obj, PyFormat.TEXT)
+        rv = dumper.quote(obj)
+        # If the result is quoted, and the oid not unknown,
+        # add an explicit type cast.
+        # Check the last char because the first one might be 'E'.
+        if dumper.oid and rv and rv[-1] == b"'"[0]:
+            try:
+                type_sql = self._oid_types[dumper.oid]
+            except KeyError:
+                ti = self.adapters.types.get(dumper.oid)
+                if ti:
+                    type_sql = ti.regtype.encode(self.encoding)
+                    if dumper.oid == ti.array_oid:
+                        type_sql += b"[]"
+                else:
+                    type_sql = b""
+                self._oid_types[dumper.oid] = type_sql
+
+            if type_sql:
+                rv = b"%s::%s" % (rv, type_sql)
+
+        return rv
+
     def get_dumper(self, obj: Any, format: PyFormat) -> "Dumper":
         """
         Return a Dumper instance to dump *obj*.
index 4716ab5504fd47be064dab297a0cbe04b75008ca..f4af55ac0758cfb714e236117a709af41a8caf8a 100644 (file)
@@ -236,6 +236,9 @@ class Transformer(Protocol):
     ) -> Sequence[Optional[Buffer]]:
         ...
 
+    def as_literal(self, obj: Any) -> Buffer:
+        ...
+
     def get_dumper(self, obj: Any, format: PyFormat) -> Dumper:
         ...
 
index 9726af8b5f342808e145413ef7c9e722cfdeb32f..c111e2b74fe1c8ca237051bf1eb26505cfba103a 100644 (file)
@@ -7,8 +7,7 @@ SQL composition utility module
 import codecs
 import string
 from abc import ABC, abstractmethod
-from typing import Any, Dict, Iterator, Iterable, List
-from typing import Optional, Sequence, Union, Tuple
+from typing import Any, Iterator, Iterable, List, Optional, Sequence, Union
 
 from .pq import Escaping
 from .abc import AdaptContext
@@ -390,26 +389,9 @@ class Literal(Composable):
 
     """
 
-    _names_cache: Dict[Tuple[str, str], bytes] = {}
-
     def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
         tx = Transformer.from_context(context)
-        dumper = tx.get_dumper(self._obj, PyFormat.TEXT)
-        rv = dumper.quote(self._obj)
-        # If the result is quoted and the oid not unknown,
-        # add an explicit type cast.
-        if rv[-1] == b"'"[0] and dumper.oid:
-            ti = tx.adapters.types.get(dumper.oid)
-            if ti:
-                try:
-                    type_name = self._names_cache[ti.regtype, tx.encoding]
-                except KeyError:
-                    type_name = ti.regtype.encode(tx.encoding)
-                    self._names_cache[ti.regtype, tx.encoding] = type_name
-                if dumper.oid == ti.array_oid:
-                    type_name += b"[]"
-                rv = b"%s::%s" % (rv, type_name)
-        return rv
+        return tx.as_literal(self._obj)
 
 
 class Placeholder(Composable):
index cd150e3a19d4fcea25eda9ae30187a07b2af6f2d..a215eee581cdd9fc8364d76ef42bba7b6d74f6d5 100644 (file)
@@ -44,6 +44,7 @@ class Transformer(abc.AdaptContext):
     def dump_sequence(
         self, params: Sequence[Any], formats: Sequence[PyFormat]
     ) -> Sequence[Optional[abc.Buffer]]: ...
+    def as_literal(self, obj: Any) -> abc.Buffer: ...
     def get_dumper(self, obj: Any, format: PyFormat) -> abc.Dumper: ...
     def load_rows(self, row0: int, row1: int, make_row: RowMaker[Row]) -> List[Row]: ...
     def load_row(self, row: int, make_row: RowMaker[Row]) -> Optional[Row]: ...
index f7f0ffbdacf3198c5ef86b146b9211608c062407..c6e98355cccd64d9744a868cbb95c63c67b918f0 100644 (file)
@@ -95,6 +95,8 @@ cdef class Transformer:
     cdef list _row_dumpers
     cdef list _row_loaders
 
+    cdef dict _oid_types
+
     def __cinit__(self, context: Optional["AdaptContext"] = None):
         if context is not None:
             self.adapters = context.adapters
@@ -202,6 +204,40 @@ cdef class Transformer:
 
         self._row_loaders = loaders
 
+    cpdef as_literal(self, obj):
+        cdef PyObject *row_dumper = self.get_row_dumper(
+            <PyObject *>obj, <PyObject *>PG_TEXT)
+
+        if (<RowDumper>row_dumper).cdumper is not None:
+            dumper = (<RowDumper>row_dumper).cdumper
+        else:
+            dumper = (<RowDumper>row_dumper).pydumper
+
+        rv = dumper.quote(obj)
+        oid = dumper.oid
+        # If the result is quoted and the oid not unknown,
+        # add an explicit type cast.
+        # Check the last char because the first one might be 'E'.
+        if oid and rv and rv[-1] == 39:
+            if self._oid_types is None:
+                self._oid_types = {}
+            type_ptr = PyDict_GetItem(<object>self._oid_types, oid)
+            if type_ptr == NULL:
+                type_sql = b""
+                ti = self.adapters.types.get(oid)
+                if ti is not None:
+                    type_sql = ti.regtype.encode(self.encoding)
+                    if oid == ti.array_oid:
+                        type_sql += b"[]"
+
+                type_ptr = <PyObject *>type_sql
+                PyDict_SetItem(<object>self._oid_types, oid, type_sql)
+
+            if <object>type_ptr:
+                rv = b"%s::%s" % (rv, <object>type_ptr)
+
+        return rv
+
     def get_dumper(self, obj, format) -> "Dumper":
         cdef PyObject *row_dumper = self.get_row_dumper(
             <PyObject *>obj, <PyObject *>format)