From: Daniele Varrazzo Date: Mon, 26 Jul 2021 00:53:18 +0000 (+0200) Subject: Fix quote of Binary wrapper X-Git-Tag: 3.0.dev2~31 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=ca4653bf00e92b2c218e7095a3bae3cfb776be18;p=thirdparty%2Fpsycopg.git Fix quote of Binary wrapper --- diff --git a/psycopg/psycopg/dbapi20.py b/psycopg/psycopg/dbapi20.py index 78f8f66fc..a0a87590e 100644 --- a/psycopg/psycopg/dbapi20.py +++ b/psycopg/psycopg/dbapi20.py @@ -7,12 +7,11 @@ Compatibility objects with DBAPI 2.0 import time import datetime as dt from math import floor -from typing import Any, Optional, Sequence +from typing import Any, Sequence from . import postgres -from .pq import Format, Escaping -from .abc import AdaptContext -from .adapt import Dumper +from .abc import Buffer +from .types.string import BytesDumper, BytesBinaryDumper class DBAPITypeObject: @@ -58,32 +57,14 @@ class Binary: return f"{self.__class__.__name__}({sobj})" -class BinaryBinaryDumper(Dumper): +class BinaryBinaryDumper(BytesBinaryDumper): + def dump(self, obj: Binary) -> Buffer: # type: ignore + return super().dump(obj.obj) - format = Format.BINARY - _oid = postgres.types["bytea"].oid - def dump(self, obj: Binary) -> bytes: - wrapped = obj.obj - if isinstance(wrapped, bytes): - return wrapped - else: - return bytes(wrapped) - - -class BinaryTextDumper(BinaryBinaryDumper): - - format = Format.TEXT - - def __init__(self, cls: type, context: Optional[AdaptContext] = None): - super().__init__(cls, context) - self._esc = Escaping( - self.connection.pgconn if self.connection else None - ) - - def dump(self, obj: Binary) -> bytes: - data = super().dump(obj) - return self._esc.escape_bytea(data) +class BinaryTextDumper(BytesDumper): + def dump(self, obj: Binary) -> Buffer: # type: ignore + return super().dump(obj.obj) def Date(year: int, month: int, day: int) -> dt.date: diff --git a/psycopg/psycopg/types/string.py b/psycopg/psycopg/types/string.py index 3833f17e2..47eb4af07 100644 --- a/psycopg/psycopg/types/string.py +++ b/psycopg/psycopg/types/string.py @@ -129,10 +129,7 @@ class BytesBinaryDumper(Dumper): format = Format.BINARY _oid = postgres.types["bytea"].oid - def dump( - self, obj: Union[bytes, bytearray, memoryview] - ) -> Union[bytes, bytearray, memoryview]: - # TODO: mypy doesn't complain, but this function has the wrong signature + def dump(self, obj: Buffer) -> Buffer: return obj diff --git a/tests/types/test_string.py b/tests/types/test_string.py index 4c8c14449..37fc064c7 100644 --- a/tests/types/test_string.py +++ b/tests/types/test_string.py @@ -224,7 +224,8 @@ def test_dump_1byte(conn, fmt_in, pytype): @pytest.mark.parametrize("scs", ["on", "off"]) -def test_quote_1byte(conn, scs): +@pytest.mark.parametrize("pytype", [bytes, bytearray, memoryview, Binary]) +def test_quote_1byte(conn, scs, pytype): messages = [] conn.add_notice_handler(lambda msg: messages.append(msg.message_primary)) conn.execute(f"set standard_conforming_strings to {scs}") @@ -233,7 +234,8 @@ def test_quote_1byte(conn, scs): cur = conn.cursor() query = sql.SQL("select {ch} = set_byte('x', 0, %s)") for i in range(0, 256): - cur.execute(query.format(ch=sql.Literal(bytes([i]))), (i,)) + obj = pytype(bytes([i])) + cur.execute(query.format(ch=sql.Literal(obj)), (i,)) assert cur.fetchone()[0] is True, i # No "nonstandard use of \\ in a string literal" warning