From: Daniele Varrazzo Date: Wed, 20 Jan 2021 01:46:45 +0000 (+0100) Subject: Adapt the dumper used to the value of the objects X-Git-Tag: 3.0.dev0~143 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=a246e3f294f43bc5bed27ec765b838ccff69be78;p=thirdparty%2Fpsycopg.git Adapt the dumper used to the value of the objects Added a second dispatch to allow a dumper to upgrade to a specialised version. Currently used to dump int to the smallest Postgres type holding that value and to dump lists of object into typed arrays. This change allows to write queries more naturally as no ``::int`` cast should be needed anymore e.g. in date + int or jsonb ->> int. Only Python implementation; C version to be implemented yet. --- diff --git a/psycopg3/psycopg3/_transform.py b/psycopg3/psycopg3/_transform.py index f0b61c335..a8e956fa2 100644 --- a/psycopg3/psycopg3/_transform.py +++ b/psycopg3/psycopg3/_transform.py @@ -4,8 +4,8 @@ Helper object to transform values between Python and PostgreSQL # Copyright (C) 2020-2021 The Psycopg Team -from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union -from typing import cast, DefaultDict, TYPE_CHECKING +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from typing import DefaultDict, TYPE_CHECKING from collections import defaultdict from . import pq @@ -18,9 +18,8 @@ if TYPE_CHECKING: from .pq.proto import PGresult from .adapt import Dumper, Loader, AdaptersMap from .connection import BaseConnection - from .types.array import BaseListDumper -DumperKey = Union[type, Tuple[type, type]] +DumperKey = Union[type, Tuple[type, ...]] DumperCache = Dict[DumperKey, "Dumper"] LoaderKey = int @@ -130,51 +129,33 @@ class Transformer(AdaptContext): return ps, tuple(ts), fs def get_dumper(self, obj: Any, format: Format) -> "Dumper": - # Fast path: return a Dumper class already instantiated from the same type - cls = type(obj) - if cls is not list: - key: DumperKey = cls - else: - # TODO: Can be probably generalised to handle other recursive types - subobj = self._find_list_element(obj) - key = (cls, type(subobj)) + """ + Return a Dumper instance to dump *obj*. + """ + # Normally, the type of the object dictates how to dump it + key = type(obj) + # Reuse an existing Dumper class for objects of the same type cache = self._dumpers_cache[format] try: - return cache[key] + dumper = cache[key] except KeyError: - pass - - # When dumping a string with %s we may refer to any type actually, - # but the user surely passed a text format - if cls is str and format == Format.AUTO: - format = Format.TEXT - - sub_dumper = None - if cls is list: - # It's not possible to declare an empty unknown array, so force text - if subobj is None: - format = Format.TEXT - - # If we are dumping a list it's the sub-object which should dictate - # what format to use. - else: - sub_dumper = self.get_dumper(subobj, format) - format = Format.from_pq(sub_dumper.format) - - dcls = self._adapters.get_dumper(cls, format) - if not dcls: - raise e.ProgrammingError( - f"cannot adapt type {cls.__name__}" - f" to format {Format(format).name}" - ) + # If it's the first time we see this type, look for a dumper + # configured for it. + dcls = self.adapters.get_dumper(key, format) + cache[key] = dumper = dcls(key, self) - d = dcls(cls, self) - if sub_dumper: - cast("BaseListDumper", d).set_sub_dumper(sub_dumper) + # Check if the dumper requires an upgrade to handle this specific value + key1 = dumper.get_key(obj, format) + if key1 is key: + return dumper - cache[key] = d - return d + # If it doesn't ask the dumper to create its own upgraded version + try: + return cache[key1] + except KeyError: + dumper = cache[key1] = dumper.upgrade(obj, format) + return dumper def load_rows(self, row0: int, row1: int) -> List[Tuple[Any, ...]]: res = self._pgresult @@ -241,26 +222,3 @@ class Transformer(AdaptContext): raise e.InterfaceError("unknown oid loader not found") loader = self._loaders_cache[format][oid] = loader_cls(oid, self) return loader - - def _find_list_element( - self, L: List[Any], seen: Optional[Set[int]] = None - ) -> Any: - """ - Find the first non-null element of an eventually nested list - """ - if not seen: - seen = set() - if id(L) in seen: - raise e.DataError("cannot dump a recursive list") - - seen.add(id(L)) - - for it in L: - if type(it) is list: - subit = self._find_list_element(it, seen) - if subit is not None: - return subit - elif it is not None: - return it - - return None diff --git a/psycopg3/psycopg3/adapt.py b/psycopg3/psycopg3/adapt.py index 7ae7cda76..fc8a9cbbf 100644 --- a/psycopg3/psycopg3/adapt.py +++ b/psycopg3/psycopg3/adapt.py @@ -5,11 +5,12 @@ Entry point into the adaptation system. # Copyright (C) 2020-2021 The Psycopg Team from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Type, TypeVar, Union -from typing import cast, TYPE_CHECKING +from typing import Any, Dict, List, Optional, Type, Tuple, Union +from typing import cast, TYPE_CHECKING, TypeVar from . import pq from . import proto +from . import errors as e from ._enums import Format as Format from .oids import builtins from .proto import AdaptContext, Buffer as Buffer @@ -56,6 +57,37 @@ class Dumper(ABC): esc = pq.Escaping() return b"'%s'" % esc.escape_string(value) + def get_key( + self, obj: Any, format: Format + ) -> Union[type, Tuple[type, ...]]: + """Return an alternative key to upgrade the dumper to represent *obj* + + Normally the type of the object is all it takes to define how to dump + the object to the database. In a few cases this is not enough. Example + + - Python int could be several Postgres types: int2, int4, int8, numeric + - Python lists should be dumped according to the type they contain + to convert them to e.g. array of strings, array of ints (which?...) + + In these cases a Dumper can implement `get_key()` and return a new + class, or sequence of classes, that can be used to indentify the same + dumper again. + + If a Dumper implements `get_key()` it should also implmement + `upgrade()`. + """ + return self.cls + + def upgrade(self, obj: Any, format: Format) -> "Dumper": + """Return a new dumper to manage *obj*. + + Once `Transformer.get_dumper()` has been notified that this Dumper + class cannot handle *obj* itself it will invoke `upgrade()`, which + should return a new `Dumper` instance, and will be reused for every + objects for which `get_key()` returns the same result. + """ + return self + @classmethod def register( this_cls, cls: Union[type, str], context: Optional[AdaptContext] = None @@ -171,17 +203,22 @@ class AdaptersMap(AdaptContext): self._loaders[fmt][oid] = loader - def get_dumper(self, cls: type, format: Format) -> Optional[Type[Dumper]]: + def get_dumper(self, cls: type, format: Format) -> Type[Dumper]: """ Return the dumper class for the given type and format. - Return None if not found. + Raise ProgrammingError if a class is not available. """ if format == Format.AUTO: - dmaps = [ - self._dumpers[pq.Format.BINARY], - self._dumpers[pq.Format.TEXT], - ] + # When dumping a string with %s we may refer to any type actually, + # but the user surely passed a text format + if cls is str: + dmaps = [self._dumpers[pq.Format.TEXT]] + else: + dmaps = [ + self._dumpers[pq.Format.BINARY], + self._dumpers[pq.Format.TEXT], + ] elif format == Format.BINARY: dmaps = [self._dumpers[pq.Format.BINARY]] elif format == Format.TEXT: @@ -203,7 +240,10 @@ class AdaptersMap(AdaptContext): d = dmap[scls] = dmap.pop(fqn) return d - return None + raise e.ProgrammingError( + f"cannot adapt type {cls.__name__}" + f" to format {Format(format).name}" + ) def get_loader( self, oid: int, format: pq.Format diff --git a/psycopg3/psycopg3/types/__init__.py b/psycopg3/psycopg3/types/__init__.py index 85aa130d6..1bc7fed4a 100644 --- a/psycopg3/psycopg3/types/__init__.py +++ b/psycopg3/psycopg3/types/__init__.py @@ -12,7 +12,7 @@ from . import array, composite from . import range # Wrapper objects -from .numeric import Int2, Int4, Int8, Oid +from .numeric import Int2, Int4, Int8, IntNumeric, Oid from .json import Json, Jsonb from .range import Range, Int4Range, Int8Range, DecimalRange from .range import DateRange, DateTimeRange, DateTimeTZRange @@ -34,16 +34,19 @@ from .text import ( ) from .numeric import ( IntDumper, + IntBinaryDumper, FloatDumper, FloatBinaryDumper, DecimalDumper, Int2Dumper, Int4Dumper, Int8Dumper, + IntNumericDumper, OidDumper, Int2BinaryDumper, Int4BinaryDumper, Int8BinaryDumper, + IntNumericBinaryDumper, OidBinaryDumper, IntLoader, Int2BinaryLoader, @@ -148,17 +151,19 @@ def register_default_globals(ctx: AdaptContext) -> None: ByteaBinaryLoader.register("bytea", ctx) IntDumper.register(int, ctx) + IntBinaryDumper.register(int, ctx) FloatDumper.register(float, ctx) - Int8BinaryDumper.register(int, ctx) FloatBinaryDumper.register(float, ctx) DecimalDumper.register("decimal.Decimal", ctx) Int2Dumper.register(Int2, ctx) Int4Dumper.register(Int4, ctx) Int8Dumper.register(Int8, ctx) + IntNumericDumper.register(IntNumeric, ctx) OidDumper.register(Oid, ctx) Int2BinaryDumper.register(Int2, ctx) Int4BinaryDumper.register(Int4, ctx) Int8BinaryDumper.register(Int8, ctx) + IntNumericBinaryDumper.register(IntNumeric, ctx) OidBinaryDumper.register(Oid, ctx) IntLoader.register("int2", ctx) IntLoader.register("int4", ctx) diff --git a/psycopg3/psycopg3/types/array.py b/psycopg3/psycopg3/types/array.py index 357f7432e..ec95b379d 100644 --- a/psycopg3/psycopg3/types/array.py +++ b/psycopg3/psycopg3/types/array.py @@ -6,34 +6,85 @@ Adapters for arrays import re import struct -from typing import Any, Iterator, List, Optional, Type +from typing import Any, Iterator, List, Optional, Set, Tuple, Type from .. import pq -from .._enums import Format from .. import errors as e from ..oids import builtins, TEXT_OID, TEXT_ARRAY_OID, INVALID_OID from ..adapt import Buffer, Dumper, Loader, Transformer +from ..adapt import Format as Pg3Format from ..proto import AdaptContext class BaseListDumper(Dumper): def __init__(self, cls: type, context: Optional[AdaptContext] = None): super().__init__(cls, context) - tx = Transformer(context) - fmt = Format.from_pq(self.format) - self.set_sub_dumper(tx.get_dumper("", fmt)) + self._tx = Transformer(context) + fmt = Pg3Format.from_pq(self.format) + self.sub_dumper = self._tx.get_dumper("", fmt) + self.sub_oid = TEXT_OID + + def get_key(self, obj: List[Any], format: Pg3Format) -> Tuple[type, ...]: + item = self._find_list_element(obj) + if item is not None: + sd = self._tx.get_dumper(item, format) + return (self.cls, sd.cls) + else: + return (self.cls,) + + def upgrade(self, obj: List[Any], format: Pg3Format) -> "BaseListDumper": + item = self._find_list_element(obj) + if item is None: + return ListDumper(self.cls, self._tx) + + sd = self._tx.get_dumper(item, format) + dcls = ListDumper if sd.format == pq.Format.TEXT else ListBinaryDumper + dumper = dcls(self.cls, self._tx) + dumper.sub_dumper = sd - def set_sub_dumper(self, dumper: Dumper) -> None: - self.sub_dumper = dumper # We consider an array of unknowns as unknown, so we can dump empty # lists or lists containing only None elements. However Postgres won't # take unknown for element oid (in binary; in text it doesn't matter) - if dumper.oid != INVALID_OID: - self.oid = self._get_array_oid(dumper.oid) - self.sub_oid = dumper.oid + if sd.oid != INVALID_OID: + dumper.oid = self._get_array_oid(sd.oid) + dumper.sub_oid = sd.oid else: - self.oid = INVALID_OID - self.sub_oid = TEXT_OID + dumper.oid = INVALID_OID + dumper.sub_oid = TEXT_OID + + return dumper + + def _find_list_element(self, L: List[Any]) -> Any: + """ + Find the first non-null element of an eventually nested list + """ + it = self._flatiter(L, set()) + try: + item = next(it) + except StopIteration: + return None + + if not isinstance(item, int): + return item + + imax = max((i if i >= 0 else -i - 1 for i in it), default=0) + imax = max(item if item >= 0 else -item, imax) + return imax + + def _flatiter(self, L: List[Any], seen: Set[int]) -> Any: + if id(L) in seen: + raise e.DataError("cannot dump a recursive list") + + seen.add(id(L)) + + for item in L: + if type(item) is list: + for subit in self._flatiter(item, seen): + yield subit + elif item is not None: + yield item + + return None def _get_array_oid(self, base_oid: int) -> int: """ diff --git a/psycopg3/psycopg3/types/numeric.py b/psycopg3/psycopg3/types/numeric.py index a8fd5e7cc..c5e969a61 100644 --- a/psycopg3/psycopg3/types/numeric.py +++ b/psycopg3/psycopg3/types/numeric.py @@ -5,12 +5,14 @@ Adapers for numeric types. # Copyright (C) 2020-2021 The Psycopg Team import struct -from typing import Any, Callable, Dict, Tuple, cast +from typing import Any, Callable, Dict, Optional, Tuple, cast from decimal import Decimal +from .. import proto from ..pq import Format from ..oids import builtins -from ..adapt import Buffer, Dumper, Loader +from ..adapt import Buffer, Dumper, Loader, Transformer +from ..adapt import Format as Pg3Format _PackInt = Callable[[int], bytes] _PackFloat = Callable[[float], bytes] @@ -48,11 +50,69 @@ class Int8(int): return super().__new__(cls, arg) # type: ignore +class IntNumeric(int): + def __new__(cls, arg: int) -> "IntNumeric": + return super().__new__(cls, arg) # type: ignore + + class Oid(int): def __new__(cls, arg: int) -> "Oid": return super().__new__(cls, arg) # type: ignore +class IntDumper(Dumper): + + format = Format.TEXT + + def __init__( + self, cls: type, context: Optional[proto.AdaptContext] = None + ): + super().__init__(cls, context) + self._tx = Transformer(context) + + def dump(self, obj: Any) -> bytes: + raise TypeError( + "dispatcher to find the int subclass: not supposed to be called" + ) + + def get_key(cls, obj: int, format: Pg3Format) -> type: + if -(2 ** 31) <= obj < 2 ** 31: + if -(2 ** 15) <= obj < 2 ** 15: + return Int2 + else: + return Int4 + else: + if -(2 ** 63) <= obj < 2 ** 63: + return Int8 + else: + return IntNumeric + + def upgrade(self, obj: int, format: Pg3Format) -> Dumper: + sample: Any + if -(2 ** 31) <= obj < 2 ** 31: + if -(2 ** 15) <= obj < 2 ** 15: + sample = INT2_SAMPLE + else: + sample = INT4_SAMPLE + else: + if -(2 ** 63) <= obj < 2 ** 63: + sample = INT8_SAMPLE + else: + sample = INTNUMERIC_SAMPLE + + return self._tx.get_dumper(sample, format) + + +class IntBinaryDumper(IntDumper): + format = Format.BINARY + + +INT2_SAMPLE = Int2(0) +INT4_SAMPLE = Int4(0) +INT8_SAMPLE = Int8(0) +INTNUMERIC_SAMPLE = IntNumeric(0) + + class NumberDumper(Dumper): format = Format.TEXT @@ -78,10 +138,6 @@ class SpecialValuesDumper(NumberDumper): return value if obj >= 0 else b" " + value -class IntDumper(NumberDumper): - _oid = builtins["int8"].oid - - class FloatDumper(SpecialValuesDumper): format = Format.TEXT @@ -126,6 +182,10 @@ class Int8Dumper(NumberDumper): _oid = builtins["int8"].oid +class IntNumericDumper(NumberDumper): + _oid = builtins["numeric"].oid + + class OidDumper(NumberDumper): _oid = builtins["oid"].oid @@ -154,6 +214,14 @@ class Int8BinaryDumper(Int8Dumper): return _pack_int8(obj) +class IntNumericBinaryDumper(IntNumericDumper): + + format = Format.BINARY + + def dump(self, obj: int) -> bytes: + raise NotImplementedError + + class OidBinaryDumper(OidDumper): format = Format.BINARY diff --git a/tests/test_adapt.py b/tests/test_adapt.py index 5c6023675..a32ea2e4f 100644 --- a/tests/test_adapt.py +++ b/tests/test_adapt.py @@ -9,7 +9,7 @@ from psycopg3.oids import builtins, TEXT_OID @pytest.mark.parametrize( "data, format, result, type", [ - (1, Format.TEXT, b"1", "int8"), + (1, Format.TEXT, b"1", "int2"), ("hello", Format.TEXT, b"hello", "text"), ("hello", Format.BINARY, b"hello", "text"), ], @@ -181,8 +181,8 @@ def test_array_dumper(conn, fmt_out): t = Transformer(conn) fmt_in = Format.from_pq(fmt_out) dint = t.get_dumper([0], fmt_in) - assert dint.oid == builtins["int8"].array_oid - assert dint.sub_oid == builtins["int8"].oid + assert dint.oid == builtins["int2"].array_oid + assert dint.sub_oid == builtins["int2"].oid dstr = t.get_dumper([""], fmt_in) assert dstr.oid == ( @@ -202,7 +202,7 @@ def test_array_dumper(conn, fmt_out): L = [] L.append(L) with pytest.raises(psycopg3.DataError): - assert t.get_dumper(L, fmt_out) + assert t.get_dumper(L, fmt_in) def test_string_connection_ctx(conn): diff --git a/tests/test_prepared.py b/tests/test_prepared.py index f53a4b2e6..bfe561148 100644 --- a/tests/test_prepared.py +++ b/tests/test_prepared.py @@ -131,7 +131,7 @@ def test_params_types(conn): ) cur = conn.execute("select parameter_types from pg_prepared_statements") (rec,) = cur.fetchall() - assert rec[0] == ["date", "bigint", "numeric"] + assert rec[0] == ["date", "smallint", "numeric"] def test_evict_lru(conn): @@ -180,7 +180,7 @@ def test_different_types(conn): "select parameter_types from pg_prepared_statements order by prepare_time", prepare=False, ) - assert cur.fetchall() == [(["text"],), (["date"],), (["bigint"],)] + assert cur.fetchall() == [(["text"],), (["date"],), (["smallint"],)] def test_untyped_json(conn): diff --git a/tests/test_prepared_async.py b/tests/test_prepared_async.py index cb3127cd4..231dbfe1a 100644 --- a/tests/test_prepared_async.py +++ b/tests/test_prepared_async.py @@ -139,7 +139,7 @@ async def test_params_types(aconn): "select parameter_types from pg_prepared_statements" ) (rec,) = await cur.fetchall() - assert rec[0] == ["date", "bigint", "numeric"] + assert rec[0] == ["date", "smallint", "numeric"] async def test_evict_lru(aconn): @@ -190,7 +190,7 @@ async def test_different_types(aconn): "select parameter_types from pg_prepared_statements order by prepare_time", prepare=False, ) - assert await cur.fetchall() == [(["text"],), (["date"],), (["bigint"],)] + assert await cur.fetchall() == [(["text"],), (["date"],), (["smallint"],)] async def test_untyped_json(aconn): diff --git a/tests/test_query.py b/tests/test_query.py index 7eb2dcba7..3c596f5ba 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -113,7 +113,7 @@ def test_pg_query_seq(query, params, want, wformats, wparams): {"hi": 0, "there": "a"}, b"select $1 $2 $1", [pq.Format.BINARY, pq.Format.TEXT], - [b"\x00" * 8, b"a"], + [b"\x00" * 2, b"a"], ), ], ) diff --git a/tests/types/test_array.py b/tests/types/test_array.py index 46089c278..f1b19ae7a 100644 --- a/tests/types/test_array.py +++ b/tests/types/test_array.py @@ -139,12 +139,10 @@ def test_array_of_unknown_builtin(conn): assert res[1] == [val] -@pytest.mark.xfail @pytest.mark.parametrize( "array, type", [([1, 32767], "int2"), ([1, 32768], "int4")] ) def test_array_mixed_numbers(array, type): - # TODO: must use the type accommodating the largest/highest precision tx = Transformer() dumper = tx.get_dumper(array, Format.BINARY) dumper.dump(array) diff --git a/tests/types/test_numeric.py b/tests/types/test_numeric.py index a794b7897..e02e3db81 100644 --- a/tests/types/test_numeric.py +++ b/tests/types/test_numeric.py @@ -3,7 +3,6 @@ from math import isnan, isinf, exp import pytest -import psycopg3 from psycopg3 import pq from psycopg3 import sql from psycopg3.oids import builtins @@ -39,28 +38,33 @@ def test_dump_int(conn, val, expr, fmt_in): @pytest.mark.parametrize( "val, expr", [ - (0, "'0'::integer"), - (1, "'1'::integer"), - (-1, "'-1'::integer"), + (0, "'0'::smallint"), + (1, "'1'::smallint"), + (-1, "'-1'::smallint"), (42, "'42'::smallint"), (-42, "'-42'::smallint"), - (int(2 ** 63 - 1), "'9223372036854775807'::bigint"), - (int(-(2 ** 63)), "'-9223372036854775808'::bigint"), - (0, "'0'::oid"), - (4294967295, "'4294967295'::oid"), + (int(2 ** 15 - 1), f"'{2 ** 15 - 1}'::smallint"), + (int(-(2 ** 15)), f"'{-2 ** 15}'::smallint"), + (int(2 ** 15), f"'{2 ** 15}'::integer"), + (int(-(2 ** 15) - 1), f"'{-2 ** 15 - 1}'::integer"), + (int(2 ** 31 - 1), f"'{2 ** 31 - 1}'::integer"), + (int(-(2 ** 31)), f"'{-2 ** 31}'::integer"), + (int(2 ** 31), f"'{2 ** 31}'::bigint"), + (int(-(2 ** 31) - 1), f"'{-2 ** 31 - 1}'::bigint"), + (int(2 ** 63 - 1), f"'{2 ** 63 - 1}'::bigint"), + (int(-(2 ** 63)), f"'{-2 ** 63}'::bigint"), + (int(2 ** 63), f"'{2 ** 63}'::numeric"), + (int(-(2 ** 63) - 1), f"'{-2 ** 63 - 1}'::numeric"), ], ) -@pytest.mark.parametrize("fmt_in", [Format.TEXT, Format.BINARY]) +@pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT, Format.BINARY]) def test_dump_int_subtypes(conn, val, expr, fmt_in): - tname = builtins[expr.rsplit(":", 1)[-1]].name.title() - assert tname in "Int2 Int4 Int8 Oid".split() - Type = getattr(psycopg3.types.numeric, tname) + if fmt_in in (Format.AUTO, Format.BINARY) and "numeric" in expr: + pytest.xfail("binary numeric not implemented") cur = conn.cursor() - cur.execute( - f"select pg_typeof({expr}) = pg_typeof(%{fmt_in})", (Type(val),) - ) + cur.execute(f"select pg_typeof({expr}) = pg_typeof(%{fmt_in})", (val,)) assert cur.fetchone()[0] is True - cur.execute(f"select {expr} = %{fmt_in}", (Type(val),)) + cur.execute(f"select {expr} = %{fmt_in}", (val,)) assert cur.fetchone()[0] is True