formatter: "Formatter"
- def __init__(self, cursor: "BaseCursor[ConnectionType, Any]"):
+ def __init__(
+ self,
+ cursor: "BaseCursor[ConnectionType, Any]",
+ *,
+ binary: Optional[bool] = None,
+ ):
self.cursor = cursor
self.connection = cursor.connection
self._pgconn = self.connection.pgconn
else:
self._direction = COPY_IN
+ if binary is None:
+ binary = bool(result and result.binary_tuples)
+
tx: Transformer = getattr(cursor, "_tx", None) or adapt.Transformer(cursor)
- if result and result.binary_tuples:
+ if binary:
self.formatter = BinaryFormatter(tx)
else:
self.formatter = TextFormatter(tx, encoding=pgconn_encoding(self._pgconn))
writer: "Writer"
- def __init__(self, cursor: "Cursor[Any]", *, writer: Optional["Writer"] = None):
- super().__init__(cursor)
+ def __init__(
+ self,
+ cursor: "Cursor[Any]",
+ *,
+ binary: Optional[bool] = None,
+ writer: Optional["Writer"] = None,
+ ):
+ super().__init__(cursor, binary=binary)
if not writer:
writer = LibpqWriter(cursor)
writer: "AsyncWriter"
def __init__(
- self, cursor: "AsyncCursor[Any]", *, writer: Optional["AsyncWriter"] = None
+ self,
+ cursor: "AsyncCursor[Any]",
+ *,
+ binary: Optional[bool] = None,
+ writer: Optional["AsyncWriter"] = None,
):
- super().__init__(cursor)
+ super().__init__(cursor, binary=binary)
if not writer:
writer = AsyncLibpqWriter(cursor)
import gc
import string
+import struct
import hashlib
from io import BytesIO, StringIO
from random import choice, randrange
pytestmark = pytest.mark.crdb_skip("copy")
-sample_records = [(10, 20, "hello"), (40, None, "world")]
-sample_values = "values (10::int, 20::int, 'hello'::text), (40, NULL, 'world')"
+sample_records = [(40010, 40020, "hello"), (40040, None, "world")]
+sample_values = "values (40010::int, 40020::int, 'hello'::text), (40040, NULL, 'world')"
sample_tabledef = "col1 serial primary key, col2 int, data text"
sample_text = b"""\
-10\t20\thello
-40\t\\N\tworld
+40010\t40020\thello
+40040\t\\N\tworld
"""
sample_binary_str = """
5047 434f 5059 0aff 0d0a 00
00 0000 0000 0000 00
-00 0300 0000 0400 0000 0a00 0000 0400 0000 1400 0000 0568 656c 6c6f
+00 0300 0000 0400 009c 4a00 0000 0400 009c 5400 0000 0568 656c 6c6f
-0003 0000 0004 0000 0028 ffff ffff 0000 0005 776f 726c 64
+0003 0000 0004 0000 9c68 ffff ffff 0000 0005 776f 726c 64
ff ff
"""
sample_binary_rows = [
bytes.fromhex("".join(row.split())) for row in sample_binary_str.split("\n\n")
]
-
sample_binary = b"".join(sample_binary_rows)
special_chars = {8: "b", 9: "t", 10: "n", 11: "v", 12: "f", 13: "r", ord("\\"): "\\"}
assert fields[1].decode() == chr(i)
+@pytest.mark.parametrize(
+ "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")]
+)
+def test_file_writer(conn, format, buffer):
+ file = BytesIO()
+ conn.execute("set client_encoding to utf8")
+ cur = conn.cursor()
+ with Copy(cur, binary=format, writer=FileWriter(file)) as copy:
+ for record in sample_records:
+ copy.write_row(record)
+
+ file.seek(0)
+ want = globals()[buffer]
+ got = file.read()
+ assert got == want
+
+
@pytest.mark.slow
def test_copy_from_to(conn):
# Roundtrip from file to database to file blockwise
@pytest.mark.parametrize(
- "format, buffer",
- [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
+ "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")]
)
def test_worker_life(conn, format, buffer):
cur = conn.cursor()
@pytest.mark.parametrize(
- "format, buffer",
- [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
+ "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")]
)
def test_connection_writer(conn, format, buffer):
cur = conn.cursor()
return str(item)
else:
if isinstance(item, int):
- return bytes([0, 0, 0, item])
+ # Assume int4
+ return struct.pack("!i", item)
elif isinstance(item, str):
return item.encode()
return item
@pytest.mark.parametrize(
- "format, buffer",
- [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
+ "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")]
)
async def test_copy_in_buffers(aconn, format, buffer):
cur = aconn.cursor()
async def test_copy_in_format(aconn):
- writer = AsyncBytesWriter()
+ file = BytesIO()
await aconn.execute("set client_encoding to utf8")
cur = aconn.cursor()
- async with AsyncCopy(cur, writer=writer) as copy:
+ async with AsyncCopy(cur, writer=AsyncFileWriter(file)) as copy:
for i in range(1, 256):
await copy.write_row((i, chr(i)))
- writer.file.seek(0)
- rows = writer.file.read().split(b"\n")
+ file.seek(0)
+ rows = file.read().split(b"\n")
assert not rows[-1]
del rows[-1]
assert fields[1].decode() == chr(i)
+@pytest.mark.parametrize(
+ "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")]
+)
+async def test_file_writer(aconn, format, buffer):
+ file = BytesIO()
+ await aconn.execute("set client_encoding to utf8")
+ cur = aconn.cursor()
+ async with AsyncCopy(cur, binary=format, writer=AsyncFileWriter(file)) as copy:
+ for record in sample_records:
+ await copy.write_row(record)
+
+ file.seek(0)
+ want = globals()[buffer]
+ got = file.read()
+ assert got == want
+
+
@pytest.mark.slow
async def test_copy_from_to(aconn):
# Roundtrip from file to database to file blockwise
@pytest.mark.parametrize(
- "format, buffer",
- [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
+ "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")]
)
async def test_worker_life(aconn, format, buffer):
cur = aconn.cursor()
@pytest.mark.parametrize(
- "format, buffer",
- [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
+ "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")]
)
async def test_connection_writer(aconn, format, buffer):
cur = aconn.cursor()
return m.hexdigest()
-class AsyncBytesWriter(AsyncWriter):
- def __init__(self):
- self.file = BytesIO()
+class AsyncFileWriter(AsyncWriter):
+ def __init__(self, file):
+ self.file = file
async def write(self, data):
self.file.write(data)