]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
lint: fix type hints with disabled bytearray/memoryview/bytes equivalience
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 4 Nov 2022 01:12:01 +0000 (02:12 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 4 Nov 2022 14:06:37 +0000 (15:06 +0100)
This changeset makes the code compatible with the current Mypy 0.981,
but passes most of the checks that Mypy 0.990 enforces if the byte
strings equivalence is disabled.

18 files changed:
psycopg/psycopg/_struct.py
psycopg/psycopg/_transform.py
psycopg/psycopg/abc.py
psycopg/psycopg/copy.py
psycopg/psycopg/generators.py
psycopg/psycopg/pq/abc.py
psycopg/psycopg/pq/pq_ctypes.py
psycopg/psycopg/types/array.py
psycopg/psycopg/types/composite.py
psycopg/psycopg/types/datetime.py
psycopg/psycopg/types/hstore.py
psycopg/psycopg/types/json.py
psycopg/psycopg/types/multirange.py
psycopg/psycopg/types/numeric.py
psycopg/psycopg/types/range.py
psycopg/psycopg/types/string.py
psycopg_c/psycopg_c/_psycopg.pyi
psycopg_c/psycopg_c/_psycopg/transform.pyx

index 5ccb7703b63530d906731bb973bd652d13ee42df..191c4013dbe101f12653de1fd5dc18a2d8dbcf4b 100644 (file)
@@ -12,9 +12,9 @@ from . import errors as e
 from ._compat import Protocol, TypeAlias
 
 PackInt: TypeAlias = Callable[[int], bytes]
-UnpackInt: TypeAlias = Callable[[bytes], Tuple[int]]
+UnpackInt: TypeAlias = Callable[[Buffer], Tuple[int]]
 PackFloat: TypeAlias = Callable[[float], bytes]
-UnpackFloat: TypeAlias = Callable[[bytes], Tuple[float]]
+UnpackFloat: TypeAlias = Callable[[Buffer], Tuple[float]]
 
 
 class UnpackLen(Protocol):
index 86cce1913988ca5f650ddff51e737f664e8c3c42..0e9dcaf7a2ce30fc1808a395b8eb6aff8dd1a69c 100644 (file)
@@ -194,7 +194,7 @@ class Transformer(AdaptContext):
 
         return out
 
-    def as_literal(self, obj: Any) -> Buffer:
+    def as_literal(self, obj: Any) -> bytes:
         dumper = self.get_dumper(obj, PY_TEXT)
         rv = dumper.quote(obj)
         # If the result is quoted, and the oid not unknown or text,
@@ -221,6 +221,8 @@ class Transformer(AdaptContext):
             if type_sql:
                 rv = b"%s::%s" % (rv, type_sql)
 
+        if not isinstance(rv, bytes):
+            rv = bytes(rv)
         return rv
 
     def get_dumper(self, obj: Any, format: PyFormat) -> "Dumper":
@@ -321,7 +323,7 @@ class Transformer(AdaptContext):
 
         return make_row(record)
 
-    def load_sequence(self, record: Sequence[Optional[bytes]]) -> Tuple[Any, ...]:
+    def load_sequence(self, record: Sequence[Optional[Buffer]]) -> Tuple[Any, ...]:
         if len(self._row_loaders) != len(record):
             raise e.ProgrammingError(
                 f"cannot load sequence of {len(record)} items:"
index 1edde3db3cab72e2bb4fe82bc2443226cc52b045..570a22363247f470101d6fa017ad4cf8dfdaebaf 100644 (file)
@@ -51,8 +51,8 @@ PQGen: TypeAlias = Generator["Wait", "Ready", RV]
 
 # Adaptation types
 
-DumpFunc: TypeAlias = Callable[[Any], bytes]
-LoadFunc: TypeAlias = Callable[[bytes], Any]
+DumpFunc: TypeAlias = Callable[[Any], Buffer]
+LoadFunc: TypeAlias = Callable[[Buffer], Any]
 
 
 class AdaptContext(Protocol):
@@ -238,7 +238,7 @@ class Transformer(Protocol):
     ) -> Sequence[Optional[Buffer]]:
         ...
 
-    def as_literal(self, obj: Any) -> Buffer:
+    def as_literal(self, obj: Any) -> bytes:
         ...
 
     def get_dumper(self, obj: Any, format: PyFormat) -> Dumper:
@@ -250,7 +250,7 @@ class Transformer(Protocol):
     def load_row(self, row: int, make_row: "RowMaker[Row]") -> Optional["Row"]:
         ...
 
-    def load_sequence(self, record: Sequence[Optional[bytes]]) -> Tuple[Any, ...]:
+    def load_sequence(self, record: Sequence[Optional[Buffer]]) -> Tuple[Any, ...]:
         ...
 
     def get_loader(self, oid: int, format: pq.Format) -> Loader:
index ff36431f90d04c7b8afb1365c494a0e01f235efe..1313a9a71341857723a01b899c7c35e997b14a9d 100644 (file)
@@ -392,7 +392,7 @@ class QueuedLibpqDriver(LibpqWriter):
     def __init__(self, cursor: "Cursor[Any]"):
         super().__init__(cursor)
 
-        self._queue: queue.Queue[bytes] = queue.Queue(maxsize=QUEUE_SIZE)
+        self._queue: queue.Queue[Buffer] = queue.Queue(maxsize=QUEUE_SIZE)
         self._worker: Optional[threading.Thread] = None
         self._worker_error: Optional[BaseException] = None
 
@@ -599,7 +599,7 @@ class AsyncQueuedLibpqWriter(AsyncLibpqWriter):
     def __init__(self, cursor: "AsyncCursor[Any]"):
         super().__init__(cursor)
 
-        self._queue: asyncio.Queue[bytes] = asyncio.Queue(maxsize=QUEUE_SIZE)
+        self._queue: asyncio.Queue[Buffer] = asyncio.Queue(maxsize=QUEUE_SIZE)
         self._worker: Optional[asyncio.Future[None]] = None
 
     async def worker(self) -> None:
@@ -652,19 +652,19 @@ class Formatter(ABC):
         self._row_mode = False  # true if the user is using write_row()
 
     @abstractmethod
-    def parse_row(self, data: bytes) -> Optional[Tuple[Any, ...]]:
+    def parse_row(self, data: Buffer) -> Optional[Tuple[Any, ...]]:
         ...
 
     @abstractmethod
-    def write(self, buffer: Union[Buffer, str]) -> bytes:
+    def write(self, buffer: Union[Buffer, str]) -> Buffer:
         ...
 
     @abstractmethod
-    def write_row(self, row: Sequence[Any]) -> bytes:
+    def write_row(self, row: Sequence[Any]) -> Buffer:
         ...
 
     @abstractmethod
-    def end(self) -> bytes:
+    def end(self) -> Buffer:
         ...
 
 
@@ -676,7 +676,7 @@ class TextFormatter(Formatter):
         super().__init__(transformer)
         self._encoding = encoding
 
-    def parse_row(self, data: bytes) -> Optional[Tuple[Any, ...]]:
+    def parse_row(self, data: Buffer) -> Optional[Tuple[Any, ...]]:
         if data:
             return parse_row_text(data, self.transformer)
         else:
@@ -687,7 +687,7 @@ class TextFormatter(Formatter):
         self._signature_sent = True
         return data
 
-    def write_row(self, row: Sequence[Any]) -> bytes:
+    def write_row(self, row: Sequence[Any]) -> Buffer:
         # Note down that we are writing in row mode: it means we will have
         # to take care of the end-of-copy marker too
         self._row_mode = True
@@ -699,7 +699,7 @@ class TextFormatter(Formatter):
         else:
             return b""
 
-    def end(self) -> bytes:
+    def end(self) -> Buffer:
         buffer, self._write_buffer = self._write_buffer, bytearray()
         return buffer
 
@@ -721,7 +721,7 @@ class BinaryFormatter(Formatter):
         super().__init__(transformer)
         self._signature_sent = False
 
-    def parse_row(self, data: bytes) -> Optional[Tuple[Any, ...]]:
+    def parse_row(self, data: Buffer) -> Optional[Tuple[Any, ...]]:
         if not self._signature_sent:
             if data[: len(_binary_signature)] != _binary_signature:
                 raise e.DataError(
@@ -740,7 +740,7 @@ class BinaryFormatter(Formatter):
         self._signature_sent = True
         return data
 
-    def write_row(self, row: Sequence[Any]) -> bytes:
+    def write_row(self, row: Sequence[Any]) -> Buffer:
         # Note down that we are writing in row mode: it means we will have
         # to take care of the end-of-copy marker too
         self._row_mode = True
@@ -756,7 +756,7 @@ class BinaryFormatter(Formatter):
         else:
             return b""
 
-    def end(self) -> bytes:
+    def end(self) -> Buffer:
         # If we have sent no data we need to send the signature
         # and the trailer
         if not self._signature_sent:
@@ -828,17 +828,17 @@ def _format_row_binary(
     return out
 
 
-def _parse_row_text(data: bytes, tx: Transformer) -> Tuple[Any, ...]:
+def _parse_row_text(data: Buffer, tx: Transformer) -> Tuple[Any, ...]:
     if not isinstance(data, bytes):
         data = bytes(data)
-    fields = data.split(b"\t")
-    fields[-1] = fields[-1][:-1]  # drop \n
+    fields = data.split(b"\t")  # type: ignore
+    fields[-1] = fields[-1][:-1]  # type: ignore  # drop \n
     row = [None if f == b"\\N" else _load_re.sub(_load_sub, f) for f in fields]
     return tx.load_sequence(row)
 
 
-def _parse_row_binary(data: bytes, tx: Transformer) -> Tuple[Any, ...]:
-    row: List[Optional[bytes]] = []
+def _parse_row_binary(data: Buffer, tx: Transformer) -> Tuple[Any, ...]:
+    row: List[Optional[Buffer]] = []
     nfields = _unpack_int2(data, 0)[0]
     pos = 2
     for i in range(nfields):
index 96a27d56fdccbd9f4395a3de420eb70d38238690..37cba131acf522ec425d7b3ffb2c3811264dcb5b 100644 (file)
@@ -20,7 +20,7 @@ from typing import List, Optional, Union
 
 from . import pq
 from . import errors as e
-from .abc import PipelineCommand, PQGen, PQGenConn
+from .abc import Buffer, PipelineCommand, PQGen, PQGenConn
 from .pq.abc import PGconn, PGresult
 from .waiting import Wait, Ready
 from ._compat import Deque
@@ -271,7 +271,7 @@ def copy_from(pgconn: PGconn) -> PQGen[Union[memoryview, PGresult]]:
     return result
 
 
-def copy_to(pgconn: PGconn, buffer: bytes) -> PQGen[None]:
+def copy_to(pgconn: PGconn, buffer: Buffer) -> PQGen[None]:
     # Retry enqueuing data until successful.
     #
     # WARNING! This can cause an infinite loop if the buffer is too large. (see
index c59e43d593486ff50ec1396fe557a8acf17e44a9..9ee21c288437352cc1262779cf5a7ee2fadc5f31 100644 (file)
@@ -133,7 +133,7 @@ class PGconn(Protocol):
     def exec_params(
         self,
         command: bytes,
-        param_values: Optional[Sequence[Optional[bytes]]],
+        param_values: Optional[Sequence[Optional[Buffer]]],
         param_types: Optional[Sequence[int]] = None,
         param_formats: Optional[Sequence[int]] = None,
         result_format: int = Format.TEXT,
@@ -143,7 +143,7 @@ class PGconn(Protocol):
     def send_query_params(
         self,
         command: bytes,
-        param_values: Optional[Sequence[Optional[bytes]]],
+        param_values: Optional[Sequence[Optional[Buffer]]],
         param_types: Optional[Sequence[int]] = None,
         param_formats: Optional[Sequence[int]] = None,
         result_format: int = Format.TEXT,
@@ -161,7 +161,7 @@ class PGconn(Protocol):
     def send_query_prepared(
         self,
         name: bytes,
-        param_values: Optional[Sequence[Optional[bytes]]],
+        param_values: Optional[Sequence[Optional[Buffer]]],
         param_formats: Optional[Sequence[int]] = None,
         result_format: int = Format.TEXT,
     ) -> None:
@@ -178,7 +178,7 @@ class PGconn(Protocol):
     def exec_prepared(
         self,
         name: bytes,
-        param_values: Optional[Sequence[bytes]],
+        param_values: Optional[Sequence[Buffer]],
         param_formats: Optional[Sequence[int]] = None,
         result_format: int = 0,
     ) -> "PGresult":
@@ -225,7 +225,7 @@ class PGconn(Protocol):
     def notifies(self) -> Optional["PGnotify"]:
         ...
 
-    def put_copy_data(self, buffer: bytes) -> int:
+    def put_copy_data(self, buffer: Buffer) -> int:
         ...
 
     def put_copy_end(self, error: Optional[bytes] = None) -> int:
@@ -380,5 +380,5 @@ class Escaping(Protocol):
     def escape_bytea(self, data: Buffer) -> bytes:
         ...
 
-    def unescape_bytea(self, data: bytes) -> bytes:
+    def unescape_bytea(self, data: Buffer) -> bytes:
         ...
index dccb0afbd6c0ba5a8c79f1ac1e0e260286ca9a7c..51462fea85acb2d5270a8bf32be01b5b8ea15c04 100644 (file)
@@ -275,7 +275,7 @@ class PGconn:
     def exec_params(
         self,
         command: bytes,
-        param_values: Optional[Sequence[Optional[bytes]]],
+        param_values: Optional[Sequence[Optional["abc.Buffer"]]],
         param_types: Optional[Sequence[int]] = None,
         param_formats: Optional[Sequence[int]] = None,
         result_format: int = Format.TEXT,
@@ -292,7 +292,7 @@ class PGconn:
     def send_query_params(
         self,
         command: bytes,
-        param_values: Optional[Sequence[Optional[bytes]]],
+        param_values: Optional[Sequence[Optional["abc.Buffer"]]],
         param_types: Optional[Sequence[int]] = None,
         param_formats: Optional[Sequence[int]] = None,
         result_format: int = Format.TEXT,
@@ -329,7 +329,7 @@ class PGconn:
     def send_query_prepared(
         self,
         name: bytes,
-        param_values: Optional[Sequence[Optional[bytes]]],
+        param_values: Optional[Sequence[Optional["abc.Buffer"]]],
         param_formats: Optional[Sequence[int]] = None,
         result_format: int = Format.TEXT,
     ) -> None:
@@ -349,7 +349,7 @@ class PGconn:
     def _query_params_args(
         self,
         command: bytes,
-        param_values: Optional[Sequence[Optional[bytes]]],
+        param_values: Optional[Sequence[Optional["abc.Buffer"]]],
         param_types: Optional[Sequence[int]] = None,
         param_formats: Optional[Sequence[int]] = None,
         result_format: int = Format.TEXT,
@@ -364,7 +364,6 @@ class PGconn:
             aparams = (c_char_p * nparams)(
                 *(
                     # convert bytearray/memoryview to bytes
-                    # TODO: avoid copy, at least in the C implementation.
                     b
                     if b is None or isinstance(b, bytes)
                     else bytes(b)  # type: ignore[arg-type]
@@ -436,7 +435,7 @@ class PGconn:
     def exec_prepared(
         self,
         name: bytes,
-        param_values: Optional[Sequence[bytes]],
+        param_values: Optional[Sequence["abc.Buffer"]],
         param_formats: Optional[Sequence[int]] = None,
         result_format: int = 0,
     ) -> "PGresult":
@@ -447,7 +446,13 @@ class PGconn:
         alenghts: Optional[Array[c_int]]
         if param_values:
             nparams = len(param_values)
-            aparams = (c_char_p * nparams)(*param_values)
+            aparams = (c_char_p * nparams)(
+                *(
+                    # convert bytearray/memoryview to bytes
+                    b if b is None or isinstance(b, bytes) else bytes(b)
+                    for b in param_values
+                )
+            )
             alenghts = (c_int * nparams)(*(len(p) if p else 0 for p in param_values))
         else:
             nparams = 0
@@ -1050,13 +1055,15 @@ class Escaping:
         impl.PQfreemem(out)
         return rv
 
-    def unescape_bytea(self, data: bytes) -> bytes:
+    def unescape_bytea(self, data: "abc.Buffer") -> bytes:
         # not needed, but let's keep it symmetric with the escaping:
         # if a connection is passed in, it must be valid.
         if self.conn:
             self.conn._ensure_pgconn()
 
         len_out = c_size_t()
+        if not isinstance(data, bytes):
+            data = bytes(data)
         out = impl.PQunescapeBytea(
             data,
             byref(t_cast(c_ulong, len_out)),  # type: ignore[arg-type]
index 5ecbac1fa58678d3347a80b5c75e494ad6668888..566f253c8bb7d6257038630025c5833962f9ef96 100644 (file)
@@ -21,10 +21,10 @@ from .._typeinfo import TypeInfo
 
 _struct_head = struct.Struct("!III")  # ndims, hasnull, elem oid
 _pack_head = cast(Callable[[int, int, int], bytes], _struct_head.pack)
-_unpack_head = cast(Callable[[bytes], Tuple[int, int, int]], _struct_head.unpack_from)
+_unpack_head = cast(Callable[[Buffer], Tuple[int, int, int]], _struct_head.unpack_from)
 _struct_dim = struct.Struct("!II")  # dim, lower bound
 _pack_dim = cast(Callable[[int, int], bytes], _struct_dim.pack)
-_unpack_dim = cast(Callable[[bytes, int], Tuple[int, int]], _struct_dim.unpack_from)
+_unpack_dim = cast(Callable[[Buffer, int], Tuple[int, int]], _struct_dim.unpack_from)
 
 TEXT_ARRAY_OID = postgres.types["text"].array_oid
 
@@ -153,7 +153,7 @@ class ListDumper(BaseListDumper):
     _re_esc = re.compile(rb'(["\\])')
 
     def dump(self, obj: List[Any]) -> bytes:
-        tokens: List[bytes] = []
+        tokens: List[Buffer] = []
         needs_quotes = _get_needs_quotes_regexp(self.delimiter).search
 
         def dump_list(obj: List[Any]) -> None:
@@ -249,7 +249,7 @@ class ListBinaryDumper(BaseListDumper):
         if not obj:
             return _pack_head(0, 0, sub_oid)
 
-        data: List[bytes] = [b"", b""]  # placeholders to avoid a resize
+        data: List[Buffer] = [b"", b""]  # placeholders to avoid a resize
         dims: List[int] = []
         hasnull = 0
 
index 36660e324535a454a75bdb4e9fbf3a80ee6c54df..1c609c3079ec081a65d9bde774ff5f48684e5019 100644 (file)
@@ -22,7 +22,7 @@ from .._encodings import _as_python_identifier
 _struct_oidlen = struct.Struct("!Ii")
 _pack_oidlen = cast(Callable[[int, int], bytes], _struct_oidlen.pack)
 _unpack_oidlen = cast(
-    Callable[[bytes, int], Tuple[int, int]], _struct_oidlen.unpack_from
+    Callable[[Buffer, int], Tuple[int, int]], _struct_oidlen.unpack_from
 )
 
 
@@ -33,7 +33,7 @@ class SequenceDumper(RecursiveDumper):
         if not obj:
             return start + end
 
-        parts = [start]
+        parts: List[Buffer] = [start]
 
         for item in obj:
             if item is None:
@@ -100,7 +100,7 @@ class BaseCompositeLoader(Loader):
         super().__init__(oid, context)
         self._tx = Transformer(context)
 
-    def _parse_record(self, data: bytes) -> Iterator[Optional[bytes]]:
+    def _parse_record(self, data: Buffer) -> Iterator[Optional[bytes]]:
         """
         Split a non-empty representation of a composite type into components.
 
@@ -163,7 +163,7 @@ class RecordBinaryLoader(Loader):
             )
         )
 
-    def _walk_record(self, data: bytes) -> Iterator[Tuple[int, int, int]]:
+    def _walk_record(self, data: Buffer) -> Iterator[Tuple[int, int, int]]:
         """
         Yield a sequence of (oid, offset, length) for the content of the record
         """
@@ -174,7 +174,7 @@ class RecordBinaryLoader(Loader):
             yield oid, i + 8, length
             i += (8 + length) if length > 0 else 8
 
-    def _config_types(self, data: bytes) -> None:
+    def _config_types(self, data: Buffer) -> None:
         oids = [r[0] for r in self._walk_record(data)]
         self._tx.set_loader_types(oids, self.format)
 
@@ -197,7 +197,7 @@ class CompositeLoader(RecordLoader):
             *self._tx.load_sequence(tuple(self._parse_record(data[1:-1])))
         )
 
-    def _config_types(self, data: bytes) -> None:
+    def _config_types(self, data: Buffer) -> None:
         self._tx.set_loader_types(self.fields_types, self.format)
 
 
index 763a2340a6c6e2acfd8b2d9eea17e7641982b18a..f0dfe83eb0d360c8e4fe6908250e49475016900b 100644 (file)
@@ -22,12 +22,12 @@ if TYPE_CHECKING:
 
 _struct_timetz = struct.Struct("!qi")  # microseconds, sec tz offset
 _pack_timetz = cast(Callable[[int, int], bytes], _struct_timetz.pack)
-_unpack_timetz = cast(Callable[[bytes], Tuple[int, int]], _struct_timetz.unpack)
+_unpack_timetz = cast(Callable[[Buffer], Tuple[int, int]], _struct_timetz.unpack)
 
 _struct_interval = struct.Struct("!qii")  # microseconds, days, months
 _pack_interval = cast(Callable[[int, int, int], bytes], _struct_interval.pack)
 _unpack_interval = cast(
-    Callable[[bytes], Tuple[int, int, int]], _struct_interval.unpack
+    Callable[[Buffer], Tuple[int, int, int]], _struct_interval.unpack
 )
 
 utc = timezone.utc
index ed47d0d2dd154f0d464747df0637bb470bf24b98..f1467db167df7c187f9e5c13b2f2b9468699b7e1 100644 (file)
@@ -39,7 +39,7 @@ Hstore: TypeAlias = Dict[str, Optional[str]]
 
 
 class BaseHstoreDumper(RecursiveDumper):
-    def dump(self, obj: Hstore) -> bytes:
+    def dump(self, obj: Hstore) -> Buffer:
         if not obj:
             return b""
 
index 9d090ef237f5ca050d1cb9686f5e1d6416309ddc..a80e0e4e35b01aec145681ffca279c20d9f4618b 100644 (file)
@@ -15,7 +15,7 @@ from ..adapt import Buffer, Dumper, Loader, PyFormat, AdaptersMap
 from ..errors import DataError
 
 JsonDumpsFunction = Callable[[Any], str]
-JsonLoadsFunction = Callable[[Union[str, bytes, bytearray]], Any]
+JsonLoadsFunction = Callable[[Union[str, bytes]], Any]
 
 
 def set_json_dumps(
@@ -170,7 +170,7 @@ class _JsonLoader(Loader):
 
     def load(self, data: Buffer) -> Any:
         # json.loads() cannot work on memoryview.
-        if isinstance(data, memoryview):
+        if not isinstance(data, bytes):
             data = bytes(data)
         return self.loads(data)
 
@@ -195,7 +195,7 @@ class JsonbBinaryLoader(_JsonLoader):
         if data and data[0] != 1:
             raise DataError("unknown jsonb binary format: {data[0]}")
         data = data[1:]
-        if isinstance(data, memoryview):
+        if not isinstance(data, bytes):
             data = bytes(data)
         return self.loads(data)
 
index 5846edcfa8c20f5c5adea04292e19420d943de74..7b9f545706af0d8bb0f4626fef1ef6c20a9ea728 100644 (file)
@@ -224,7 +224,7 @@ class MultirangeDumper(BaseMultirangeDumper):
         else:
             dump = fail_dump
 
-        out = [b"{"]
+        out: List[Buffer] = [b"{"]
         for r in obj:
             out.append(dump_range_text(r, dump))
             out.append(b",")
@@ -243,7 +243,7 @@ class MultirangeBinaryDumper(BaseMultirangeDumper):
         else:
             dump = fail_dump
 
-        out = [pack_len(len(obj))]
+        out: List[Buffer] = [pack_len(len(obj))]
         for r in obj:
             data = dump_range_binary(r, dump)
             out.append(pack_len(len(data)))
index 48e096fa65eb98b83c0038d4c63ef846a1a4c132..1bd9329f5947bb64b5a93514f0af4588778bd11a 100644 (file)
@@ -32,7 +32,7 @@ from .._wrappers import (
 
 
 class _IntDumper(Dumper):
-    def dump(self, obj: Any) -> bytes:
+    def dump(self, obj: Any) -> Buffer:
         t = type(obj)
         if t is not int:
             # Convert to int in order to dump IntEnum correctly
@@ -43,7 +43,7 @@ class _IntDumper(Dumper):
 
         return str(obj).encode()
 
-    def quote(self, obj: Any) -> bytes:
+    def quote(self, obj: Any) -> Buffer:
         value = self.dump(obj)
         return value if obj >= 0 else b" " + value
 
@@ -321,7 +321,7 @@ for i in range(DefaultContext.prec):
     _contexts[i] = DefaultContext
 
 _unpack_numeric_head = cast(
-    Callable[[bytes], Tuple[int, int, int, int]],
+    Callable[[Buffer], Tuple[int, int, int, int]],
     struct.Struct("!HhHH").unpack_from,
 )
 _pack_numeric_head = cast(
index fb4c6e8082111105a0c6ad92a6197d359e8e51e4..ade9b37cee818ffbb95076f525e39375fe12f8ee 100644 (file)
@@ -5,7 +5,7 @@ Support for range types adaptation.
 # Copyright (C) 2020 The Psycopg Team
 
 import re
-from typing import Any, Callable, Dict, Generic, Optional, TypeVar, Type, Tuple
+from typing import Any, Callable, Dict, Generic, List, Optional, TypeVar, Type, Tuple
 from typing import cast
 from decimal import Decimal
 from datetime import date, datetime
@@ -329,7 +329,7 @@ def dump_range_text(obj: Range[Any], dump: Callable[[Any], Buffer]) -> Buffer:
     if obj.isempty:
         return b"empty"
 
-    parts = [b"[" if obj.lower_inc else b"("]
+    parts: List[Buffer] = [b"[" if obj.lower_inc else b"("]
 
     def dump_item(item: Any) -> Buffer:
         ad = dump(item)
index 5cd0d0e072c5275c8fc4a887bbecc7d7772bb5ea..c15c3926a6401ec2f9dfe293fe5e80dfb6cd5364 100644 (file)
@@ -78,11 +78,12 @@ class TextLoader(Loader):
     def load(self, data: Buffer) -> Union[bytes, str]:
         if self._encoding:
             if isinstance(data, memoryview):
-                return bytes(data).decode(self._encoding)
-            else:
-                return data.decode(self._encoding)
+                data = bytes(data)
+            return data.decode(self._encoding)
         else:
             # return bytes for SQL_ASCII db
+            if not isinstance(data, bytes):
+                data = bytes(data)
             return data
 
 
@@ -100,10 +101,10 @@ class BytesDumper(Dumper):
         super().__init__(cls, context)
         self._esc = Escaping(self.connection.pgconn if self.connection else None)
 
-    def dump(self, obj: bytes) -> Buffer:
+    def dump(self, obj: Buffer) -> Buffer:
         return self._esc.escape_bytea(obj)
 
-    def quote(self, obj: bytes) -> bytes:
+    def quote(self, obj: Buffer) -> bytes:
         escaped = self.dump(obj)
 
         # We cannot use the base quoting because escape_bytea already returns
@@ -148,14 +149,14 @@ class ByteaLoader(Loader):
             self.__class__._escaping = Escaping()
 
     def load(self, data: Buffer) -> bytes:
-        return bytes(self._escaping.unescape_bytea(data))
+        return self._escaping.unescape_bytea(data)
 
 
 class ByteaBinaryLoader(Loader):
 
     format = Format.BINARY
 
-    def load(self, data: Buffer) -> bytes:
+    def load(self, data: Buffer) -> Buffer:
         return data
 
 
index a215eee581cdd9fc8364d76ef42bba7b6d74f6d5..0da2cab503195b20443bf2538bd504d4114018c0 100644 (file)
@@ -12,7 +12,6 @@ from typing import Any, Iterable, List, Optional, Sequence, Tuple
 from psycopg import pq
 from psycopg import abc
 from psycopg.rows import Row, RowMaker
-from psycopg.abc import PipelineCommand
 from psycopg.adapt import AdaptersMap, PyFormat
 from psycopg.pq.abc import PGconn, PGresult
 from psycopg.connection import BaseConnection
@@ -44,11 +43,13 @@ class Transformer(abc.AdaptContext):
     def dump_sequence(
         self, params: Sequence[Any], formats: Sequence[PyFormat]
     ) -> Sequence[Optional[abc.Buffer]]: ...
-    def as_literal(self, obj: Any) -> abc.Buffer: ...
+    def as_literal(self, obj: Any) -> bytes: ...
     def get_dumper(self, obj: Any, format: PyFormat) -> abc.Dumper: ...
     def load_rows(self, row0: int, row1: int, make_row: RowMaker[Row]) -> List[Row]: ...
     def load_row(self, row: int, make_row: RowMaker[Row]) -> Optional[Row]: ...
-    def load_sequence(self, record: Sequence[Optional[bytes]]) -> Tuple[Any, ...]: ...
+    def load_sequence(
+        self, record: Sequence[Optional[abc.Buffer]]
+    ) -> Tuple[Any, ...]: ...
     def get_loader(self, oid: int, format: pq.Format) -> abc.Loader: ...
 
 # Generators
@@ -58,7 +59,7 @@ def send(pgconn: PGconn) -> abc.PQGen[None]: ...
 def fetch_many(pgconn: PGconn) -> abc.PQGen[List[PGresult]]: ...
 def fetch(pgconn: PGconn) -> abc.PQGen[Optional[PGresult]]: ...
 def pipeline_communicate(
-    pgconn: PGconn, commands: Deque[PipelineCommand]
+    pgconn: PGconn, commands: Deque[abc.PipelineCommand]
 ) -> abc.PQGen[List[List[PGresult]]]: ...
 
 # Copy support
@@ -68,7 +69,7 @@ def format_row_text(
 def format_row_binary(
     row: Sequence[Any], tx: abc.Transformer, out: Optional[bytearray] = None
 ) -> bytearray: ...
-def parse_row_text(data: bytes, tx: abc.Transformer) -> Tuple[Any, ...]: ...
-def parse_row_binary(data: bytes, tx: abc.Transformer) -> Tuple[Any, ...]: ...
+def parse_row_text(data: abc.Buffer, tx: abc.Transformer) -> Tuple[Any, ...]: ...
+def parse_row_binary(data: abc.Buffer, tx: abc.Transformer) -> Tuple[Any, ...]: ...
 
 # vim: set syntax=python:
index cfa01ca8a5db5b0cb82eb67028864403e7258732..48f338dd830969e0d45c22f4fc90cc05c6056585 100644 (file)
@@ -541,7 +541,7 @@ cdef class Transformer:
                 make_row, <PyObject *>record, NULL)
         return record
 
-    cpdef object load_sequence(self, record: Sequence[Optional[bytes]]):
+    cpdef object load_sequence(self, record: Sequence[Optional[Buffer]]):
         cdef Py_ssize_t nfields = len(record)
         out = PyTuple_New(nfields)
         cdef PyObject *loader  # borrowed RowLoader