# the length of the result columns
self._row_loaders: List[LoadFunc] = []
+ # mapping oid -> type sql representation
+ self._oid_types: Dict[int, bytes] = {}
+
self._encoding = ""
@classmethod
return out
+ def as_literal(self, obj: Any) -> Buffer:
+ dumper = self.get_dumper(obj, PyFormat.TEXT)
+ rv = dumper.quote(obj)
+ # If the result is quoted, and the oid not unknown,
+ # add an explicit type cast.
+ # Check the last char because the first one might be 'E'.
+ if dumper.oid and rv and rv[-1] == b"'"[0]:
+ try:
+ type_sql = self._oid_types[dumper.oid]
+ except KeyError:
+ ti = self.adapters.types.get(dumper.oid)
+ if ti:
+ type_sql = ti.regtype.encode(self.encoding)
+ if dumper.oid == ti.array_oid:
+ type_sql += b"[]"
+ else:
+ type_sql = b""
+ self._oid_types[dumper.oid] = type_sql
+
+ if type_sql:
+ rv = b"%s::%s" % (rv, type_sql)
+
+ return rv
+
def get_dumper(self, obj: Any, format: PyFormat) -> "Dumper":
"""
Return a Dumper instance to dump *obj*.
) -> Sequence[Optional[Buffer]]:
...
+ def as_literal(self, obj: Any) -> Buffer:
+ ...
+
def get_dumper(self, obj: Any, format: PyFormat) -> Dumper:
...
import codecs
import string
from abc import ABC, abstractmethod
-from typing import Any, Dict, Iterator, Iterable, List
-from typing import Optional, Sequence, Union, Tuple
+from typing import Any, Iterator, Iterable, List, Optional, Sequence, Union
from .pq import Escaping
from .abc import AdaptContext
"""
- _names_cache: Dict[Tuple[str, str], bytes] = {}
-
def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
tx = Transformer.from_context(context)
- dumper = tx.get_dumper(self._obj, PyFormat.TEXT)
- rv = dumper.quote(self._obj)
- # If the result is quoted and the oid not unknown,
- # add an explicit type cast.
- if rv[-1] == b"'"[0] and dumper.oid:
- ti = tx.adapters.types.get(dumper.oid)
- if ti:
- try:
- type_name = self._names_cache[ti.regtype, tx.encoding]
- except KeyError:
- type_name = ti.regtype.encode(tx.encoding)
- self._names_cache[ti.regtype, tx.encoding] = type_name
- if dumper.oid == ti.array_oid:
- type_name += b"[]"
- rv = b"%s::%s" % (rv, type_name)
- return rv
+ return tx.as_literal(self._obj)
class Placeholder(Composable):
def dump_sequence(
self, params: Sequence[Any], formats: Sequence[PyFormat]
) -> Sequence[Optional[abc.Buffer]]: ...
+ def as_literal(self, obj: Any) -> abc.Buffer: ...
def get_dumper(self, obj: Any, format: PyFormat) -> abc.Dumper: ...
def load_rows(self, row0: int, row1: int, make_row: RowMaker[Row]) -> List[Row]: ...
def load_row(self, row: int, make_row: RowMaker[Row]) -> Optional[Row]: ...
cdef list _row_dumpers
cdef list _row_loaders
+ cdef dict _oid_types
+
def __cinit__(self, context: Optional["AdaptContext"] = None):
if context is not None:
self.adapters = context.adapters
self._row_loaders = loaders
+ cpdef as_literal(self, obj):
+ cdef PyObject *row_dumper = self.get_row_dumper(
+ <PyObject *>obj, <PyObject *>PG_TEXT)
+
+ if (<RowDumper>row_dumper).cdumper is not None:
+ dumper = (<RowDumper>row_dumper).cdumper
+ else:
+ dumper = (<RowDumper>row_dumper).pydumper
+
+ rv = dumper.quote(obj)
+ oid = dumper.oid
+ # If the result is quoted and the oid not unknown,
+ # add an explicit type cast.
+ # Check the last char because the first one might be 'E'.
+ if oid and rv and rv[-1] == 39:
+ if self._oid_types is None:
+ self._oid_types = {}
+ type_ptr = PyDict_GetItem(<object>self._oid_types, oid)
+ if type_ptr == NULL:
+ type_sql = b""
+ ti = self.adapters.types.get(oid)
+ if ti is not None:
+ type_sql = ti.regtype.encode(self.encoding)
+ if oid == ti.array_oid:
+ type_sql += b"[]"
+
+ type_ptr = <PyObject *>type_sql
+ PyDict_SetItem(<object>self._oid_types, oid, type_sql)
+
+ if <object>type_ptr:
+ rv = b"%s::%s" % (rv, <object>type_ptr)
+
+ return rv
+
def get_dumper(self, obj, format) -> "Dumper":
cdef PyObject *row_dumper = self.get_row_dumper(
<PyObject *>obj, <PyObject *>format)