]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add RecursiveDumper, RecursiveLoader
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 8 Jun 2021 20:27:40 +0000 (21:27 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 25 Jun 2021 15:16:26 +0000 (16:16 +0100)
psycopg3/psycopg3/adapt.py
psycopg3/psycopg3/types/array.py
psycopg3/psycopg3/types/composite.py
psycopg3/psycopg3/types/range.py

index 43f2c6c1b8561fec3ee89373be44c3f45bfe6dda..360201d20189b9fb27d1a592c6a86a4df5742fb2 100644 (file)
@@ -300,3 +300,19 @@ else:
     from . import _transform
 
     Transformer = _transform.Transformer
+
+
+class RecursiveDumper(Dumper):
+    """Dumper with a transformer to help dumping recursive types."""
+
+    def __init__(self, cls: type, context: Optional[AdaptContext] = None):
+        super().__init__(cls, context)
+        self._tx = Transformer(context)
+
+
+class RecursiveLoader(Loader):
+    """Loader with a transformer to help loading recursive types."""
+
+    def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+        super().__init__(oid, context)
+        self._tx = Transformer(context)
index 9d133d899bc65d2ee99d246336697ea79b25bd7b..5090eb593a4322ebd153fd3a8355cb57791894a6 100644 (file)
@@ -12,9 +12,9 @@ from typing import cast
 from .. import pq
 from .. import errors as e
 from ..oids import postgres_types, TEXT_OID, TEXT_ARRAY_OID, INVALID_OID
-from ..adapt import Buffer, Dumper, Loader, Transformer
+from ..adapt import Dumper, RecursiveDumper, RecursiveLoader
 from ..adapt import Format as Pg3Format
-from ..proto import AdaptContext
+from ..proto import AdaptContext, Buffer
 from .._struct import pack_len, unpack_len
 from .._typeinfo import TypeInfo
 
@@ -30,10 +30,9 @@ _unpack_dim = cast(
 )
 
 
-class BaseListDumper(Dumper):
+class BaseListDumper(RecursiveDumper):
     def __init__(self, cls: type, context: Optional[AdaptContext] = None):
         super().__init__(cls, context)
-        self._tx = Transformer(context)
         self.sub_dumper: Optional[Dumper] = None
         self._types = context.adapters.types if context else postgres_types
 
@@ -220,13 +219,9 @@ class ListBinaryDumper(BaseListDumper):
         return b"".join(data)
 
 
-class BaseArrayLoader(Loader):
+class BaseArrayLoader(RecursiveLoader):
     base_oid: int
 
-    def __init__(self, oid: int, context: Optional[AdaptContext] = None):
-        super().__init__(oid, context)
-        self._tx = Transformer(context)
-
 
 class ArrayLoader(BaseArrayLoader):
 
index 56a3d41e2e9f837d08159a303817a6480ba2300c..5245e9d2c27cd0922c6e84f332b787c933b47303 100644 (file)
@@ -12,8 +12,8 @@ from typing import Sequence, Tuple, Type
 
 from .. import pq
 from ..oids import TEXT_OID
-from ..adapt import Buffer, Format, Dumper, Loader, Transformer
-from ..proto import AdaptContext
+from ..adapt import Format, RecursiveDumper, RecursiveLoader
+from ..proto import AdaptContext, Buffer
 from .._struct import unpack_len
 from .._typeinfo import CompositeInfo
 
@@ -23,14 +23,10 @@ _unpack_oidlen = cast(
 )
 
 
-class SequenceDumper(Dumper):
+class SequenceDumper(RecursiveDumper):
 
     format = pq.Format.TEXT
 
-    def __init__(self, cls: type, context: Optional[AdaptContext] = None):
-        super().__init__(cls, context)
-        self._tx = Transformer(context)
-
     def _dump_sequence(
         self, obj: Sequence[Any], start: bytes, end: bytes, sep: bytes
     ) -> bytes:
@@ -71,14 +67,10 @@ class TupleDumper(SequenceDumper):
         return self._dump_sequence(obj, b"(", b")", b",")
 
 
-class BaseCompositeLoader(Loader):
+class BaseCompositeLoader(RecursiveLoader):
 
     format = pq.Format.TEXT
 
-    def __init__(self, oid: int, context: Optional[AdaptContext] = None):
-        super().__init__(oid, context)
-        self._tx = Transformer(context)
-
     def _parse_record(self, data: bytes) -> Iterator[Optional[bytes]]:
         """
         Split a non-empty representation of a composite type into components.
@@ -122,15 +114,11 @@ class RecordLoader(BaseCompositeLoader):
         )
 
 
-class RecordBinaryLoader(Loader):
+class RecordBinaryLoader(RecursiveLoader):
 
     format = pq.Format.BINARY
     _types_set = False
 
-    def __init__(self, oid: int, context: Optional[AdaptContext] = None):
-        super().__init__(oid, context)
-        self._tx = Transformer(context)
-
     def load(self, data: Buffer) -> Tuple[Any, ...]:
         if not self._types_set:
             self._config_types(data)
index 09ea3ccca08644dd9d42150d9aed6e33a826d43a..0ebecbd55bafc344ea0914a7291c4cd53b7d4034 100644 (file)
@@ -11,8 +11,8 @@ from datetime import date, datetime
 
 from ..pq import Format
 from ..oids import postgres_types as builtins, INVALID_OID
-from ..adapt import Buffer, Dumper, Loader, Format as Pg3Format, Transformer
-from ..proto import AdaptContext
+from ..adapt import Dumper, RecursiveLoader, Format as Pg3Format
+from ..proto import AdaptContext, Buffer
 from .._struct import unpack_len
 from .._typeinfo import RangeInfo
 
@@ -310,15 +310,11 @@ class RangeLoader(BaseCompositeLoader, Generic[T]):
         return Range(min, max, bounds)
 
 
-class RangeBinaryLoader(Loader, Generic[T]):
+class RangeBinaryLoader(RecursiveLoader, Generic[T]):
 
     format = Format.BINARY
     subtype_oid: int
 
-    def __init__(self, oid: int, context: Optional[AdaptContext] = None):
-        super().__init__(oid, context)
-        self._tx = Transformer(context)
-
     def load(self, data: Buffer) -> Range[T]:
         head = data[0]
         if head & RANGE_EMPTY: