]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Fix quote of Binary wrapper
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 26 Jul 2021 00:53:18 +0000 (02:53 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 26 Jul 2021 00:53:18 +0000 (02:53 +0200)
psycopg/psycopg/dbapi20.py
psycopg/psycopg/types/string.py
tests/types/test_string.py

index 78f8f66fc923bf6ba22a3b7f5c80d224cf307bbf..a0a87590e8370b050624a2b3782b2748095352af 100644 (file)
@@ -7,12 +7,11 @@ Compatibility objects with DBAPI 2.0
 import time
 import datetime as dt
 from math import floor
-from typing import Any, Optional, Sequence
+from typing import Any, Sequence
 
 from . import postgres
-from .pq import Format, Escaping
-from .abc import AdaptContext
-from .adapt import Dumper
+from .abc import Buffer
+from .types.string import BytesDumper, BytesBinaryDumper
 
 
 class DBAPITypeObject:
@@ -58,32 +57,14 @@ class Binary:
         return f"{self.__class__.__name__}({sobj})"
 
 
-class BinaryBinaryDumper(Dumper):
+class BinaryBinaryDumper(BytesBinaryDumper):
+    def dump(self, obj: Binary) -> Buffer:  # type: ignore
+        return super().dump(obj.obj)
 
-    format = Format.BINARY
-    _oid = postgres.types["bytea"].oid
 
-    def dump(self, obj: Binary) -> bytes:
-        wrapped = obj.obj
-        if isinstance(wrapped, bytes):
-            return wrapped
-        else:
-            return bytes(wrapped)
-
-
-class BinaryTextDumper(BinaryBinaryDumper):
-
-    format = Format.TEXT
-
-    def __init__(self, cls: type, context: Optional[AdaptContext] = None):
-        super().__init__(cls, context)
-        self._esc = Escaping(
-            self.connection.pgconn if self.connection else None
-        )
-
-    def dump(self, obj: Binary) -> bytes:
-        data = super().dump(obj)
-        return self._esc.escape_bytea(data)
+class BinaryTextDumper(BytesDumper):
+    def dump(self, obj: Binary) -> Buffer:  # type: ignore
+        return super().dump(obj.obj)
 
 
 def Date(year: int, month: int, day: int) -> dt.date:
index 3833f17e2f62719dde71504e9c0cdeeec1aea6a3..47eb4af073348b796745df2956d3d563a890871f 100644 (file)
@@ -129,10 +129,7 @@ class BytesBinaryDumper(Dumper):
     format = Format.BINARY
     _oid = postgres.types["bytea"].oid
 
-    def dump(
-        self, obj: Union[bytes, bytearray, memoryview]
-    ) -> Union[bytes, bytearray, memoryview]:
-        # TODO: mypy doesn't complain, but this function has the wrong signature
+    def dump(self, obj: Buffer) -> Buffer:
         return obj
 
 
index 4c8c1444949a6af06324fe3463b74cc77a10b733..37fc064c73152c377a4a9282507f6aab9d95fdb7 100644 (file)
@@ -224,7 +224,8 @@ def test_dump_1byte(conn, fmt_in, pytype):
 
 
 @pytest.mark.parametrize("scs", ["on", "off"])
-def test_quote_1byte(conn, scs):
+@pytest.mark.parametrize("pytype", [bytes, bytearray, memoryview, Binary])
+def test_quote_1byte(conn, scs, pytype):
     messages = []
     conn.add_notice_handler(lambda msg: messages.append(msg.message_primary))
     conn.execute(f"set standard_conforming_strings to {scs}")
@@ -233,7 +234,8 @@ def test_quote_1byte(conn, scs):
     cur = conn.cursor()
     query = sql.SQL("select {ch} = set_byte('x', 0, %s)")
     for i in range(0, 256):
-        cur.execute(query.format(ch=sql.Literal(bytes([i]))), (i,))
+        obj = pytype(bytes([i]))
+        cur.execute(query.format(ch=sql.Literal(obj)), (i,))
         assert cur.fetchone()[0] is True, i
 
     # No "nonstandard use of \\ in a string literal" warning