]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix: annotate every dumper to return Optional[Buffer]
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 19 Sep 2022 00:42:34 +0000 (01:42 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 1 Jun 2024 11:07:21 +0000 (13:07 +0200)
Even if these classes never return None, this allows to create
subclasses returning None without making Mypy unhappy.

Similarly, annotate quote() methods as returning Buffer to allow
subclasses to return other types.

21 files changed:
psycopg/psycopg/dbapi20.py
psycopg/psycopg/types/array.py
psycopg/psycopg/types/bool.py
psycopg/psycopg/types/composite.py
psycopg/psycopg/types/datetime.py
psycopg/psycopg/types/enum.py
psycopg/psycopg/types/json.py
psycopg/psycopg/types/multirange.py
psycopg/psycopg/types/net.py
psycopg/psycopg/types/none.py
psycopg/psycopg/types/numeric.py
psycopg/psycopg/types/range.py
psycopg/psycopg/types/shapely.py
psycopg/psycopg/types/string.py
psycopg/psycopg/types/uuid.py
psycopg_c/psycopg_c/_psycopg/adapt.pyx
psycopg_c/psycopg_c/types/bool.pyx
psycopg_c/psycopg_c/types/numeric.pyx
psycopg_c/psycopg_c/types/string.pyx
tests/test_copy.py
tests/test_copy_async.py

index aab4c822f3f628a5bebadcf14d81eee24f2e1e55..919b0506c79c976d6a4154c47fac6a8c3cb2fd07 100644 (file)
@@ -7,7 +7,7 @@ Compatibility objects with DBAPI 2.0
 import time
 import datetime as dt
 from math import floor
-from typing import Any, Sequence, Union
+from typing import Any, Optional, Sequence, Union
 
 from . import _oids
 from .abc import AdaptContext, Buffer
@@ -76,7 +76,7 @@ class Binary:
 
 
 class BinaryBinaryDumper(BytesBinaryDumper):
-    def dump(self, obj: Union[Buffer, Binary]) -> Buffer:
+    def dump(self, obj: Union[Buffer, Binary]) -> Optional[Buffer]:
         if isinstance(obj, Binary):
             return super().dump(obj.obj)
         else:
@@ -84,7 +84,7 @@ class BinaryBinaryDumper(BytesBinaryDumper):
 
 
 class BinaryTextDumper(BytesDumper):
-    def dump(self, obj: Union[Buffer, Binary]) -> Buffer:
+    def dump(self, obj: Union[Buffer, Binary]) -> Optional[Buffer]:
         if isinstance(obj, Binary):
             return super().dump(obj.obj)
         else:
index d4a5dc0b2caa68c9fb8cfcc0a65d316b02aec2cd..23afc7d9f83597799095ccab24fffcabe3697976 100644 (file)
@@ -153,7 +153,7 @@ class ListDumper(BaseListDumper):
     # backslash-escaped.
     _re_esc = re.compile(rb'(["\\])')
 
-    def dump(self, obj: List[Any]) -> bytes:
+    def dump(self, obj: List[Any]) -> Optional[Buffer]:
         tokens: List[Buffer] = []
         needs_quotes = _get_needs_quotes_regexp(self.delimiter).search
 
@@ -245,7 +245,7 @@ class ListBinaryDumper(BaseListDumper):
 
         return dumper
 
-    def dump(self, obj: List[Any]) -> bytes:
+    def dump(self, obj: List[Any]) -> Optional[Buffer]:
         # Postgres won't take unknown for element oid: fall back on text
         sub_oid = self.sub_dumper and self.sub_dumper.oid or TEXT_OID
 
index 2ad53c2f4168ce9d6f2203aba9c18be06abfd151..c056793541e5a16193220bc225dde33c7765e4cc 100644 (file)
@@ -5,6 +5,7 @@ Adapters for booleans.
 # Copyright (C) 2020 The Psycopg Team
 
 from .. import _oids
+from typing import Optional
 from ..pq import Format
 from ..abc import AdaptContext
 from ..adapt import Buffer, Dumper, Loader
@@ -13,10 +14,10 @@ from ..adapt import Buffer, Dumper, Loader
 class BoolDumper(Dumper):
     oid = _oids.BOOL_OID
 
-    def dump(self, obj: bool) -> bytes:
+    def dump(self, obj: bool) -> Optional[Buffer]:
         return b"t" if obj else b"f"
 
-    def quote(self, obj: bool) -> bytes:
+    def quote(self, obj: bool) -> Buffer:
         return b"true" if obj else b"false"
 
 
@@ -24,7 +25,7 @@ class BoolBinaryDumper(Dumper):
     format = Format.BINARY
     oid = _oids.BOOL_OID
 
-    def dump(self, obj: bool) -> bytes:
+    def dump(self, obj: bool) -> Optional[Buffer]:
         return b"\x01" if obj else b"\x00"
 
 
index d116273c40fdbbe9f3e0e0b6362e462b3cb8db3e..d3f41b675bab4bcb3591bdd7694581abc1d04124 100644 (file)
@@ -14,7 +14,7 @@ from .. import pq
 from .. import abc
 from .. import sql
 from .. import postgres
-from ..adapt import Transformer, PyFormat, RecursiveDumper, Loader, Dumper
+from ..adapt import Transformer, PyFormat, RecursiveDumper, Loader, Dumper, Buffer
 from .._oids import TEXT_OID
 from .._compat import cache
 from .._struct import pack_len, unpack_len
@@ -117,7 +117,7 @@ class TupleDumper(SequenceDumper):
     # Should be this, but it doesn't work
     # oid = _oids.RECORD_OID
 
-    def dump(self, obj: Tuple[Any, ...]) -> bytes:
+    def dump(self, obj: Tuple[Any, ...]) -> Optional[Buffer]:
         return self._dump_sequence(obj, b"(", b")", b",")
 
 
@@ -140,7 +140,7 @@ class TupleBinaryDumper(Dumper):
         nfields = len(self._field_types)
         self._formats = (PyFormat.from_pq(self.format),) * nfields
 
-    def dump(self, obj: Tuple[Any, ...]) -> bytearray:
+    def dump(self, obj: Tuple[Any, ...]) -> Optional[Buffer]:
         out = bytearray(pack_len(len(obj)))
         adapted = self._tx.dump_sequence(obj, self._formats)
         for i in range(len(obj)):
index e5629a9ddc2dc919d5fa248c30f56075f0c1fa52..b3bfa1b8be00e90c950c8aa3cbc6e1e103be38c1 100644 (file)
@@ -40,7 +40,7 @@ _py_date_min_days = date.min.toordinal()
 class DateDumper(Dumper):
     oid = _oids.DATE_OID
 
-    def dump(self, obj: date) -> bytes:
+    def dump(self, obj: date) -> Optional[Buffer]:
         # NOTE: whatever the PostgreSQL DateStyle input format (DMY, MDY, YMD)
         # the YYYY-MM-DD is always understood correctly.
         return str(obj).encode()
@@ -50,7 +50,7 @@ class DateBinaryDumper(Dumper):
     format = Format.BINARY
     oid = _oids.DATE_OID
 
-    def dump(self, obj: date) -> bytes:
+    def dump(self, obj: date) -> Optional[Buffer]:
         days = obj.toordinal() - _pg_date_epoch_days
         return pack_int4(days)
 
@@ -77,7 +77,7 @@ class _BaseTimeDumper(Dumper):
 
 
 class _BaseTimeTextDumper(_BaseTimeDumper):
-    def dump(self, obj: time) -> bytes:
+    def dump(self, obj: time) -> Optional[Buffer]:
         return str(obj).encode()
 
 
@@ -94,7 +94,7 @@ class TimeDumper(_BaseTimeTextDumper):
 class TimeTzDumper(_BaseTimeTextDumper):
     oid = _oids.TIMETZ_OID
 
-    def dump(self, obj: time) -> bytes:
+    def dump(self, obj: time) -> Optional[Buffer]:
         self._get_offset(obj)
         return super().dump(obj)
 
@@ -103,7 +103,7 @@ class TimeBinaryDumper(_BaseTimeDumper):
     format = Format.BINARY
     oid = _oids.TIME_OID
 
-    def dump(self, obj: time) -> bytes:
+    def dump(self, obj: time) -> Optional[Buffer]:
         us = obj.microsecond + 1_000_000 * (
             obj.second + 60 * (obj.minute + 60 * obj.hour)
         )
@@ -120,7 +120,7 @@ class TimeTzBinaryDumper(_BaseTimeDumper):
     format = Format.BINARY
     oid = _oids.TIMETZ_OID
 
-    def dump(self, obj: time) -> bytes:
+    def dump(self, obj: time) -> Optional[Buffer]:
         us = obj.microsecond + 1_000_000 * (
             obj.second + 60 * (obj.minute + 60 * obj.hour)
         )
@@ -142,7 +142,7 @@ class _BaseDatetimeDumper(Dumper):
 
 
 class _BaseDatetimeTextDumper(_BaseDatetimeDumper):
-    def dump(self, obj: datetime) -> bytes:
+    def dump(self, obj: datetime) -> Optional[Buffer]:
         # NOTE: whatever the PostgreSQL DateStyle input format (DMY, MDY, YMD)
         # the YYYY-MM-DD is always understood correctly.
         return str(obj).encode()
@@ -166,7 +166,7 @@ class DatetimeBinaryDumper(_BaseDatetimeDumper):
     format = Format.BINARY
     oid = _oids.TIMESTAMPTZ_OID
 
-    def dump(self, obj: datetime) -> bytes:
+    def dump(self, obj: datetime) -> Optional[Buffer]:
         delta = obj - _pg_datetimetz_epoch
         micros = delta.microseconds + 1_000_000 * (86_400 * delta.days + delta.seconds)
         return pack_int8(micros)
@@ -182,7 +182,7 @@ class DatetimeNoTzBinaryDumper(_BaseDatetimeDumper):
     format = Format.BINARY
     oid = _oids.TIMESTAMP_OID
 
-    def dump(self, obj: datetime) -> bytes:
+    def dump(self, obj: datetime) -> Optional[Buffer]:
         delta = obj - _pg_datetime_epoch
         micros = delta.microseconds + 1_000_000 * (86_400 * delta.days + delta.seconds)
         return pack_int8(micros)
@@ -198,7 +198,7 @@ class TimedeltaDumper(Dumper):
         else:
             self._dump_method = self._dump_any
 
-    def dump(self, obj: timedelta) -> bytes:
+    def dump(self, obj: timedelta) -> Optional[Buffer]:
         return self._dump_method(self, obj)
 
     @staticmethod
@@ -222,7 +222,7 @@ class TimedeltaBinaryDumper(Dumper):
     format = Format.BINARY
     oid = _oids.INTERVAL_OID
 
-    def dump(self, obj: timedelta) -> bytes:
+    def dump(self, obj: timedelta) -> Optional[Buffer]:
         micros = 1_000_000 * obj.seconds + obj.microseconds
         return _pack_interval(micros, obj.days, 0)
 
index 6e20dd3cef764aab3fecc893c517ac0fabc7e230..04616419d19286f3d23acb0a22832e1a646c49f0 100644 (file)
@@ -98,7 +98,7 @@ class _BaseEnumDumper(Dumper, Generic[E]):
     enum: Type[E]
     _dump_map: EnumDumpMap[E]
 
-    def dump(self, value: E) -> Buffer:
+    def dump(self, value: E) -> Optional[Buffer]:
         return self._dump_map[value]
 
 
@@ -111,7 +111,7 @@ class EnumDumper(Dumper):
         super().__init__(cls, context)
         self._encoding = conn_encoding(self.connection)
 
-    def dump(self, value: E) -> Buffer:
+    def dump(self, value: E) -> Optional[Buffer]:
         return value.name.encode(self._encoding)
 
 
index df571e43c119a0fc69bde30b42ca15c556e7e456..0f5651e71e826af53b73eba11bb411b6b419826d 100644 (file)
@@ -142,7 +142,7 @@ class _JsonDumper(Dumper):
         super().__init__(cls, context)
         self.dumps = self.__class__._dumps
 
-    def dump(self, obj: Any) -> bytes:
+    def dump(self, obj: Any) -> Optional[Buffer]:
         if isinstance(obj, _JsonWrapper):
             dumps = obj.dumps or self.dumps
             obj = obj.obj
@@ -171,7 +171,7 @@ class JsonbBinaryDumper(_JsonDumper):
     format = Format.BINARY
     oid = _oids.JSONB_OID
 
-    def dump(self, obj: Any) -> bytes:
+    def dump(self, obj: Any) -> Optional[Buffer]:
         return b"\x01" + super().dump(obj)
 
 
index 51f61d1a79445a9a7f0d33a036238afe63c15ec5..d37681009f60484588ed81a7d4e9db04e37c84b1 100644 (file)
@@ -256,7 +256,7 @@ class MultirangeDumper(BaseMultirangeDumper):
     The dumper can upgrade to one specific for a different range type.
     """
 
-    def dump(self, obj: Multirange[Any]) -> Buffer:
+    def dump(self, obj: Multirange[Any]) -> Optional[Buffer]:
         if not obj:
             return b"{}"
 
@@ -277,7 +277,7 @@ class MultirangeDumper(BaseMultirangeDumper):
 class MultirangeBinaryDumper(BaseMultirangeDumper):
     format = Format.BINARY
 
-    def dump(self, obj: Multirange[Any]) -> Buffer:
+    def dump(self, obj: Multirange[Any]) -> Optional[Buffer]:
         item = self._get_item(obj)
         if item is not None:
             dump = self._tx.get_dumper(item, self._adapt_format).dump
index 76522dcbbd3917a6d09d79e05d88554567e4d136..b8fc992040a6a5f6bf9f796d4ec2eb060865bffe 100644 (file)
@@ -52,14 +52,14 @@ class _LazyIpaddress:
 class InterfaceDumper(Dumper):
     oid = _oids.INET_OID
 
-    def dump(self, obj: Interface) -> bytes:
+    def dump(self, obj: Interface) -> Optional[Buffer]:
         return str(obj).encode()
 
 
 class NetworkDumper(Dumper):
     oid = _oids.CIDR_OID
 
-    def dump(self, obj: Network) -> bytes:
+    def dump(self, obj: Network) -> Optional[Buffer]:
         return str(obj).encode()
 
 
@@ -69,7 +69,7 @@ class _AIBinaryDumper(Dumper):
 
 
 class AddressBinaryDumper(_AIBinaryDumper):
-    def dump(self, obj: Address) -> bytes:
+    def dump(self, obj: Address) -> Optional[Buffer]:
         packed = obj.packed
         family = PGSQL_AF_INET if obj.version == 4 else PGSQL_AF_INET6
         head = bytes((family, obj.max_prefixlen, 0, len(packed)))
@@ -77,7 +77,7 @@ class AddressBinaryDumper(_AIBinaryDumper):
 
 
 class InterfaceBinaryDumper(_AIBinaryDumper):
-    def dump(self, obj: Interface) -> bytes:
+    def dump(self, obj: Interface) -> Optional[Buffer]:
         packed = obj.packed
         family = PGSQL_AF_INET if obj.version == 4 else PGSQL_AF_INET6
         head = bytes((family, obj.network.prefixlen, 0, len(packed)))
@@ -94,7 +94,7 @@ class InetBinaryDumper(_AIBinaryDumper, _LazyIpaddress):
         super().__init__(cls, context)
         self._ensure_module()
 
-    def dump(self, obj: Union[Address, Interface]) -> bytes:
+    def dump(self, obj: Union[Address, Interface]) -> Optional[Buffer]:
         packed = obj.packed
         family = PGSQL_AF_INET if obj.version == 4 else PGSQL_AF_INET6
         if isinstance(obj, (IPv4Interface, IPv6Interface)):
@@ -110,7 +110,7 @@ class NetworkBinaryDumper(Dumper):
     format = Format.BINARY
     oid = _oids.CIDR_OID
 
-    def dump(self, obj: Network) -> bytes:
+    def dump(self, obj: Network) -> Optional[Buffer]:
         packed = obj.network_address.packed
         family = PGSQL_AF_INET if obj.version == 4 else PGSQL_AF_INET6
         head = bytes((family, obj.prefixlen, 1, len(packed)))
index 2ab857ccabf0ead816ff115b206a3a56dd8ce4cf..b65f6315b2ffee1d6a117a64f38efe4fd718a96c 100644 (file)
@@ -4,7 +4,9 @@ Adapters for None.
 
 # Copyright (C) 2020 The Psycopg Team
 
-from ..abc import AdaptContext, NoneType
+from typing import Optional
+
+from ..abc import AdaptContext, NoneType, Buffer
 from ..adapt import Dumper
 
 
@@ -14,10 +16,10 @@ class NoneDumper(Dumper):
     quote(), so it can be used in sql composition.
     """
 
-    def dump(self, obj: None) -> bytes:
+    def dump(self, obj: None) -> Optional[Buffer]:
         raise NotImplementedError("NULL is passed to Postgres in other ways")
 
-    def quote(self, obj: None) -> bytes:
+    def quote(self, obj: None) -> Buffer:
         return b"NULL"
 
 
index 1817740fd6bd504af9b70661c5f7430d899652a5..4bfef8e2fbe24db40267077234c52a52b28371f7 100644 (file)
@@ -38,16 +38,18 @@ if TYPE_CHECKING:
 
 
 class _IntDumper(Dumper):
-    def dump(self, obj: Any) -> Buffer:
+    def dump(self, obj: Any) -> Optional[Buffer]:
         return str(obj).encode()
 
     def quote(self, obj: Any) -> Buffer:
         value = self.dump(obj)
+        if value is None:
+            return b"NULL"
         return value if obj >= 0 else b" " + value
 
 
 class _IntOrSubclassDumper(_IntDumper):
-    def dump(self, obj: Any) -> Buffer:
+    def dump(self, obj: Any) -> Optional[Buffer]:
         t = type(obj)
         # Convert to int in order to dump IntEnum or numpy.integer correctly
         if t is not int:
@@ -59,12 +61,16 @@ class _IntOrSubclassDumper(_IntDumper):
 class _SpecialValuesDumper(Dumper):
     _special: Dict[bytes, bytes] = {}
 
-    def dump(self, obj: Any) -> bytes:
+    def dump(self, obj: Any) -> Optional[Buffer]:
         return str(obj).encode()
 
-    def quote(self, obj: Any) -> bytes:
+    def quote(self, obj: Any) -> Buffer:
         value = self.dump(obj)
 
+        if value is None:
+            return b"NULL"
+        if not isinstance(value, bytes):
+            value = bytes(value)
         if value in self._special:
             return self._special[value]
 
@@ -89,21 +95,21 @@ class FloatBinaryDumper(Dumper):
     format = Format.BINARY
     oid = _oids.FLOAT8_OID
 
-    def dump(self, obj: float) -> bytes:
+    def dump(self, obj: float) -> Optional[Buffer]:
         return pack_float8(obj)
 
 
 class Float4BinaryDumper(FloatBinaryDumper):
     oid = _oids.FLOAT4_OID
 
-    def dump(self, obj: float) -> bytes:
+    def dump(self, obj: float) -> Optional[Buffer]:
         return pack_float4(obj)
 
 
 class DecimalDumper(_SpecialValuesDumper):
     oid = _oids.NUMERIC_OID
 
-    def dump(self, obj: Decimal) -> bytes:
+    def dump(self, obj: Decimal) -> Optional[Buffer]:
         return dump_decimal_to_text(obj)
 
     _special = {
@@ -134,7 +140,7 @@ class OidDumper(_IntOrSubclassDumper):
 
 
 class IntDumper(Dumper):
-    def dump(self, obj: Any) -> bytes:
+    def dump(self, obj: Any) -> Optional[Buffer]:
         raise TypeError(
             f"{type(self).__name__} is a dispatcher to other dumpers:"
             " dump() is not supposed to be called"
@@ -164,21 +170,21 @@ class IntDumper(Dumper):
 class Int2BinaryDumper(Int2Dumper):
     format = Format.BINARY
 
-    def dump(self, obj: int) -> bytes:
+    def dump(self, obj: int) -> Optional[Buffer]:
         return pack_int2(obj)
 
 
 class Int4BinaryDumper(Int4Dumper):
     format = Format.BINARY
 
-    def dump(self, obj: int) -> bytes:
+    def dump(self, obj: int) -> Optional[Buffer]:
         return pack_int4(obj)
 
 
 class Int8BinaryDumper(Int8Dumper):
     format = Format.BINARY
 
-    def dump(self, obj: int) -> bytes:
+    def dump(self, obj: int) -> Optional[Buffer]:
         return pack_int8(obj)
 
 
@@ -190,14 +196,14 @@ BIT_PER_PGDIGIT = log(2) / log(10_000)
 class IntNumericBinaryDumper(IntNumericDumper):
     format = Format.BINARY
 
-    def dump(self, obj: int) -> Buffer:
+    def dump(self, obj: int) -> Optional[Buffer]:
         return dump_int_to_numeric_binary(obj)
 
 
 class OidBinaryDumper(OidDumper):
     format = Format.BINARY
 
-    def dump(self, obj: int) -> bytes:
+    def dump(self, obj: int) -> Optional[Buffer]:
         return pack_uint4(obj)
 
 
@@ -350,7 +356,7 @@ class DecimalBinaryDumper(Dumper):
     format = Format.BINARY
     oid = _oids.NUMERIC_OID
 
-    def dump(self, obj: Decimal) -> Buffer:
+    def dump(self, obj: Decimal) -> Optional[Buffer]:
         return dump_decimal_to_numeric_binary(obj)
 
 
@@ -379,11 +385,13 @@ class _MixedNumericDumper(Dumper, ABC):
                 _MixedNumericDumper.int_classes = int
 
     @abstractmethod
-    def dump(self, obj: Union[Decimal, int, "numpy.integer[Any]"]) -> Buffer: ...
+    def dump(
+        self, obj: Union[Decimal, int, "numpy.integer[Any]"]
+    ) -> Optional[Buffer]: ...
 
 
 class NumericDumper(_MixedNumericDumper):
-    def dump(self, obj: Union[Decimal, int, "numpy.integer[Any]"]) -> Buffer:
+    def dump(self, obj: Union[Decimal, int, "numpy.integer[Any]"]) -> Optional[Buffer]:
         if isinstance(obj, self.int_classes):
             return str(obj).encode()
         elif isinstance(obj, Decimal):
@@ -397,7 +405,7 @@ class NumericDumper(_MixedNumericDumper):
 class NumericBinaryDumper(_MixedNumericDumper):
     format = Format.BINARY
 
-    def dump(self, obj: Union[Decimal, int, "numpy.integer[Any]"]) -> Buffer:
+    def dump(self, obj: Union[Decimal, int, "numpy.integer[Any]"]) -> Optional[Buffer]:
         if type(obj) is int:
             return dump_int_to_numeric_binary(obj)
         elif isinstance(obj, Decimal):
index e320e42037193cdbbd112ef99a7d20de8a7a016a..c3d472621dca90804613b24098258360fc1014c8 100644 (file)
@@ -354,7 +354,7 @@ class RangeDumper(BaseRangeDumper):
     The dumper can upgrade to one specific for a different range type.
     """
 
-    def dump(self, obj: Range[Any]) -> Buffer:
+    def dump(self, obj: Range[Any]) -> Optional[Buffer]:
         item = self._get_item(obj)
         if item is not None:
             dump = self._tx.get_dumper(item, self._adapt_format).dump
@@ -399,7 +399,7 @@ _re_esc = re.compile(rb"([\\\"])")
 class RangeBinaryDumper(BaseRangeDumper):
     format = Format.BINARY
 
-    def dump(self, obj: Range[Any]) -> Buffer:
+    def dump(self, obj: Range[Any]) -> Optional[Buffer]:
         item = self._get_item(obj)
         if item is not None:
             dump = self._tx.get_dumper(item, self._adapt_format).dump
index e9387007af1f81a355ac97fb6cbe041068d8da54..5714454137e2759db3f9b33aac183587db66e3a4 100644 (file)
@@ -43,12 +43,12 @@ class GeometryLoader(Loader):
 class BaseGeometryBinaryDumper(Dumper):
     format = Format.BINARY
 
-    def dump(self, obj: "BaseGeometry") -> bytes:
+    def dump(self, obj: "BaseGeometry") -> Optional[Buffer]:
         return dumps(obj)  # type: ignore
 
 
 class BaseGeometryDumper(Dumper):
-    def dump(self, obj: "BaseGeometry") -> bytes:
+    def dump(self, obj: "BaseGeometry") -> Optional[Buffer]:
         return dumps(obj, hex=True).encode()  # type: ignore
 
 
index a0adc5650412d27b4d721b4bfeefb92ea723fcca..88548180aac385fc9f047d1a65528aedb54b6d3b 100644 (file)
@@ -33,7 +33,7 @@ class _StrBinaryDumper(_BaseStrDumper):
 
     format = Format.BINARY
 
-    def dump(self, obj: str) -> bytes:
+    def dump(self, obj: str) -> Optional[Buffer]:
         # the server will raise DataError subclass if the string contains 0x00
         return obj.encode(self._encoding)
 
@@ -45,7 +45,7 @@ class _StrDumper(_BaseStrDumper):
     Subclasses shall specify the oids of real types (text, varchar, name...).
     """
 
-    def dump(self, obj: str) -> bytes:
+    def dump(self, obj: str) -> Optional[Buffer]:
         if "\x00" in obj:
             raise DataError("PostgreSQL text fields cannot contain NUL (0x00) bytes")
         else:
@@ -132,11 +132,13 @@ class BytesDumper(Dumper):
         super().__init__(cls, context)
         self._esc = Escaping(self.connection.pgconn if self.connection else None)
 
-    def dump(self, obj: Buffer) -> Buffer:
+    def dump(self, obj: Buffer) -> Optional[Buffer]:
         return self._esc.escape_bytea(obj)
 
-    def quote(self, obj: Buffer) -> bytes:
+    def quote(self, obj: Buffer) -> Buffer:
         escaped = self.dump(obj)
+        if escaped is None:
+            return b"NULL"
 
         # We cannot use the base quoting because escape_bytea already returns
         # the quotes content. if scs is off it will escape the backslashes in
@@ -165,7 +167,7 @@ class BytesBinaryDumper(Dumper):
     format = Format.BINARY
     oid = _oids.BYTEA_OID
 
-    def dump(self, obj: Buffer) -> Buffer:
+    def dump(self, obj: Buffer) -> Optional[Buffer]:
         return obj
 
 
index 6e0d8156b9b2e44192d0bdb9e41738d92ac74a84..df4509ee595b4831bec2043e8d2ec2ab54d0c6b2 100644 (file)
@@ -21,14 +21,14 @@ UUID: Callable[..., "uuid.UUID"] = None  # type: ignore[assignment]
 class UUIDDumper(Dumper):
     oid = _oids.UUID_OID
 
-    def dump(self, obj: "uuid.UUID") -> bytes:
+    def dump(self, obj: "uuid.UUID") -> Optional[Buffer]:
         return obj.hex.encode()
 
 
 class UUIDBinaryDumper(UUIDDumper):
     format = Format.BINARY
 
-    def dump(self, obj: "uuid.UUID") -> bytes:
+    def dump(self, obj: "uuid.UUID") -> Optional[Buffer]:
         return obj.bytes
 
 
index cfd90741a989be52a84af19030728ff48fbf1052..e32ef19d1d104dfd6b1e1b21a5ee1993d3d127bf 100644 (file)
@@ -58,14 +58,14 @@ cdef class CDumper:
         """
         raise NotImplementedError()
 
-    def dump(self, obj):
+    def dump(self, obj) -> Optional[Buffer]:
         """Return the Postgres representation of *obj* as Python array of bytes"""
         cdef rv = PyByteArray_FromStringAndSize("", 0)
         cdef Py_ssize_t length = self.cdump(obj, rv, 0)
         PyByteArray_Resize(rv, length)
         return rv
 
-    def quote(self, obj):
+    def quote(self, obj) -> Buffer:
         cdef char *ptr
         cdef char *ptr_out
         cdef Py_ssize_t length
index 86cf88e948a5ab01b2546afc3a6a96cb02dc7952..5b3a7c387f30b132e1b6f30b6b0842fe92b0a014 100644 (file)
@@ -28,7 +28,7 @@ cdef class BoolDumper(CDumper):
 
         return 1
 
-    def quote(self, obj: bool) -> bytes:
+    def quote(self, obj: bool) -> Optional[Buffer]:
         if obj is True:
             return b"true"
         elif obj is False:
index f9580fd84a98a078b79e590283a8b4845090a1aa..f2f2c2ac1e5216c29cd1e128d9faf26abde124d4 100644 (file)
@@ -55,7 +55,7 @@ cdef class _IntDumper(CDumper):
     cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
         return dump_int_to_text(obj, rv, offset)
 
-    def quote(self, obj) -> bytearray:
+    def quote(self, obj) -> Optional[Buffer]:
         cdef Py_ssize_t length
 
         rv = PyByteArray_FromStringAndSize("", 0)
@@ -311,7 +311,7 @@ cdef class _FloatDumper(CDumper):
         PyMem_Free(out)
         return length
 
-    def quote(self, obj) -> bytes:
+    def quote(self, obj) -> Optional[Buffer]:
         value = bytes(self.dump(obj))
         cdef PyObject *ptr = PyDict_GetItem(_special_float, value)
         if ptr != NULL:
@@ -417,7 +417,7 @@ cdef class DecimalDumper(CDumper):
     cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
         return dump_decimal_to_text(obj, rv, offset)
 
-    def quote(self, obj) -> bytes:
+    def quote(self, obj) -> Optional[Buffer]:
         value = bytes(self.dump(obj))
         cdef PyObject *ptr = PyDict_GetItem(_special_decimal, value)
         if ptr != NULL:
index da18b015a25564f5282be789a8940433bb5205cb..170b17b7db36d5af83a6535702f6f19c559cddaa 100644 (file)
@@ -215,7 +215,7 @@ cdef class BytesDumper(CDumper):
         libpq.PQfreemem(out)
         return len_out
 
-    def quote(self, obj):
+    def quote(self, obj) -> Buffer:
         cdef size_t len_out
         cdef unsigned char *out
         cdef char *ptr
index 596b837f177f9dd303bbf3c05203a6a96112cb21..286e6e4408ab5ec29d6f5bcbb5f292eda0e23345 100644 (file)
@@ -308,7 +308,9 @@ def test_subclass_adapter(conn, format):
     class MyStrDumper(BaseDumper):
 
         def dump(self, obj):
-            return super().dump(obj) * 2
+            rv = super().dump(obj)
+            assert rv
+            return bytes(rv) * 2
 
     conn.adapters.register_dumper(str, MyStrDumper)
 
index 6d1d4956614ee0b1b048ae20b85760354b578eac..b06e880ae5ba771351716b3e28de6f02422fb053 100644 (file)
@@ -317,7 +317,9 @@ async def test_subclass_adapter(aconn, format):
 
     class MyStrDumper(BaseDumper):
         def dump(self, obj):
-            return super().dump(obj) * 2
+            rv = super().dump(obj)
+            assert rv
+            return bytes(rv) * 2
 
     aconn.adapters.register_dumper(str, MyStrDumper)