# Copyright (C) 2020-2021 The Psycopg Team
-from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union
-from typing import cast, DefaultDict, TYPE_CHECKING
+from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
+from typing import DefaultDict, TYPE_CHECKING
from collections import defaultdict
from . import pq
from .pq.proto import PGresult
from .adapt import Dumper, Loader, AdaptersMap
from .connection import BaseConnection
- from .types.array import BaseListDumper
-DumperKey = Union[type, Tuple[type, type]]
+DumperKey = Union[type, Tuple[type, ...]]
DumperCache = Dict[DumperKey, "Dumper"]
LoaderKey = int
return ps, tuple(ts), fs
def get_dumper(self, obj: Any, format: Format) -> "Dumper":
- # Fast path: return a Dumper class already instantiated from the same type
- cls = type(obj)
- if cls is not list:
- key: DumperKey = cls
- else:
- # TODO: Can be probably generalised to handle other recursive types
- subobj = self._find_list_element(obj)
- key = (cls, type(subobj))
+ """
+ Return a Dumper instance to dump *obj*.
+ """
+ # Normally, the type of the object dictates how to dump it
+ key = type(obj)
+ # Reuse an existing Dumper class for objects of the same type
cache = self._dumpers_cache[format]
try:
- return cache[key]
+ dumper = cache[key]
except KeyError:
- pass
-
- # When dumping a string with %s we may refer to any type actually,
- # but the user surely passed a text format
- if cls is str and format == Format.AUTO:
- format = Format.TEXT
-
- sub_dumper = None
- if cls is list:
- # It's not possible to declare an empty unknown array, so force text
- if subobj is None:
- format = Format.TEXT
-
- # If we are dumping a list it's the sub-object which should dictate
- # what format to use.
- else:
- sub_dumper = self.get_dumper(subobj, format)
- format = Format.from_pq(sub_dumper.format)
-
- dcls = self._adapters.get_dumper(cls, format)
- if not dcls:
- raise e.ProgrammingError(
- f"cannot adapt type {cls.__name__}"
- f" to format {Format(format).name}"
- )
+ # If it's the first time we see this type, look for a dumper
+ # configured for it.
+ dcls = self.adapters.get_dumper(key, format)
+ cache[key] = dumper = dcls(key, self)
- d = dcls(cls, self)
- if sub_dumper:
- cast("BaseListDumper", d).set_sub_dumper(sub_dumper)
+ # Check if the dumper requires an upgrade to handle this specific value
+ key1 = dumper.get_key(obj, format)
+ if key1 is key:
+ return dumper
- cache[key] = d
- return d
+ # If it doesn't ask the dumper to create its own upgraded version
+ try:
+ return cache[key1]
+ except KeyError:
+ dumper = cache[key1] = dumper.upgrade(obj, format)
+ return dumper
def load_rows(self, row0: int, row1: int) -> List[Tuple[Any, ...]]:
res = self._pgresult
raise e.InterfaceError("unknown oid loader not found")
loader = self._loaders_cache[format][oid] = loader_cls(oid, self)
return loader
-
- def _find_list_element(
- self, L: List[Any], seen: Optional[Set[int]] = None
- ) -> Any:
- """
- Find the first non-null element of an eventually nested list
- """
- if not seen:
- seen = set()
- if id(L) in seen:
- raise e.DataError("cannot dump a recursive list")
-
- seen.add(id(L))
-
- for it in L:
- if type(it) is list:
- subit = self._find_list_element(it, seen)
- if subit is not None:
- return subit
- elif it is not None:
- return it
-
- return None
# Copyright (C) 2020-2021 The Psycopg Team
from abc import ABC, abstractmethod
-from typing import Any, Dict, List, Optional, Type, TypeVar, Union
-from typing import cast, TYPE_CHECKING
+from typing import Any, Dict, List, Optional, Type, Tuple, Union
+from typing import cast, TYPE_CHECKING, TypeVar
from . import pq
from . import proto
+from . import errors as e
from ._enums import Format as Format
from .oids import builtins
from .proto import AdaptContext, Buffer as Buffer
esc = pq.Escaping()
return b"'%s'" % esc.escape_string(value)
+ def get_key(
+ self, obj: Any, format: Format
+ ) -> Union[type, Tuple[type, ...]]:
+ """Return an alternative key to upgrade the dumper to represent *obj*
+
+ Normally the type of the object is all it takes to define how to dump
+ the object to the database. In a few cases this is not enough. Example
+
+ - Python int could be several Postgres types: int2, int4, int8, numeric
+ - Python lists should be dumped according to the type they contain
+ to convert them to e.g. array of strings, array of ints (which?...)
+
+ In these cases a Dumper can implement `get_key()` and return a new
+ class, or sequence of classes, that can be used to indentify the same
+ dumper again.
+
+ If a Dumper implements `get_key()` it should also implmement
+ `upgrade()`.
+ """
+ return self.cls
+
+ def upgrade(self, obj: Any, format: Format) -> "Dumper":
+ """Return a new dumper to manage *obj*.
+
+ Once `Transformer.get_dumper()` has been notified that this Dumper
+ class cannot handle *obj* itself it will invoke `upgrade()`, which
+ should return a new `Dumper` instance, and will be reused for every
+ objects for which `get_key()` returns the same result.
+ """
+ return self
+
@classmethod
def register(
this_cls, cls: Union[type, str], context: Optional[AdaptContext] = None
self._loaders[fmt][oid] = loader
- def get_dumper(self, cls: type, format: Format) -> Optional[Type[Dumper]]:
+ def get_dumper(self, cls: type, format: Format) -> Type[Dumper]:
"""
Return the dumper class for the given type and format.
- Return None if not found.
+ Raise ProgrammingError if a class is not available.
"""
if format == Format.AUTO:
- dmaps = [
- self._dumpers[pq.Format.BINARY],
- self._dumpers[pq.Format.TEXT],
- ]
+ # When dumping a string with %s we may refer to any type actually,
+ # but the user surely passed a text format
+ if cls is str:
+ dmaps = [self._dumpers[pq.Format.TEXT]]
+ else:
+ dmaps = [
+ self._dumpers[pq.Format.BINARY],
+ self._dumpers[pq.Format.TEXT],
+ ]
elif format == Format.BINARY:
dmaps = [self._dumpers[pq.Format.BINARY]]
elif format == Format.TEXT:
d = dmap[scls] = dmap.pop(fqn)
return d
- return None
+ raise e.ProgrammingError(
+ f"cannot adapt type {cls.__name__}"
+ f" to format {Format(format).name}"
+ )
def get_loader(
self, oid: int, format: pq.Format
from . import range
# Wrapper objects
-from .numeric import Int2, Int4, Int8, Oid
+from .numeric import Int2, Int4, Int8, IntNumeric, Oid
from .json import Json, Jsonb
from .range import Range, Int4Range, Int8Range, DecimalRange
from .range import DateRange, DateTimeRange, DateTimeTZRange
)
from .numeric import (
IntDumper,
+ IntBinaryDumper,
FloatDumper,
FloatBinaryDumper,
DecimalDumper,
Int2Dumper,
Int4Dumper,
Int8Dumper,
+ IntNumericDumper,
OidDumper,
Int2BinaryDumper,
Int4BinaryDumper,
Int8BinaryDumper,
+ IntNumericBinaryDumper,
OidBinaryDumper,
IntLoader,
Int2BinaryLoader,
ByteaBinaryLoader.register("bytea", ctx)
IntDumper.register(int, ctx)
+ IntBinaryDumper.register(int, ctx)
FloatDumper.register(float, ctx)
- Int8BinaryDumper.register(int, ctx)
FloatBinaryDumper.register(float, ctx)
DecimalDumper.register("decimal.Decimal", ctx)
Int2Dumper.register(Int2, ctx)
Int4Dumper.register(Int4, ctx)
Int8Dumper.register(Int8, ctx)
+ IntNumericDumper.register(IntNumeric, ctx)
OidDumper.register(Oid, ctx)
Int2BinaryDumper.register(Int2, ctx)
Int4BinaryDumper.register(Int4, ctx)
Int8BinaryDumper.register(Int8, ctx)
+ IntNumericBinaryDumper.register(IntNumeric, ctx)
OidBinaryDumper.register(Oid, ctx)
IntLoader.register("int2", ctx)
IntLoader.register("int4", ctx)
import re
import struct
-from typing import Any, Iterator, List, Optional, Type
+from typing import Any, Iterator, List, Optional, Set, Tuple, Type
from .. import pq
-from .._enums import Format
from .. import errors as e
from ..oids import builtins, TEXT_OID, TEXT_ARRAY_OID, INVALID_OID
from ..adapt import Buffer, Dumper, Loader, Transformer
+from ..adapt import Format as Pg3Format
from ..proto import AdaptContext
class BaseListDumper(Dumper):
def __init__(self, cls: type, context: Optional[AdaptContext] = None):
super().__init__(cls, context)
- tx = Transformer(context)
- fmt = Format.from_pq(self.format)
- self.set_sub_dumper(tx.get_dumper("", fmt))
+ self._tx = Transformer(context)
+ fmt = Pg3Format.from_pq(self.format)
+ self.sub_dumper = self._tx.get_dumper("", fmt)
+ self.sub_oid = TEXT_OID
+
+ def get_key(self, obj: List[Any], format: Pg3Format) -> Tuple[type, ...]:
+ item = self._find_list_element(obj)
+ if item is not None:
+ sd = self._tx.get_dumper(item, format)
+ return (self.cls, sd.cls)
+ else:
+ return (self.cls,)
+
+ def upgrade(self, obj: List[Any], format: Pg3Format) -> "BaseListDumper":
+ item = self._find_list_element(obj)
+ if item is None:
+ return ListDumper(self.cls, self._tx)
+
+ sd = self._tx.get_dumper(item, format)
+ dcls = ListDumper if sd.format == pq.Format.TEXT else ListBinaryDumper
+ dumper = dcls(self.cls, self._tx)
+ dumper.sub_dumper = sd
- def set_sub_dumper(self, dumper: Dumper) -> None:
- self.sub_dumper = dumper
# We consider an array of unknowns as unknown, so we can dump empty
# lists or lists containing only None elements. However Postgres won't
# take unknown for element oid (in binary; in text it doesn't matter)
- if dumper.oid != INVALID_OID:
- self.oid = self._get_array_oid(dumper.oid)
- self.sub_oid = dumper.oid
+ if sd.oid != INVALID_OID:
+ dumper.oid = self._get_array_oid(sd.oid)
+ dumper.sub_oid = sd.oid
else:
- self.oid = INVALID_OID
- self.sub_oid = TEXT_OID
+ dumper.oid = INVALID_OID
+ dumper.sub_oid = TEXT_OID
+
+ return dumper
+
+ def _find_list_element(self, L: List[Any]) -> Any:
+ """
+ Find the first non-null element of an eventually nested list
+ """
+ it = self._flatiter(L, set())
+ try:
+ item = next(it)
+ except StopIteration:
+ return None
+
+ if not isinstance(item, int):
+ return item
+
+ imax = max((i if i >= 0 else -i - 1 for i in it), default=0)
+ imax = max(item if item >= 0 else -item, imax)
+ return imax
+
+ def _flatiter(self, L: List[Any], seen: Set[int]) -> Any:
+ if id(L) in seen:
+ raise e.DataError("cannot dump a recursive list")
+
+ seen.add(id(L))
+
+ for item in L:
+ if type(item) is list:
+ for subit in self._flatiter(item, seen):
+ yield subit
+ elif item is not None:
+ yield item
+
+ return None
def _get_array_oid(self, base_oid: int) -> int:
"""
# Copyright (C) 2020-2021 The Psycopg Team
import struct
-from typing import Any, Callable, Dict, Tuple, cast
+from typing import Any, Callable, Dict, Optional, Tuple, cast
from decimal import Decimal
+from .. import proto
from ..pq import Format
from ..oids import builtins
-from ..adapt import Buffer, Dumper, Loader
+from ..adapt import Buffer, Dumper, Loader, Transformer
+from ..adapt import Format as Pg3Format
_PackInt = Callable[[int], bytes]
_PackFloat = Callable[[float], bytes]
return super().__new__(cls, arg) # type: ignore
+class IntNumeric(int):
+ def __new__(cls, arg: int) -> "IntNumeric":
+ return super().__new__(cls, arg) # type: ignore
+
+
class Oid(int):
def __new__(cls, arg: int) -> "Oid":
return super().__new__(cls, arg) # type: ignore
+class IntDumper(Dumper):
+
+ format = Format.TEXT
+
+ def __init__(
+ self, cls: type, context: Optional[proto.AdaptContext] = None
+ ):
+ super().__init__(cls, context)
+ self._tx = Transformer(context)
+
+ def dump(self, obj: Any) -> bytes:
+ raise TypeError(
+ "dispatcher to find the int subclass: not supposed to be called"
+ )
+
+ def get_key(cls, obj: int, format: Pg3Format) -> type:
+ if -(2 ** 31) <= obj < 2 ** 31:
+ if -(2 ** 15) <= obj < 2 ** 15:
+ return Int2
+ else:
+ return Int4
+ else:
+ if -(2 ** 63) <= obj < 2 ** 63:
+ return Int8
+ else:
+ return IntNumeric
+
+ def upgrade(self, obj: int, format: Pg3Format) -> Dumper:
+ sample: Any
+ if -(2 ** 31) <= obj < 2 ** 31:
+ if -(2 ** 15) <= obj < 2 ** 15:
+ sample = INT2_SAMPLE
+ else:
+ sample = INT4_SAMPLE
+ else:
+ if -(2 ** 63) <= obj < 2 ** 63:
+ sample = INT8_SAMPLE
+ else:
+ sample = INTNUMERIC_SAMPLE
+
+ return self._tx.get_dumper(sample, format)
+
+
+class IntBinaryDumper(IntDumper):
+ format = Format.BINARY
+
+
+INT2_SAMPLE = Int2(0)
+INT4_SAMPLE = Int4(0)
+INT8_SAMPLE = Int8(0)
+INTNUMERIC_SAMPLE = IntNumeric(0)
+
+
class NumberDumper(Dumper):
format = Format.TEXT
return value if obj >= 0 else b" " + value
-class IntDumper(NumberDumper):
- _oid = builtins["int8"].oid
-
-
class FloatDumper(SpecialValuesDumper):
format = Format.TEXT
_oid = builtins["int8"].oid
+class IntNumericDumper(NumberDumper):
+ _oid = builtins["numeric"].oid
+
+
class OidDumper(NumberDumper):
_oid = builtins["oid"].oid
return _pack_int8(obj)
+class IntNumericBinaryDumper(IntNumericDumper):
+
+ format = Format.BINARY
+
+ def dump(self, obj: int) -> bytes:
+ raise NotImplementedError
+
+
class OidBinaryDumper(OidDumper):
format = Format.BINARY
@pytest.mark.parametrize(
"data, format, result, type",
[
- (1, Format.TEXT, b"1", "int8"),
+ (1, Format.TEXT, b"1", "int2"),
("hello", Format.TEXT, b"hello", "text"),
("hello", Format.BINARY, b"hello", "text"),
],
t = Transformer(conn)
fmt_in = Format.from_pq(fmt_out)
dint = t.get_dumper([0], fmt_in)
- assert dint.oid == builtins["int8"].array_oid
- assert dint.sub_oid == builtins["int8"].oid
+ assert dint.oid == builtins["int2"].array_oid
+ assert dint.sub_oid == builtins["int2"].oid
dstr = t.get_dumper([""], fmt_in)
assert dstr.oid == (
L = []
L.append(L)
with pytest.raises(psycopg3.DataError):
- assert t.get_dumper(L, fmt_out)
+ assert t.get_dumper(L, fmt_in)
def test_string_connection_ctx(conn):
)
cur = conn.execute("select parameter_types from pg_prepared_statements")
(rec,) = cur.fetchall()
- assert rec[0] == ["date", "bigint", "numeric"]
+ assert rec[0] == ["date", "smallint", "numeric"]
def test_evict_lru(conn):
"select parameter_types from pg_prepared_statements order by prepare_time",
prepare=False,
)
- assert cur.fetchall() == [(["text"],), (["date"],), (["bigint"],)]
+ assert cur.fetchall() == [(["text"],), (["date"],), (["smallint"],)]
def test_untyped_json(conn):
"select parameter_types from pg_prepared_statements"
)
(rec,) = await cur.fetchall()
- assert rec[0] == ["date", "bigint", "numeric"]
+ assert rec[0] == ["date", "smallint", "numeric"]
async def test_evict_lru(aconn):
"select parameter_types from pg_prepared_statements order by prepare_time",
prepare=False,
)
- assert await cur.fetchall() == [(["text"],), (["date"],), (["bigint"],)]
+ assert await cur.fetchall() == [(["text"],), (["date"],), (["smallint"],)]
async def test_untyped_json(aconn):
{"hi": 0, "there": "a"},
b"select $1 $2 $1",
[pq.Format.BINARY, pq.Format.TEXT],
- [b"\x00" * 8, b"a"],
+ [b"\x00" * 2, b"a"],
),
],
)
assert res[1] == [val]
-@pytest.mark.xfail
@pytest.mark.parametrize(
"array, type", [([1, 32767], "int2"), ([1, 32768], "int4")]
)
def test_array_mixed_numbers(array, type):
- # TODO: must use the type accommodating the largest/highest precision
tx = Transformer()
dumper = tx.get_dumper(array, Format.BINARY)
dumper.dump(array)
import pytest
-import psycopg3
from psycopg3 import pq
from psycopg3 import sql
from psycopg3.oids import builtins
@pytest.mark.parametrize(
"val, expr",
[
- (0, "'0'::integer"),
- (1, "'1'::integer"),
- (-1, "'-1'::integer"),
+ (0, "'0'::smallint"),
+ (1, "'1'::smallint"),
+ (-1, "'-1'::smallint"),
(42, "'42'::smallint"),
(-42, "'-42'::smallint"),
- (int(2 ** 63 - 1), "'9223372036854775807'::bigint"),
- (int(-(2 ** 63)), "'-9223372036854775808'::bigint"),
- (0, "'0'::oid"),
- (4294967295, "'4294967295'::oid"),
+ (int(2 ** 15 - 1), f"'{2 ** 15 - 1}'::smallint"),
+ (int(-(2 ** 15)), f"'{-2 ** 15}'::smallint"),
+ (int(2 ** 15), f"'{2 ** 15}'::integer"),
+ (int(-(2 ** 15) - 1), f"'{-2 ** 15 - 1}'::integer"),
+ (int(2 ** 31 - 1), f"'{2 ** 31 - 1}'::integer"),
+ (int(-(2 ** 31)), f"'{-2 ** 31}'::integer"),
+ (int(2 ** 31), f"'{2 ** 31}'::bigint"),
+ (int(-(2 ** 31) - 1), f"'{-2 ** 31 - 1}'::bigint"),
+ (int(2 ** 63 - 1), f"'{2 ** 63 - 1}'::bigint"),
+ (int(-(2 ** 63)), f"'{-2 ** 63}'::bigint"),
+ (int(2 ** 63), f"'{2 ** 63}'::numeric"),
+ (int(-(2 ** 63) - 1), f"'{-2 ** 63 - 1}'::numeric"),
],
)
-@pytest.mark.parametrize("fmt_in", [Format.TEXT, Format.BINARY])
+@pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT, Format.BINARY])
def test_dump_int_subtypes(conn, val, expr, fmt_in):
- tname = builtins[expr.rsplit(":", 1)[-1]].name.title()
- assert tname in "Int2 Int4 Int8 Oid".split()
- Type = getattr(psycopg3.types.numeric, tname)
+ if fmt_in in (Format.AUTO, Format.BINARY) and "numeric" in expr:
+ pytest.xfail("binary numeric not implemented")
cur = conn.cursor()
- cur.execute(
- f"select pg_typeof({expr}) = pg_typeof(%{fmt_in})", (Type(val),)
- )
+ cur.execute(f"select pg_typeof({expr}) = pg_typeof(%{fmt_in})", (val,))
assert cur.fetchone()[0] is True
- cur.execute(f"select {expr} = %{fmt_in}", (Type(val),))
+ cur.execute(f"select {expr} = %{fmt_in}", (val,))
assert cur.fetchone()[0] is True