]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Dropped format parameter from register
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 30 Dec 2020 21:02:02 +0000 (22:02 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 8 Jan 2021 01:26:53 +0000 (02:26 +0100)
Now the loaders/dumper declare their format themselves.

psycopg3/psycopg3/adapt.py
psycopg3/psycopg3/types/array.py
psycopg3/psycopg3/types/composite.py
psycopg3/psycopg3/types/numeric.py
psycopg3/psycopg3/types/range.py
tests/test_adapt.py

index b32442b0669792ae29a0ae5f78188b889fadb651..28358d016315f5c58cebe3434417d2ce01e80671 100644 (file)
@@ -64,20 +64,18 @@ class Dumper(ABC):
 
     @classmethod
     def register(
-        cls,
-        src: Union[type, str],
-        context: Optional[AdaptContext] = None,
-        format: Format = Format.TEXT,
+        cls, src: Union[type, str], context: Optional[AdaptContext] = None
     ) -> None:
         """
         Configure *context* to use this dumper to convert object of type *src*.
         """
         adapters = context.adapters if context else global_adapters
-        adapters.register_dumper(src, cls, format=format)
+        adapters.register_dumper(src, cls)
 
     @classmethod
     def text(cls, src: Union[type, str]) -> Callable[[DumperType], DumperType]:
         def text_(dumper: DumperType) -> DumperType:
+            assert dumper.format == Format.TEXT
             dumper.register(src)
             return dumper
 
@@ -88,7 +86,8 @@ class Dumper(ABC):
         cls, src: Union[type, str]
     ) -> Callable[[DumperType], DumperType]:
         def binary_(dumper: DumperType) -> DumperType:
-            dumper.register(src, format=Format.BINARY)
+            assert dumper.format == Format.BINARY
+            dumper.register(src)
             return dumper
 
         return binary_
@@ -113,20 +112,18 @@ class Loader(ABC):
 
     @classmethod
     def register(
-        cls,
-        oid: int,
-        context: Optional[AdaptContext] = None,
-        format: Format = Format.TEXT,
+        cls, oid: int, context: Optional[AdaptContext] = None
     ) -> None:
         """
         Configure *context* to use this loader to convert values with OID *oid*.
         """
         adapters = context.adapters if context else global_adapters
-        adapters.register_loader(oid, cls, format=format)
+        adapters.register_loader(oid, cls)
 
     @classmethod
     def text(cls, oid: int) -> Callable[[LoaderType], LoaderType]:
         def text_(loader: LoaderType) -> LoaderType:
+            assert loader.format == Format.TEXT
             loader.register(oid)
             return loader
 
@@ -135,7 +132,8 @@ class Loader(ABC):
     @classmethod
     def binary(cls, oid: int) -> Callable[[LoaderType], LoaderType]:
         def binary_(loader: LoaderType) -> LoaderType:
-            loader.register(oid, format=Format.BINARY)
+            assert loader.format == Format.BINARY
+            loader.register(oid)
             return loader
 
         return binary_
