from . import pq
from . import errors as e
from .pq import ExecStatus
-from .abc import ConnectionType, PQGen, Transformer
+from .abc import Buffer, ConnectionType, PQGen, Transformer
from .adapt import PyFormat
from ._compat import create_task
from ._cmodule import _psycopg
"""
return self.connection.wait(self._read_row_gen())
- def write(self, buffer: Union[str, bytes]) -> None:
+ def write(self, buffer: Union[Buffer, str]) -> None:
"""
Write a block of data to a table after a :sql:`COPY FROM` operation.
# Propagate the error to the main thread.
self._worker_error = ex
- def _write(self, data: bytes) -> None:
+ def _write(self, data: Buffer) -> None:
if not data:
return
async def read_row(self) -> Optional[Tuple[Any, ...]]:
return await self.connection.wait(self._read_row_gen())
- async def write(self, buffer: Union[str, bytes]) -> None:
+ async def write(self, buffer: Union[Buffer, str]) -> None:
data = self.formatter.write(buffer)
await self._write(data)
break
await self.connection.wait(copy_to(self._pgconn, data))
- async def _write(self, data: bytes) -> None:
+ async def _write(self, data: Buffer) -> None:
if not data:
return
...
@abstractmethod
- def write(self, buffer: Union[str, bytes]) -> bytes:
+ def write(self, buffer: Union[Buffer, str]) -> bytes:
...
@abstractmethod
else:
return None
- def write(self, buffer: Union[str, bytes]) -> bytes:
+ def write(self, buffer: Union[Buffer, str]) -> Buffer:
data = self._ensure_bytes(buffer)
self._signature_sent = True
return data
buffer, self._write_buffer = self._write_buffer, bytearray()
return buffer
- def _ensure_bytes(self, data: Union[bytes, str]) -> bytes:
- if isinstance(data, bytes):
- return data
-
- elif isinstance(data, str):
+ def _ensure_bytes(self, data: Union[Buffer, str]) -> Buffer:
+ if isinstance(data, str):
return data.encode(self._encoding)
-
else:
- raise TypeError(f"can't write {type(data).__name__}")
+ # Assume, for simplicity, that the user is not passing stupid
+ # things to the write function. If that's the case, things
+ # will fail downstream.
+ return data
class BinaryFormatter(Formatter):
return parse_row_binary(data, self.transformer)
- def write(self, buffer: Union[str, bytes]) -> bytes:
+ def write(self, buffer: Union[Buffer, str]) -> Buffer:
data = self._ensure_bytes(buffer)
self._signature_sent = True
return data
buffer, self._write_buffer = self._write_buffer, bytearray()
return buffer
- def _ensure_bytes(self, data: Union[bytes, str]) -> bytes:
- if isinstance(data, bytes):
- return data
-
- elif isinstance(data, str):
+ def _ensure_bytes(self, data: Union[Buffer, str]) -> Buffer:
+ if isinstance(data, str):
raise TypeError("cannot copy str data in binary mode: use bytes instead")
-
else:
- raise TypeError(f"can't write {type(data).__name__}")
+ # Assume, for simplicity, that the user is not passing stupid
+ # things to the write function. If that's the case, things
+ # will fail downstream.
+ return data
def _format_row_text(
@pytest.mark.slow
-def test_copy_big_size_block(conn):
+@pytest.mark.parametrize("pytype", [str, bytes, bytearray, memoryview])
+def test_copy_big_size_block(conn, pytype):
cur = conn.cursor()
ensure_table(cur, sample_tabledef)
data = "".join(choice(string.ascii_letters) for i in range(10 * 1024 * 1024))
+ copy_data = data + "\n" if pytype is str else pytype(data.encode() + b"\n")
with cur.copy("copy copy_in (data) from stdin") as copy:
- copy.write(data + "\n")
+ copy.write(copy_data)
cur.execute("select data from copy_in limit 1")
assert cur.fetchone()[0] == data
@pytest.mark.slow
-def test_copy_from_to_bytes(conn):
+@pytest.mark.parametrize("pytype", [bytes, bytearray, memoryview])
+def test_copy_from_to_bytes(conn, pytype):
# Roundtrip from file to database to file blockwise
gen = DataGenerator(conn, nrecs=1024, srec=10 * 1024)
gen.ensure_table()
cur = conn.cursor()
with cur.copy("copy copy_in from stdin") as copy:
for block in gen.blocks():
- copy.write(block.encode())
+ copy.write(pytype(block.encode()))
gen.assert_data()
@pytest.mark.slow
-async def test_copy_big_size_block(aconn):
+@pytest.mark.parametrize("pytype", [str, bytes, bytearray, memoryview])
+async def test_copy_big_size_block(aconn, pytype):
cur = aconn.cursor()
await ensure_table(cur, sample_tabledef)
data = "".join(choice(string.ascii_letters) for i in range(10 * 1024 * 1024))
+ copy_data = data + "\n" if pytype is str else pytype(data.encode() + b"\n")
async with cur.copy("copy copy_in (data) from stdin") as copy:
- await copy.write(data + "\n")
+ await copy.write(copy_data)
await cur.execute("select data from copy_in limit 1")
assert await cur.fetchone() == (data,)
@pytest.mark.slow
-async def test_copy_from_to_bytes(aconn):
+@pytest.mark.parametrize("pytype", [bytes, bytearray, memoryview])
+async def test_copy_from_to_bytes(aconn, pytype):
# Roundtrip from file to database to file blockwise
gen = DataGenerator(aconn, nrecs=1024, srec=10 * 1024)
await gen.ensure_table()
cur = aconn.cursor()
async with cur.copy("copy copy_in from stdin") as copy:
for block in gen.blocks():
- await copy.write(block.encode())
+ await copy.write(pytype(block.encode()))
await gen.assert_data()