import time
import datetime as dt
from math import floor
-from typing import Any, Optional, Sequence
+from typing import Any, Sequence
from . import postgres
-from .pq import Format, Escaping
-from .abc import AdaptContext
-from .adapt import Dumper
+from .abc import Buffer
+from .types.string import BytesDumper, BytesBinaryDumper
class DBAPITypeObject:
return f"{self.__class__.__name__}({sobj})"
-class BinaryBinaryDumper(Dumper):
+class BinaryBinaryDumper(BytesBinaryDumper):
+ def dump(self, obj: Binary) -> Buffer: # type: ignore
+ return super().dump(obj.obj)
- format = Format.BINARY
- _oid = postgres.types["bytea"].oid
- def dump(self, obj: Binary) -> bytes:
- wrapped = obj.obj
- if isinstance(wrapped, bytes):
- return wrapped
- else:
- return bytes(wrapped)
-
-
-class BinaryTextDumper(BinaryBinaryDumper):
-
- format = Format.TEXT
-
- def __init__(self, cls: type, context: Optional[AdaptContext] = None):
- super().__init__(cls, context)
- self._esc = Escaping(
- self.connection.pgconn if self.connection else None
- )
-
- def dump(self, obj: Binary) -> bytes:
- data = super().dump(obj)
- return self._esc.escape_bytea(data)
+class BinaryTextDumper(BytesDumper):
+ def dump(self, obj: Binary) -> Buffer: # type: ignore
+ return super().dump(obj.obj)
def Date(year: int, month: int, day: int) -> dt.date:
format = Format.BINARY
_oid = postgres.types["bytea"].oid
- def dump(
- self, obj: Union[bytes, bytearray, memoryview]
- ) -> Union[bytes, bytearray, memoryview]:
- # TODO: mypy doesn't complain, but this function has the wrong signature
+ def dump(self, obj: Buffer) -> Buffer:
return obj
@pytest.mark.parametrize("scs", ["on", "off"])
-def test_quote_1byte(conn, scs):
+@pytest.mark.parametrize("pytype", [bytes, bytearray, memoryview, Binary])
+def test_quote_1byte(conn, scs, pytype):
messages = []
conn.add_notice_handler(lambda msg: messages.append(msg.message_primary))
conn.execute(f"set standard_conforming_strings to {scs}")
cur = conn.cursor()
query = sql.SQL("select {ch} = set_byte('x', 0, %s)")
for i in range(0, 256):
- cur.execute(query.format(ch=sql.Literal(bytes([i]))), (i,))
+ obj = pytype(bytes([i]))
+ cur.execute(query.format(ch=sql.Literal(obj)), (i,))
assert cur.fetchone()[0] is True, i
# No "nonstandard use of \\ in a string literal" warning