]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Remove the cls attribute from the Dumper protocol
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 10 Jul 2021 16:41:44 +0000 (18:41 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 10 Jul 2021 17:08:15 +0000 (19:08 +0200)
Where used, in recursive types, get_key() can be used instead.

psycopg/psycopg/_transform.py
psycopg/psycopg/proto.py
psycopg/psycopg/types/array.py
psycopg/psycopg/types/datetime.py
psycopg/psycopg/types/range.py
psycopg_c/psycopg_c/_psycopg/adapt.pyx
tests/typing_example.py

index 1d472380843432f5039de60ac911ea0a7bfe3f10..d2351e2b34eee01e20e2e27c0baf4ea39c303630 100644 (file)
@@ -4,7 +4,7 @@ Helper object to transform values between Python and PostgreSQL
 
 # Copyright (C) 2020-2021 The Psycopg Team
 
-from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
+from typing import Any, Dict, List, Optional, Sequence, Tuple
 from typing import DefaultDict, TYPE_CHECKING
 from collections import defaultdict
 
@@ -12,7 +12,7 @@ from . import pq
 from . import errors as e
 from .oids import INVALID_OID
 from .rows import Row, RowMaker
-from .proto import LoadFunc, AdaptContext, PyFormat
+from .proto import LoadFunc, AdaptContext, PyFormat, DumperKey
 
 if TYPE_CHECKING:
     from .pq.proto import PGresult
@@ -20,11 +20,8 @@ if TYPE_CHECKING:
     from .proto import Dumper
     from .connection import BaseConnection
 
-DumperKey = Union[type, Tuple[type, ...]]
 DumperCache = Dict[DumperKey, "Dumper"]
-
-LoaderKey = int
-LoaderCache = Dict[LoaderKey, "Loader"]
+LoaderCache = Dict[int, "Loader"]
 
 
 class Transformer(AdaptContext):
index 8c1b574feedc49c3368290e90ed0711c1e7f0982..b0c92286a291df4fe22caf2e712649f91686cf8c 100644 (file)
@@ -29,6 +29,9 @@ Query = Union[str, bytes, "Composable"]
 Params = Union[Sequence[Any], Mapping[str, Any]]
 ConnectionType = TypeVar("ConnectionType", bound="BaseConnection[Any]")
 
+# TODO: make it recursive when mypy will support it
+# DumperKey = Union[type, Tuple[Union[type, "DumperKey"]]]
+DumperKey = Union[type, Tuple[type, ...]]
 
 # Waiting protocol types
 
@@ -73,7 +76,6 @@ class AdaptContext(Protocol):
 class Dumper(Protocol):
     format: pq.Format
     oid: int
-    cls: type
 
     def __init__(self, cls: type, context: Optional[AdaptContext] = None):
         ...
index ad998dae5e7bcd1a92f707cac8582878dcb814cf..67ff2692d82c05421c8250ac7fe930ebb524ecc8 100644 (file)
@@ -13,7 +13,7 @@ from .. import pq
 from .. import errors as e
 from ..oids import postgres_types, TEXT_OID, TEXT_ARRAY_OID, INVALID_OID
 from ..adapt import RecursiveDumper, RecursiveLoader, PyFormat
-from ..proto import Dumper, AdaptContext, Buffer
+from ..proto import AdaptContext, Buffer, Dumper, DumperKey
 from .._struct import pack_len, unpack_len
 from .._typeinfo import TypeInfo
 
@@ -35,11 +35,11 @@ class BaseListDumper(RecursiveDumper):
         self.sub_dumper: Optional[Dumper] = None
         self._types = context.adapters.types if context else postgres_types
 
-    def get_key(self, obj: List[Any], format: PyFormat) -> Tuple[type, ...]:
+    def get_key(self, obj: List[Any], format: PyFormat) -> DumperKey:
         item = self._find_list_element(obj)
         if item is not None:
             sd = self._tx.get_dumper(item, format)
-            return (self.cls, sd.cls)
+            return (self.cls, sd.get_key(item, format))  # type: ignore
         else:
             return (self.cls,)
 
index 793001085aa46b94e1e1a6a87675e2563468d570..8717bce2e41bea2d001326ef4f2426347947996c 100644 (file)
@@ -8,13 +8,13 @@ import re
 import sys
 import struct
 from datetime import date, datetime, time, timedelta, timezone
-from typing import Any, Callable, cast, Optional, Tuple, Union, TYPE_CHECKING
+from typing import Any, Callable, cast, Optional, Tuple, TYPE_CHECKING
 
 from ..pq import Format
 from .._tz import get_tzinfo
 from ..oids import postgres_types as builtins
 from ..adapt import Buffer, Dumper, Loader, PyFormat
-from ..proto import AdaptContext
+from ..proto import AdaptContext, DumperKey
 from ..errors import InterfaceError, DataError
 from .._struct import pack_int4, pack_int8, unpack_int4, unpack_int8
 
@@ -62,7 +62,7 @@ class DateBinaryDumper(Dumper):
 
 
 class _BaseTimeDumper(Dumper):
-    def get_key(self, obj: time, format: PyFormat) -> Union[type, Tuple[type]]:
+    def get_key(self, obj: time, format: PyFormat) -> DumperKey:
         # Use (cls,) to report the need to upgrade to a dumper for timetz (the
         # Frankenstein of the data types).
         if not obj.tzinfo:
@@ -131,9 +131,7 @@ class TimeTzBinaryDumper(_BaseTimeDumper):
 
 
 class _BaseDatetimeDumper(Dumper):
-    def get_key(
-        self, obj: datetime, format: PyFormat
-    ) -> Union[type, Tuple[type]]:
+    def get_key(self, obj: datetime, format: PyFormat) -> DumperKey:
         # Use (cls,) to report the need to upgrade (downgrade, actually) to a
         # dumper for naive timestamp.
         if obj.tzinfo:
index 00a624eacd605cc13706d4c57778f9fa7715d63e..e929ba00519e3194261c7242c0729c76ba5bfe12 100644 (file)
@@ -5,7 +5,7 @@ Support for range types adaptation.
 # Copyright (C) 2020-2021 The Psycopg Team
 
 import re
-from typing import Any, Dict, Generic, Optional, Tuple, TypeVar, Type, Union
+from typing import Any, Dict, Generic, Optional, TypeVar, Type, Union
 from typing import cast
 from decimal import Decimal
 from datetime import date, datetime
@@ -13,7 +13,7 @@ from datetime import date, datetime
 from ..pq import Format
 from ..oids import postgres_types as builtins, INVALID_OID
 from ..adapt import RecursiveDumper, RecursiveLoader, PyFormat
-from ..proto import Dumper, AdaptContext, Buffer
+from ..proto import AdaptContext, Buffer, Dumper, DumperKey
 from .._struct import pack_len, unpack_len
 from .._typeinfo import RangeInfo as RangeInfo  # exported here
 
@@ -259,9 +259,7 @@ class BaseRangeDumper(RecursiveDumper):
         self._types = context.adapters.types if context else builtins
         self._adapt_format = PyFormat.from_pq(self.format)
 
-    def get_key(
-        self, obj: Range[Any], format: PyFormat
-    ) -> Union[type, Tuple[type, ...]]:
+    def get_key(self, obj: Range[Any], format: PyFormat) -> DumperKey:
         # If we are a subclass whose oid is specified we don't need upgrade
         if self.oid != INVALID_OID:
             return self.cls
@@ -269,7 +267,7 @@ class BaseRangeDumper(RecursiveDumper):
         item = self._get_item(obj)
         if item is not None:
             sd = self._tx.get_dumper(item, self._adapt_format)
-            return (self.cls, sd.cls)
+            return (self.cls, sd.get_key(item, format))  # type: ignore
         else:
             return (self.cls,)
 
index ed30dc69b56502243e8b2de63578c990cef40e5f..771745d03b15fe19f59ae912c5021fb23825a810 100644 (file)
@@ -98,10 +98,10 @@ cdef class CDumper:
 
         return rv
 
-    cdef object get_key(self, object obj, object format):
+    cpdef object get_key(self, object obj, object format):
         return self.cls
 
-    cdef object upgrade(self, object obj, object format):
+    cpdef object upgrade(self, object obj, object format):
         return self
 
     @classmethod
index 7e3d6cf09bf61a6b605895fbe29a7d3518728085..42be3826297b9156d7e8b53b1b97a01a58c265bf 100644 (file)
@@ -3,7 +3,7 @@
 from __future__ import annotations
 
 from dataclasses import dataclass
-from typing import Any, Callable, Optional, Sequence, Tuple
+from typing import Any, Callable, Optional, Sequence, Tuple, Union
 
 from psycopg import AnyCursor, Connection, Cursor, ServerCursor, connect
 from psycopg import pq
@@ -99,7 +99,7 @@ class MyStrDumper:
     format = pq.Format.TEXT
 
     def __init__(self, cls: type, context: Optional[AdaptContext] = None):
-        self.cls = cls
+        self._cls = cls
         self.oid = 25  # text
 
     def dump(self, obj: str) -> bytes:
@@ -111,7 +111,13 @@ class MyStrDumper:
         return b"'%s'" % esc.escape_string(value.replace(b"h", b"q"))
 
     def get_key(self, obj: str, format: PyFormat) -> type:
-        return self.cls
+        return self._cls
 
     def upgrade(self, obj: str, format: PyFormat) -> "MyStrDumper":
         return self
+
+
+# This should be the definition of psycopg.adapt.DumperKey, but mypy doesn't
+# support recursive types. When it will, this statement will give an error
+# (unused type: ignore) so we can fix our definition.
+_DumperKey = Union[type, Tuple[Union[type, "_DumperKey"]]]  # type: ignore