from typing import Optional, Sequence, Tuple, Type, TypeVar, Union, TYPE_CHECKING
from . import pq
+from . import adapt
from . import errors as e
from .abc import Buffer, ConnectionType, PQGen, Transformer
-from .adapt import PyFormat
from ._compat import create_task
from ._cmodule import _psycopg
from ._encodings import pgconn_encoding
from .generators import copy_from, copy_to, copy_end
if TYPE_CHECKING:
- from .pq.abc import PGresult
from .cursor import BaseCursor, Cursor
from .cursor_async import AsyncCursor
from .connection import Connection # noqa: F401
from .connection_async import AsyncConnection # noqa: F401
-PY_TEXT = PyFormat.TEXT
-PY_BINARY = PyFormat.BINARY
+PY_TEXT = adapt.PyFormat.TEXT
+PY_BINARY = adapt.PyFormat.BINARY
TEXT = pq.Format.TEXT
BINARY = pq.Format.BINARY
COPY_IN = pq.ExecStatus.COPY_IN
+COPY_OUT = pq.ExecStatus.COPY_OUT
ACTIVE = pq.TransactionStatus.ACTIVE
self.connection = cursor.connection
self._pgconn = self.connection.pgconn
- tx = cursor._tx
- assert tx.pgresult, "The Transformer doesn't have a PGresult set"
- self._pgresult: "PGresult" = tx.pgresult
-
- if self._pgresult.binary_tuples == TEXT:
- self.formatter = TextFormatter(tx, encoding=pgconn_encoding(self._pgconn))
+ result = cursor.pgresult
+ if result:
+ self._direction = result.status
+ if self._direction != COPY_IN and self._direction != COPY_OUT:
+ raise e.ProgrammingError(
+ "the cursor should have performed a COPY operation;"
+ f" its status is {pq.ExecStatus(self._direction).name} instead"
+ )
else:
+ self._direction = COPY_IN
+
+ tx: Transformer = getattr(cursor, "_tx", None) or adapt.Transformer(cursor)
+ if result and result.binary_tuples:
self.formatter = BinaryFormatter(tx)
+ else:
+ self.formatter = TextFormatter(tx, encoding=pgconn_encoding(self._pgconn))
self._finished = False
registry = self.cursor.adapters.types
oids = [t if isinstance(t, int) else registry.get_oid(t) for t in types]
- if self._pgresult.status == COPY_IN:
+ if self._direction == COPY_IN:
self.formatter.transformer.set_dumper_types(oids, self.formatter.format)
else:
self.formatter.transformer.set_loader_types(oids, self.formatter.format)
return row
- def _end_copy_in_gen(self, exc: Optional[BaseException]) -> PQGen[None]:
- bmsg: Optional[bytes]
- if exc:
- msg = f"error from Python: {type(exc).__qualname__} - {exc}"
- bmsg = msg.encode(pgconn_encoding(self._pgconn), "replace")
- else:
- bmsg = None
-
- res = yield from copy_end(self._pgconn, bmsg)
-
- nrows = res.command_tuples
- self.cursor._rowcount = nrows if nrows is not None else -1
- self._finished = True
-
def _end_copy_out_gen(self, exc: Optional[BaseException]) -> PQGen[None]:
if not exc:
return
def __init__(self, cursor: "Cursor[Any]", *, writer: Optional["Writer"] = None):
super().__init__(cursor)
if not writer:
- writer = QueueWriter(cursor.connection)
+ writer = QueueWriter(cursor)
self.writer = writer
self._write = writer.write
by exit. It is available if, despite what is documented, you end up
using the `Copy` object outside a block.
"""
- if self._pgresult.status == COPY_IN:
+ if self._direction == COPY_IN:
data = self.formatter.end()
if data:
self._write(data)
- self.writer.finish()
- self.connection.wait(self._end_copy_in_gen(exc))
+ self.writer.finish(exc)
+ self._finished = True
else:
self.connection.wait(self._end_copy_out_gen(exc))
"""
...
- def finish(self) -> None:
+ def finish(self, exc: Optional[BaseException] = None) -> None:
"""
Called when write operations are finished.
"""
class ConnectionWriter(Writer):
- def __init__(self, connection: "Connection[Any]"):
- self.connection = connection
+ def __init__(self, cursor: "Cursor[Any]"):
+ self.cursor = cursor
+ self.connection = cursor.connection
self._pgconn = self.connection.pgconn
def write(self, data: Buffer) -> None:
copy_to(self._pgconn, data[i : i + MAX_BUFFER_SIZE])
)
+ def finish(self, exc: Optional[BaseException] = None) -> None:
+ bmsg: Optional[bytes]
+ if exc:
+ msg = f"error from Python: {type(exc).__qualname__} - {exc}"
+ bmsg = msg.encode(pgconn_encoding(self._pgconn), "replace")
+ else:
+ bmsg = None
+
+ res = self.connection.wait(copy_end(self._pgconn, bmsg))
+ nrows = res.command_tuples
+ self.cursor._rowcount = nrows if nrows is not None else -1
+
class QueueWriter(ConnectionWriter):
"""
on the connection.
"""
- def __init__(self, connection: "Connection[Any]"):
- super().__init__(connection)
+ def __init__(self, cursor: "Cursor[Any]"):
+ super().__init__(cursor)
self._queue: queue.Queue[bytes] = queue.Queue(maxsize=QUEUE_SIZE)
self._worker: Optional[threading.Thread] = None
for i in range(0, len(data), MAX_BUFFER_SIZE):
self._queue.put(data[i : i + MAX_BUFFER_SIZE])
- def finish(self) -> None:
+ def finish(self, exc: Optional[BaseException] = None) -> None:
self._queue.put(b"")
if self._worker:
if self._worker_error:
raise self._worker_error
+ super().finish(exc)
+
class AsyncCopy(BaseCopy["AsyncConnection[Any]"]):
"""Manage an asynchronous :sql:`COPY` operation."""
super().__init__(cursor)
if not writer:
- writer = AsyncQueueWriter(cursor.connection)
+ writer = AsyncQueueWriter(cursor)
self.writer = writer
self._write = writer.write
await self._write(data)
async def finish(self, exc: Optional[BaseException]) -> None:
- if self._pgresult.status == COPY_IN:
+ if self._direction == COPY_IN:
data = self.formatter.end()
if data:
await self._write(data)
- await self.writer.finish()
- await self.connection.wait(self._end_copy_in_gen(exc))
+ await self.writer.finish(exc)
+ self._finished = True
else:
await self.connection.wait(self._end_copy_out_gen(exc))
"""
...
- async def finish(self) -> None:
+ async def finish(self, exc: Optional[BaseException] = None) -> None:
"""
Called when write operations are finished.
"""
class AsyncConnectionWriter(AsyncWriter):
- def __init__(self, connection: "AsyncConnection[Any]"):
- self.connection = connection
+ def __init__(self, cursor: "AsyncCursor[Any]"):
+ self.cursor = cursor
+ self.connection = cursor.connection
self._pgconn = self.connection.pgconn
async def write(self, data: Buffer) -> None:
copy_to(self._pgconn, data[i : i + MAX_BUFFER_SIZE])
)
+ async def finish(self, exc: Optional[BaseException] = None) -> None:
+ bmsg: Optional[bytes]
+ if exc:
+ msg = f"error from Python: {type(exc).__qualname__} - {exc}"
+ bmsg = msg.encode(pgconn_encoding(self._pgconn), "replace")
+ else:
+ bmsg = None
+
+ res = await self.connection.wait(copy_end(self._pgconn, bmsg))
+ nrows = res.command_tuples
+ self.cursor._rowcount = nrows if nrows is not None else -1
+
class AsyncQueueWriter(AsyncConnectionWriter):
"""
on the connection.
"""
- def __init__(self, connection: "AsyncConnection[Any]"):
- super().__init__(connection)
+ def __init__(self, cursor: "AsyncCursor[Any]"):
+ super().__init__(cursor)
self._queue: asyncio.Queue[bytes] = asyncio.Queue(maxsize=QUEUE_SIZE)
self._worker: Optional[asyncio.Future[None]] = None
for i in range(0, len(data), MAX_BUFFER_SIZE):
await self._queue.put(data[i : i + MAX_BUFFER_SIZE])
- async def finish(self) -> None:
+ async def finish(self, exc: Optional[BaseException] = None) -> None:
await self._queue.put(b"")
if self._worker:
await asyncio.gather(self._worker)
self._worker = None # break reference loops if any
+ await super().finish(exc)
+
class Formatter(ABC):
"""
sample_binary = b"".join(sample_binary_rows)
+special_chars = {8: "b", 9: "t", 10: "n", 11: "v", 12: "f", 13: "r", ord("\\"): "\\"}
+
@pytest.mark.parametrize("format", Format)
def test_copy_out_read(conn, format):
assert data == [(True, True, 1, 256)]
+def test_copy_in_format(conn):
+ writer = BytesWriter()
+ conn.execute("set client_encoding to utf8")
+ cur = conn.cursor()
+ with psycopg.copy.Copy(cur, writer=writer) as copy:
+ for i in range(1, 256):
+ copy.write_row((i, chr(i)))
+
+ writer.file.seek(0)
+ rows = writer.file.read().split(b"\n")
+ assert not rows[-1]
+ del rows[-1]
+
+ for i, row in enumerate(rows, start=1):
+ fields = row.split(b"\t")
+ assert len(fields) == 2
+ assert int(fields[0].decode()) == i
+ if i in special_chars:
+ assert fields[1].decode() == f"\\{special_chars[i]}"
+ else:
+ assert fields[1].decode() == chr(i)
+
+
@pytest.mark.slow
def test_copy_from_to(conn):
# Roundtrip from file to database to file blockwise
)
def test_connection_writer(conn, format, buffer):
cur = conn.cursor()
- writer = psycopg.copy.ConnectionWriter(conn)
+ writer = psycopg.copy.ConnectionWriter(cur)
ensure_table(cur, sample_tabledef)
with cur.copy(
block = block.encode()
m.update(block)
return m.hexdigest()
+
+
+class BytesWriter(psycopg.copy.Writer):
+ def __init__(self):
+ self.file = BytesIO()
+
+ def write(self, data):
+ self.file.write(data)
from .utils import alist, eur, gc_collect
from .test_copy import sample_text, sample_binary, sample_binary_rows # noqa
from .test_copy import sample_values, sample_records, sample_tabledef
-from .test_copy import py_to_raw
+from .test_copy import py_to_raw, special_chars
pytestmark = [
pytest.mark.asyncio,
assert data == [(True, True, 1, 256)]
+async def test_copy_in_format(aconn):
+ writer = AsyncBytesWriter()
+ await aconn.execute("set client_encoding to utf8")
+ cur = aconn.cursor()
+ async with psycopg.copy.AsyncCopy(cur, writer=writer) as copy:
+ for i in range(1, 256):
+ await copy.write_row((i, chr(i)))
+
+ writer.file.seek(0)
+ rows = writer.file.read().split(b"\n")
+ assert not rows[-1]
+ del rows[-1]
+
+ for i, row in enumerate(rows, start=1):
+ fields = row.split(b"\t")
+ assert len(fields) == 2
+ assert int(fields[0].decode()) == i
+ if i in special_chars:
+ assert fields[1].decode() == f"\\{special_chars[i]}"
+ else:
+ assert fields[1].decode() == chr(i)
+
+
@pytest.mark.slow
async def test_copy_from_to(aconn):
# Roundtrip from file to database to file blockwise
)
async def test_connection_writer(aconn, format, buffer):
cur = aconn.cursor()
- writer = psycopg.copy.AsyncConnectionWriter(aconn)
+ writer = psycopg.copy.AsyncConnectionWriter(cur)
await ensure_table(cur, sample_tabledef)
async with cur.copy(
block = block.encode()
m.update(block)
return m.hexdigest()
+
+
+class AsyncBytesWriter(psycopg.copy.AsyncWriter):
+ def __init__(self):
+ self.file = BytesIO()
+
+ async def write(self, data):
+ self.file.write(data)