From 3ecbd60d8044b109cd4a4db8687b9cf8d16b57f3 Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Wed, 4 Nov 2020 18:22:49 +0100 Subject: [PATCH] Dropped dumper param of Dumper.register Just use `cls`, as the class method is available to all the subclasses. --- psycopg3/psycopg3/adapt.py | 19 ++++++------------- tests/test_adapt.py | 12 ++++++------ 2 files changed, 12 insertions(+), 19 deletions(-) diff --git a/psycopg3/psycopg3/adapt.py b/psycopg3/psycopg3/adapt.py index 5343e3236..b911e74d7 100644 --- a/psycopg3/psycopg3/adapt.py +++ b/psycopg3/psycopg3/adapt.py @@ -47,32 +47,25 @@ class Dumper: def register( cls, src: type, - dumper: DumperType, context: AdaptContext = None, format: Format = Format.TEXT, - ) -> DumperType: + ) -> None: if not isinstance(src, type): raise TypeError( f"dumpers should be registered on classes, got {src} instead" ) - if not (isinstance(dumper, type)): - raise TypeError(f"dumpers should be classes, got {dumper} instead") - where = context.dumpers if context is not None else Dumper.globals - where[src, format] = dumper - return dumper + where[src, format] = cls @classmethod - def register_binary( - cls, src: type, dumper: DumperType, context: AdaptContext = None - ) -> DumperType: - return cls.register(src, dumper, context, format=Format.BINARY) + def register_binary(cls, src: type, context: AdaptContext = None) -> None: + cls.register(src, context, format=Format.BINARY) @classmethod def text(cls, src: type) -> Callable[[DumperType], DumperType]: def text_(dumper: DumperType) -> DumperType: - cls.register(src, dumper) + dumper.register(src) return dumper return text_ @@ -80,7 +73,7 @@ class Dumper: @classmethod def binary(cls, src: type) -> Callable[[DumperType], DumperType]: def binary_(dumper: DumperType) -> DumperType: - cls.register_binary(src, dumper) + dumper.register_binary(src) return dumper return binary_ diff --git a/tests/test_adapt.py b/tests/test_adapt.py index 5787bf116..e50b22003 100644 --- a/tests/test_adapt.py +++ b/tests/test_adapt.py @@ -37,8 +37,8 @@ def test_quote(data, result): def test_dump_connection_ctx(conn): - Dumper.register(str, make_dumper("t"), conn) - Dumper.register_binary(str, make_dumper("b"), conn) + make_dumper("t").register(str, conn) + make_dumper("b").register_binary(str, conn) cur = conn.cursor() cur.execute("select %s, %b", ["hello", "world"]) @@ -46,12 +46,12 @@ def test_dump_connection_ctx(conn): def test_dump_cursor_ctx(conn): - Dumper.register(str, make_dumper("t"), conn) - Dumper.register_binary(str, make_dumper("b"), conn) + make_dumper("t").register(str, conn) + make_dumper("b").register_binary(str, conn) cur = conn.cursor() - Dumper.register(str, make_dumper("tc"), cur) - Dumper.register_binary(str, make_dumper("bc"), cur) + make_dumper("tc").register(str, cur) + make_dumper("bc").register_binary(str, cur) cur.execute("select %s, %b", ["hello", "world"]) assert cur.fetchone() == ("hellotc", "worldbc") -- 2.47.2