]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Dropped MaybeOid type and optional second return value from Dumper.dump()
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 6 Aug 2020 02:18:53 +0000 (03:18 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 23 Aug 2020 18:24:02 +0000 (19:24 +0100)
12 files changed:
psycopg3/psycopg3/adapt.py
psycopg3/psycopg3/dbapi20.py
psycopg3/psycopg3/proto.py
psycopg3/psycopg3/transform.py
psycopg3/psycopg3/types/array.py
psycopg3/psycopg3/types/composite.py
psycopg3/psycopg3/types/numeric.py
psycopg3/psycopg3/types/text.py
psycopg3_c/psycopg3_c/_psycopg3.pyi
psycopg3_c/psycopg3_c/transform.pyx
tests/test_adapt.py
tests/types/test_array.py

index 2b43b1265851a3f5a1f0224efaa1ea2494883f6c..89429ac3a341b59d48092d88e77a3bac5d06cd98 100644 (file)
@@ -4,7 +4,7 @@ Entry point into the adaptation system.
 
 # Copyright (C) 2020 The Psycopg Team
 
-from typing import Any, Callable, Optional, Tuple, Type, Union
+from typing import Any, Callable, Optional, Type
 
 from . import pq
 from . import proto
@@ -26,7 +26,7 @@ class Dumper:
         self.context = context
         self.connection = _connection_from_context(context)
 
-    def dump(self, obj: Any) -> Union[bytes, Tuple[bytes, int]]:
+    def dump(self, obj: Any) -> bytes:
         raise NotImplementedError()
 
     @property
index 2c0763478e8dae3b98665bc5e9a70834142d5ecb..941bdb6b12ee780c5a3d77f597b3f558dc8086d9 100644 (file)
@@ -7,11 +7,13 @@ Compatibility objects with DBAPI 2.0
 import time
 import datetime as dt
 from math import floor
-from typing import Any, Sequence, Tuple
+from typing import Any, Sequence
 
 from .types.oids import builtins
 from .adapt import Dumper
 
+BYTEA_OID = builtins["bytea"].oid
+
 
 class DBAPITypeObject:
     def __init__(self, name: str, type_names: Sequence[str]):
@@ -52,12 +54,16 @@ class Binary:
 
 @Dumper.text(Binary)
 class TextBinaryDumper(Dumper):
-    def dump(self, obj: Binary) -> Tuple[bytes, int]:
-        rv = obj.obj
-        if not isinstance(rv, bytes):
-            rv = bytes(rv)
+    def dump(self, obj: Binary) -> bytes:
+        wrapped = obj.obj
+        if isinstance(wrapped, bytes):
+            return wrapped
+        else:
+            return bytes(wrapped)
 
-        return rv, builtins["bytea"].oid
+    @property
+    def oid(self) -> int:
+        return BYTEA_OID
 
 
 def Date(year: int, month: int, day: int) -> dt.date:
index 0b996adc48262c4863d88d5c7be85ee5503705e6..51df11fd08e8d8fcfeb6a298722d1dba3fa5ede0 100644 (file)
@@ -36,8 +36,7 @@ PQGen = Generator[Tuple[int, "Wait"], "Ready", RV]
 
 AdaptContext = Union[None, "BaseConnection", "BaseCursor", "Transformer"]
 
-MaybeOid = Union[Optional[bytes], Tuple[Optional[bytes], int]]
-DumpFunc = Callable[[Any], MaybeOid]
+DumpFunc = Callable[[Any], bytes]
 DumperType = Type["Dumper"]
 DumpersMap = Dict[Tuple[type, Format], DumperType]
 
@@ -85,7 +84,10 @@ class Transformer(Protocol):
     def types_sequence(self) -> List[int]:
         ...
 
-    def dump(self, obj: None, format: Format = Format.TEXT) -> MaybeOid:
+    def dump(self, obj: Any, format: Format = Format.TEXT) -> Optional[bytes]:
+        ...
+
+    def get_dumper(self, obj: Any, format: Format) -> "Dumper":
         ...
 
     def lookup_dumper(self, src: type, format: Format) -> DumperType:
index 2e8e41722dfe717f4e78a543788cea2b9f79ae9c..f106ef4cc698c1dc47c97da153b437cb88ec6cf8 100644 (file)
@@ -11,7 +11,7 @@ from typing import TYPE_CHECKING
 from . import errors as e
 from . import pq
 from .proto import AdaptContext, DumpersMap, DumperType
-from .proto import LoadFunc, LoadersMap, LoaderType, MaybeOid
+from .proto import LoadFunc, LoadersMap, LoaderType
 from .cursor import BaseCursor
 from .connection import BaseConnection
 from .types.oids import builtins, INVALID_OID
@@ -151,12 +151,8 @@ class Transformer:
 
         for var, fmt in zip(objs, formats):
             if var is not None:
-                dumper = self.get_dumper(type(var), fmt)
-                data = dumper.dump(var)
-                if isinstance(data, tuple):
-                    data = data[0]
-
-                out.append(data)
+                dumper = self.get_dumper(var, fmt)
+                out.append(dumper.dump(var))
                 oids.append(dumper.oid)
             else:
                 out.append(None)
@@ -169,22 +165,21 @@ class Transformer:
         return self._oids
 
     # TODO: drop?
-    def dump(self, obj: None, format: Format = Format.TEXT) -> MaybeOid:
+    def dump(self, obj: Any, format: Format = Format.TEXT) -> Optional[bytes]:
         if obj is None:
-            return None, TEXT_OID
+            return None
 
-        dumper = self.get_dumper(type(obj), format)
-        return dumper.dump(obj)
+        return self.get_dumper(obj, format).dump(obj)
 
-    def get_dumper(self, src: type, format: Format) -> "Dumper":
-        key = (src, format)
+    def get_dumper(self, obj: Any, format: Format) -> "Dumper":
+        key = (type(obj), format)
         try:
             return self._dumpers_cache[key]
         except KeyError:
             pass
 
-        dumper_cls = self.lookup_dumper(src, format)
-        self._dumpers_cache[key] = dumper = dumper_cls(src, self)
+        dumper_cls = self.lookup_dumper(*key)
+        self._dumpers_cache[key] = dumper = dumper_cls(key[0], self)
         return dumper
 
     def lookup_dumper(self, src: type, format: Format) -> DumperType:
index edc1edaf3cda6525776b5974561d29b3c81af674..daffd254392f2f9be43071e0e415b30a6bda655e 100644 (file)
@@ -6,7 +6,7 @@ Adapters for arrays
 
 import re
 import struct
-from typing import Any, Generator, List, Optional, Tuple
+from typing import Any, Generator, List, Optional
 
 from .. import errors as e
 from ..adapt import Format, Dumper, Loader, Transformer
@@ -21,9 +21,13 @@ class BaseListDumper(Dumper):
     def __init__(self, src: type, context: AdaptContext = None):
         super().__init__(src, context)
         self._tx = Transformer(context)
-        self._oid = 0
+        self._array_oid = 0
 
-    def _array_oid(self, base_oid: int) -> int:
+    @property
+    def oid(self) -> int:
+        return self._array_oid or TEXT_ARRAY_OID
+
+    def _get_array_oid(self, base_oid: int) -> int:
         """
         Return the oid of the array from the oid of the base item.
 
@@ -60,10 +64,13 @@ class TextListDumper(BaseListDumper):
     # backslash-escaped.
     _re_escape = re.compile(br'(["\\])')
 
-    def dump(self, obj: List[Any]) -> Tuple[bytes, int]:
+    def dump(self, obj: List[Any]) -> bytes:
         tokens: List[bytes] = []
+        oid: Optional[int] = None
 
         def dump_list(obj: List[Any]) -> None:
+            nonlocal oid
+
             if not obj:
                 tokens.append(b"{}")
                 return
@@ -72,29 +79,16 @@ class TextListDumper(BaseListDumper):
             for item in obj:
                 if isinstance(item, list):
                     dump_list(item)
-                elif item is None:
-                    tokens.append(b"NULL")
+                elif item is not None:
+                    dumper = self._tx.get_dumper(item, Format.TEXT)
+                    ad = dumper.dump(item)
+                    if self._re_needs_quotes.search(ad) is not None:
+                        ad = b'"' + self._re_escape.sub(br"\\\1", ad) + b'"'
+                    tokens.append(ad)
+                    if oid is None:
+                        oid = dumper.oid
                 else:
-                    ad = self._tx.dump(item)
-                    if isinstance(ad, tuple):
-                        if not self._oid:
-                            self._oid = ad[1]
-                            got_type = type(item)
-                        elif self._oid != ad[1]:
-                            raise e.DataError(
-                                f"array contains different types,"
-                                f" at least {got_type} and {type(item)}"
-                            )
-                        ad = ad[0]
-
-                    if ad is not None:
-                        if self._re_needs_quotes.search(ad) is not None:
-                            ad = (
-                                b'"' + self._re_escape.sub(br"\\\1", ad) + b'"'
-                            )
-                        tokens.append(ad)
-                    else:
-                        tokens.append(b"NULL")
+                    tokens.append(b"NULL")
 
                 tokens.append(b",")
 
@@ -102,22 +96,22 @@ class TextListDumper(BaseListDumper):
 
         dump_list(obj)
 
-        return b"".join(tokens), self._array_oid(self._oid)
+        if oid is not None:
+            self._array_oid = self._get_array_oid(oid)
 
-    @property
-    def oid(self) -> int:
-        return self._array_oid(self._oid) if self._oid else TEXT_ARRAY_OID
+        return b"".join(tokens)
 
 
 @Dumper.binary(list)
 class BinaryListDumper(BaseListDumper):
-    def dump(self, obj: List[Any]) -> Tuple[bytes, int]:
+    def dump(self, obj: List[Any]) -> bytes:
         if not obj:
-            return _struct_head.pack(0, 0, TEXT_OID), TEXT_ARRAY_OID
+            return _struct_head.pack(0, 0, TEXT_OID)
 
         data: List[bytes] = [b"", b""]  # placeholders to avoid a resize
         dims: List[int] = []
         hasnull = 0
+        oid: Optional[int] = None
 
         def calc_dims(L: List[Any]) -> None:
             if isinstance(L, self.src):
@@ -129,29 +123,22 @@ class BinaryListDumper(BaseListDumper):
         calc_dims(obj)
 
         def dump_list(L: List[Any], dim: int) -> None:
-            nonlocal hasnull
+            nonlocal oid, hasnull
             if len(L) != dims[dim]:
                 raise e.DataError("nested lists have inconsistent lengths")
 
             if dim == len(dims) - 1:
                 for item in L:
-                    ad = self._tx.dump(item, Format.BINARY)
-                    if isinstance(ad, tuple):
-                        if not self._oid:
-                            self._oid = ad[1]
-                            got_type = type(item)
-                        elif self._oid != ad[1]:
-                            raise e.DataError(
-                                f"array contains different types,"
-                                f" at least {got_type} and {type(item)}"
-                            )
-                        ad = ad[0]
-                    if ad is None:
-                        hasnull = 1
-                        data.append(b"\xff\xff\xff\xff")
-                    else:
+                    if item is not None:
+                        dumper = self._tx.get_dumper(item, Format.BINARY)
+                        ad = dumper.dump(item)
                         data.append(_struct_len.pack(len(ad)))
                         data.append(ad)
+                        if oid is None:
+                            oid = dumper.oid
+                    else:
+                        hasnull = 1
+                        data.append(b"\xff\xff\xff\xff")
             else:
                 for item in L:
                     if not isinstance(item, self.src):
@@ -162,16 +149,14 @@ class BinaryListDumper(BaseListDumper):
 
         dump_list(obj, 0)
 
-        if not self._oid:
-            self._oid = TEXT_OID
+        if oid is None:
+            oid = TEXT_OID
 
-        data[0] = _struct_head.pack(len(dims), hasnull, self._oid)
-        data[1] = b"".join(_struct_dim.pack(dim, 1) for dim in dims)
-        return b"".join(data), self._array_oid(self._oid)
+        self._array_oid = self._get_array_oid(oid)
 
-    @property
-    def oid(self) -> int:
-        return self._array_oid(self._oid) if self._oid else TEXT_ARRAY_OID
+        data[0] = _struct_head.pack(len(dims), hasnull, oid)
+        data[1] = b"".join(_struct_dim.pack(dim, 1) for dim in dims)
+        return b"".join(data)
 
 
 class BaseArrayLoader(Loader):
index 3b29c34622788e3bf804a0378ba8d09b20e1ff52..c6b037ff139247dd7572ca53876002a9656f84f4 100644 (file)
@@ -127,9 +127,9 @@ class TextTupleDumper(Dumper):
         super().__init__(src, context)
         self._tx = Transformer(context)
 
-    def dump(self, obj: Tuple[Any, ...]) -> Tuple[bytes, int]:
+    def dump(self, obj: Tuple[Any, ...]) -> bytes:
         if not obj:
-            return b"()", TEXT_OID
+            return b"()"
 
         parts = [b"("]
 
@@ -138,13 +138,8 @@ class TextTupleDumper(Dumper):
                 parts.append(b",")
                 continue
 
-            ad = self._tx.dump(item)
-            if isinstance(ad, tuple):
-                ad = ad[0]
-            if ad is None:
-                parts.append(b",")
-                continue
-
+            dumper = self._tx.get_dumper(item, Format.TEXT)
+            ad = dumper.dump(item)
             if self._re_needs_quotes.search(ad) is not None:
                 ad = b'"' + self._re_escape.sub(br"\1\1", ad) + b'"'
 
@@ -153,7 +148,7 @@ class TextTupleDumper(Dumper):
 
         parts[-1] = b")"
 
-        return b"".join(parts), TEXT_OID
+        return b"".join(parts)
 
     _re_needs_quotes = re.compile(
         br"""(?xi)
index 82bf981869e9f56454ac03755803ecaa0fd89de6..de021d86c82537d00786241620212bb3c356ccb0 100644 (file)
@@ -7,7 +7,6 @@ Adapers for numeric types.
 import codecs
 import struct
 from decimal import Decimal
-from typing import Tuple
 
 from ..adapt import Dumper, Loader
 from .oids import builtins
@@ -29,9 +28,9 @@ _float8_struct = struct.Struct("!d")
 
 @Dumper.text(int)
 class TextIntDumper(Dumper):
-    def dump(self, obj: int) -> Tuple[bytes, int]:
+    def dump(self, obj: int) -> bytes:
         # We don't know the size of it, so we have to return a type big enough
-        return _encode(str(obj))[0], NUMERIC_OID
+        return _encode(str(obj))[0]
 
     @property
     def oid(self) -> int:
@@ -40,9 +39,9 @@ class TextIntDumper(Dumper):
 
 @Dumper.text(float)
 class TextFloatDumper(Dumper):
-    def dump(self, obj: float) -> Tuple[bytes, int]:
+    def dump(self, obj: float) -> bytes:
         # Float can't be bigger than this instead
-        return _encode(str(obj))[0], FLOAT8_OID
+        return _encode(str(obj))[0]
 
     @property
     def oid(self) -> int:
@@ -51,28 +50,18 @@ class TextFloatDumper(Dumper):
 
 @Dumper.text(Decimal)
 class TextDecimalDumper(Dumper):
-    def dump(self, obj: Decimal) -> Tuple[bytes, int]:
-        return _encode(str(obj))[0], NUMERIC_OID
+    def dump(self, obj: Decimal) -> bytes:
+        return _encode(str(obj))[0]
 
     @property
     def oid(self) -> int:
         return NUMERIC_OID
 
 
-_bool_dump = {
-    True: (b"t", builtins["bool"].oid),
-    False: (b"f", builtins["bool"].oid),
-}
-_bool_binary_dump = {
-    True: (b"\x01", builtins["bool"].oid),
-    False: (b"\x00", builtins["bool"].oid),
-}
-
-
 @Dumper.text(bool)
 class TextBoolDumper(Dumper):
-    def dump(self, obj: bool) -> Tuple[bytes, int]:
-        return _bool_dump[obj]
+    def dump(self, obj: bool) -> bytes:
+        return b"t" if obj else b"f"
 
     @property
     def oid(self) -> int:
@@ -81,8 +70,8 @@ class TextBoolDumper(Dumper):
 
 @Dumper.binary(bool)
 class BinaryBoolDumper(Dumper):
-    def dump(self, obj: bool) -> Tuple[bytes, int]:
-        return _bool_binary_dump[obj]
+    def dump(self, obj: bool) -> bytes:
+        return b"\x01" if obj else b"\x00"
 
     @property
     def oid(self) -> int:
index 5773db38a251920ac3040e5033cdbb2cbf82280a..0353ce1080e27fda01eb74e9823b97a146bc808a 100644 (file)
@@ -5,7 +5,7 @@ Adapters for textual types.
 # Copyright (C) 2020 The Psycopg Team
 
 import codecs
-from typing import Optional, Tuple, Union
+from typing import Optional, Union
 
 from ..adapt import Dumper, Loader
 from ..proto import AdaptContext, EncodeFunc, DecodeFunc
@@ -90,8 +90,8 @@ class BytesDumper(Dumper):
             self.connection.pgconn if self.connection is not None else None
         )
 