@@ -167,10 +165,7 @@ class AdaptersMap:
             self._own_loaders = True
 
     def register_dumper(
-        self,
-        src: Union[type, str],
-        dumper: Type[Dumper],
-        format: Format = Format.TEXT,
+        self, src: Union[type, str], dumper: Type[Dumper]
     ) -> None:
         """
         Configure the context to use *dumper* to convert object of type *src*.
@@ -184,11 +179,9 @@ class AdaptersMap:
             self._dumpers = self._dumpers.copy()
             self._own_dumpers = True
 
-        self._dumpers[src, format] = dumper
+        self._dumpers[src, dumper.format] = dumper
 
-    def register_loader(
-        self, oid: int, loader: Type[Loader], format: Format = Format.TEXT
-    ) -> None:
+    def register_loader(self, oid: int, loader: Type[Loader]) -> None:
         """
         Configure the context to use *loader* to convert data of oid *oid*.
         """
@@ -201,7 +194,7 @@ class AdaptersMap:
             self._loaders = self._loaders.copy()
             self._own_loaders = True
 
-        self._loaders[oid, format] = loader
+        self._loaders[oid, loader.format] = loader
 
 
 global_adapters = AdaptersMap()
index dcbec8efa143ec67b2e162b72ed10700bce5d1b1..7b5c366ce8511fb38980d1217eb27a7a67a5a20d 100644 (file)
@@ -280,13 +280,10 @@ def register(
     if not name:
         name = f"oid{base_oid}"
 
-    for format, base in (
-        (Format.TEXT, ArrayLoader),
-        (Format.BINARY, ArrayBinaryLoader),
-    ):
+    for base in (ArrayLoader, ArrayBinaryLoader):
         lname = f"{name.title()}Array{'Binary' if format else ''}Loader"
         loader: Type[Loader] = type(lname, (base,), {"base_oid": base_oid})
-        loader.register(array_oid, context=context, format=format)
+        loader.register(array_oid, context=context)
 
 
 def register_all_arrays() -> None:
index e5065c5edbf2ca13b453ac96f67fae9e9ea88b8a..17b3f8068319bf3c7523c8d536e5bf6b29c2142d 100644 (file)
@@ -89,7 +89,7 @@ class CompositeInfo(TypeInfo):
                 "fields_types": [f.type_oid for f in self.fields],
             },
         )
-        loader.register(self.oid, context=context, format=Format.TEXT)
+        loader.register(self.oid, context=context)
 
         # generate and register a customized binary loader
         loader = type(
@@ -97,7 +97,7 @@ class CompositeInfo(TypeInfo):
             (CompositeBinaryLoader,),
             {"factory": factory},
         )
-        loader.register(self.oid, context=context, format=Format.BINARY)
+        loader.register(self.oid, context=context)
 
         if self.array_oid:
             array.register(
index 476b16cd000af444df6eca1e9e995a8fb0fdc3fb..91ba9f066a32838289257764c1b0e8474486a0c2 100644 (file)
@@ -120,6 +120,7 @@ class FloatBinaryDumper(Dumper):
 
 @Dumper.text(Decimal)
 class DecimalDumper(SpecialValuesDumper):
+
     _oid = builtins["numeric"].oid
 
     _special = {
index c480c688b2ee41492f7aaa820d48b6fae5dcd009..c7c0caa4afebcd79f0bf0223d16498386cdfa9ff 100644 (file)
@@ -409,7 +409,7 @@ class RangeInfo(TypeInfo):
         dumper: Type[Dumper] = type(
             f"{self.name.title()}Dumper", (RangeDumper,), {"_oid": self.oid}
         )
-        dumper.register(range_class, context=context, format=Format.TEXT)
+        dumper.register(range_class, context=context)
 
         # generate and register a customized text loader
         loader: Type[Loader] = type(
@@ -417,7 +417,7 @@ class RangeInfo(TypeInfo):
             (RangeLoader,),
             {"cls": range_class, "subtype_oid": self.subtype_oid},
         )
-        loader.register(self.oid, context=context, format=Format.TEXT)
+        loader.register(self.oid, context=context)
 
         if self.array_oid:
             array.register(
index da0bea7eb45a402d825752bd62d0bcdac000d089..17288caef3bc9585e1fb674499b302e6590b0163 100644 (file)
@@ -37,7 +37,7 @@ def test_quote(data, result):
 
 def test_dump_connection_ctx(conn):
     make_dumper("t").register(str, conn)
-    make_bin_dumper("b").register(str, conn, format=Format.BINARY)
+    make_bin_dumper("b").register(str, conn)
 
     cur = conn.cursor()
     cur.execute("select %s, %b", ["hello", "world"])
@@ -46,11 +46,11 @@ def test_dump_connection_ctx(conn):
 
 def test_dump_cursor_ctx(conn):
     make_dumper("t").register(str, conn)
-    make_bin_dumper("b").register(str, conn, format=Format.BINARY)
+    make_bin_dumper("b").register(str, conn)
 
     cur = conn.cursor()
     make_dumper("tc").register(str, cur)
-    make_bin_dumper("bc").register(str, cur, format=Format.BINARY)
+    make_bin_dumper("bc").register(str, cur)
 
     cur.execute("select %s, %b", ["hello", "world"])
     assert cur.fetchone() == ("hellotc", "worldbc")
@@ -88,7 +88,7 @@ def test_cast(data, format, type, result):
 
 def test_load_connection_ctx(conn):
     make_loader("t").register(TEXT_OID, conn)
-    make_bin_loader("b").register(TEXT_OID, conn, format=Format.BINARY)
+    make_bin_loader("b").register(TEXT_OID, conn)
 
     r = conn.cursor().execute("select 'hello'::text").fetchone()
     assert r == ("hellot",)
@@ -98,11 +98,11 @@ def test_load_connection_ctx(conn):
 
 def test_load_cursor_ctx(conn):
     make_loader("t").register(TEXT_OID, conn)
-    make_bin_loader("b").register(TEXT_OID, conn, format=Format.BINARY)
+    make_bin_loader("b").register(TEXT_OID, conn)
 
     cur = conn.cursor()
     make_loader("tc").register(TEXT_OID, cur)
-    make_bin_loader("bc").register(TEXT_OID, cur, format=Format.BINARY)
+    make_bin_loader("bc").register(TEXT_OID, cur)
 
     r = cur.execute("select 'hello'::text").fetchone()
     assert r == ("hellotc",)
@@ -125,10 +125,10 @@ def test_load_cursor_ctx(conn):
 @pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])
 def test_load_cursor_ctx_nested(conn, sql, obj, fmt_out):
     cur = conn.cursor(format=fmt_out)
-    if format == Format.TEXT:
-        make_loader("c").register(TEXT_OID, cur, format=fmt_out)
+    if fmt_out == Format.TEXT:
+        make_loader("c").register(TEXT_OID, cur)
     else:
-        make_bin_loader("c").register(TEXT_OID, cur, format=fmt_out)
+        make_bin_loader("c").register(TEXT_OID, cur)
 
     cur.execute(f"select {sql}")
     res = cur.fetchone()[0]