-    def dump(self, obj: bytes) -> Tuple[bytes, int]:
-        return self.esc.escape_bytea(obj), BYTEA_OID
+    def dump(self, obj: bytes) -> bytes:
+        return self.esc.escape_bytea(obj)
 
     @property
     def oid(self) -> int:
@@ -100,8 +100,8 @@ class BytesDumper(Dumper):
 
 @Dumper.binary(bytes)
 class BinaryBytesDumper(Dumper):
-    def dump(self, b: bytes) -> Tuple[bytes, int]:
-        return b, BYTEA_OID
+    def dump(self, b: bytes) -> bytes:
+        return b
 
     @property
     def oid(self) -> int:
index 4226c2dd643ef25d0fc9d207eb5a25245a067b73..f0e54b1149baa6e57611e90ec0c6696f7777f3f5 100644 (file)
@@ -10,8 +10,9 @@ information. Will submit a bug.
 import codecs
 from typing import Any, Iterable, List, Optional, Sequence, Tuple
 
+from psycopg3.adapt import Dumper
 from psycopg3.proto import AdaptContext, DumpFunc, DumpersMap, DumperType
-from psycopg3.proto import LoadFunc, LoadersMap, LoaderType, MaybeOid, PQGen
+from psycopg3.proto import LoadFunc, LoadersMap, LoaderType, PQGen
 from psycopg3.connection import BaseConnection
 from psycopg3 import pq
 
@@ -37,8 +38,9 @@ class Transformer:
     ) -> List[Optional[bytes]]: ...
     def types_sequence(self) -> List[int]: ...
     def dump(
-        self, obj: None, format: pq.Format = pq.Format.TEXT
-    ) -> MaybeOid: ...
+        self, obj: Any, format: pq.Format = pq.Format.TEXT
+    ) -> Optional[bytes]: ...
+    def get_dumper(self, obj: Any, format: pq.Format) -> "Dumper": ...
     def lookup_dumper(self, src: type, format: pq.Format) -> DumperType: ...
     def load_row(self, row: int) -> Optional[Tuple[Any, ...]]: ...
     def load_sequence(
index 98e2461b3ad182427d5697ed94b651db13662b82..310b1cba6612576b3da4e4155ed68fc5eda7ddd7 100644 (file)
@@ -227,12 +227,8 @@ cdef class Transformer:
 
         for var, fmt in zip(objs, formats):
             if var is not None:
-                dumper = self.get_dumper(type(var), fmt)
-                data = dumper.dump(var)
-                if isinstance(data, tuple):
-                    data = data[0]
-
-                out.append(data)
+                dumper = self.get_dumper(var, fmt)
+                out.append(dumper.dump(var))
                 oids.append(dumper.oid)
             else:
                 out.append(None)
@@ -243,22 +239,21 @@ cdef class Transformer:
     def types_sequence(self) -> List[int]:
         return self._oids
 
-    def dump(self, obj: None, format: Format = 0) -> "MaybeOid":
+    def dump(self, obj: Any, format: Format = 0) -> Optional[bytes]:
         if obj is None:
-            return None, TEXT_OID
+            return None
 
-        dumper = self.get_dumper(type(obj), format)
-        return dumper.dump(obj)
+        return self.get_dumper(obj, format).dump(obj)
 
-    def get_dumper(self, src: type, format: Format) -> "Dumper":
-        key = (src, format)
+    def get_dumper(self, obj: Any, format: Format) -> "Dumper":
+        key = (type(obj), format)
         try:
             return self._dumpers_cache[key]
         except KeyError:
             pass
 
-        dumper_cls = self.lookup_dumper(src, format)
-        self._dumpers_cache[key] = dumper = dumper_cls(src, self)
+        dumper_cls = self.lookup_dumper(*key)
+        self._dumpers_cache[key] = dumper = dumper_cls(key[0], self)
         return dumper
 
     def lookup_dumper(self, src: type, format: Format) -> "DumperType":
index 6a26fc89f5a4ee8c7cde9bc2f24025882f768211..940e025370fc7e03b8ef5e639f75fb93de6f8280 100644 (file)
@@ -8,8 +8,6 @@ TEXT_OID = builtins["text"].oid
 @pytest.mark.parametrize(
     "data, format, result, type",
     [
-        (None, Format.TEXT, None, "text"),
-        (None, Format.BINARY, None, "text"),
         (1, Format.TEXT, b"1", "numeric"),
         ("hello", Format.TEXT, b"hello", "text"),
         ("hello", Format.BINARY, b"hello", "text"),
@@ -17,12 +15,9 @@ TEXT_OID = builtins["text"].oid
 )
 def test_dump(data, format, result, type):
     t = Transformer()
-    rv = t.dump(data, format)
-    if isinstance(rv, tuple):
-        assert rv[0] == result
-        assert rv[1] == builtins[type].oid
-    else:
-        assert rv == result
+    dumper = t.get_dumper(data, format)
+    assert dumper.dump(data) == result
+    assert dumper.oid == builtins[type].oid
 
 
 def make_dumper(suffix):
index 0ad2884244ea9b7288b39d54a006e90e06e66a73..9816c2f81932036b846c1e659247d703c95bd9f5 100644 (file)
@@ -82,7 +82,7 @@ def test_dump_list_int(conn, obj, want):
         [[]],
         [[["a"]], ["b"]],
         # [["a"], [["b"]]],  # todo, but expensive (an isinstance per item)
-        [True, b"a"],
+        # [True, b"a"], # TODO expensive too
     ],
 )
 def test_bad_binary_array(input):