From: Daniele Varrazzo Date: Fri, 6 Oct 2023 10:58:49 +0000 (+0200) Subject: refactor(copy): generate sync code from async X-Git-Tag: pool-3.2.0~12^2~13 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=5a1db7bb8f582e02388d2d8ca5d0c04ac58f17f5;p=thirdparty%2Fpsycopg.git refactor(copy): generate sync code from async --- diff --git a/psycopg/psycopg/_acompat.py b/psycopg/psycopg/_acompat.py new file mode 100644 index 000000000..4915cab91 --- /dev/null +++ b/psycopg/psycopg/_acompat.py @@ -0,0 +1,106 @@ +""" +Utilities to ease the differences between async and sync code. + +These object offer a similar interface between sync and async versions; the +script async_to_sync.py will replace the async names with the sync names +when generating the sync version. +""" + +# Copyright (C) 2023 The Psycopg Team + +from __future__ import annotations + +import queue +import asyncio +import threading +from typing import Any, Callable, Coroutine, TypeVar, TYPE_CHECKING + +from typing_extensions import TypeAlias + +Worker: TypeAlias = threading.Thread +AWorker: TypeAlias = "asyncio.Task[None]" +T = TypeVar("T") + +# Hack required on Python 3.8 because subclassing Queue[T] fails at runtime. +# https://stackoverflow.com/questions/45414066/mypy-how-to-define-a-generic-subclass +if TYPE_CHECKING: + _GQueue: TypeAlias = queue.Queue + _AGQueue: TypeAlias = asyncio.Queue + +else: + + class FakeGenericMeta(type): + def __getitem__(self, item): + return self + + class _GQueue(queue.Queue, metaclass=FakeGenericMeta): + pass + + class _AGQueue(asyncio.Queue, metaclass=FakeGenericMeta): + pass + + +class Queue(_GQueue[T]): + """ + A Queue subclass with an interruptible get() method. + """ + + def get(self, block: bool = True, timeout: float | None = None) -> T: + # Always specify a timeout to make the wait interruptible. + if timeout is None: + timeout = 24.0 * 60.0 * 60.0 + return super().get(block=block, timeout=timeout) + + +class AQueue(_AGQueue[T]): + pass + + +def aspawn( + f: Callable[..., Coroutine[Any, Any, None]], + args: tuple[Any, ...] = (), + name: str | None = None, +) -> asyncio.Task[None]: + """ + Equivalent to asyncio.create_task. + """ + return asyncio.create_task(f(*args), name=name) + + +def spawn( + f: Callable[..., Any], + args: tuple[Any, ...] = (), + name: str | None = None, +) -> threading.Thread: + """ + Equivalent to creating and running a daemon thread. + """ + t = threading.Thread(target=f, args=args, name=name, daemon=True) + t.start() + return t + + +async def agather(*tasks: asyncio.Task[Any], timeout: float | None = None) -> None: + """ + Equivalent to asyncio.gather or Thread.join() + """ + wait = asyncio.gather(*tasks) + try: + if timeout is not None: + await asyncio.wait_for(asyncio.shield(wait), timeout=timeout) + else: + await wait + except asyncio.TimeoutError: + pass + else: + return + + +def gather(*tasks: threading.Thread, timeout: float | None = None) -> None: + """ + Equivalent to asyncio.gather or Thread.join() + """ + for t in tasks: + if not t.is_alive(): + continue + t.join(timeout) diff --git a/psycopg/psycopg/_copy.py b/psycopg/psycopg/_copy.py new file mode 100644 index 000000000..57133891d --- /dev/null +++ b/psycopg/psycopg/_copy.py @@ -0,0 +1,283 @@ +# WARNING: this file is auto-generated by 'async_to_sync.py' +# from the original file '_copy_async.py' +# DO NOT CHANGE! Change the original file instead. +""" +Psycopg Copy and related objects. +""" + +# Copyright (C) 2023 The Psycopg Team + +from __future__ import annotations + +from abc import ABC, abstractmethod +from types import TracebackType +from typing import Any, Iterator, Type, Tuple, Sequence, TYPE_CHECKING + +from . import pq +from . import errors as e +from ._copy_base import BaseCopy, MAX_BUFFER_SIZE, QUEUE_SIZE +from .generators import copy_to, copy_end +from ._encodings import pgconn_encoding +from ._acompat import spawn, gather, Queue, Worker + +if TYPE_CHECKING: + from .abc import Buffer + from .cursor import Cursor + from .connection import Connection # noqa: F401 + +COPY_IN = pq.ExecStatus.COPY_IN +COPY_OUT = pq.ExecStatus.COPY_OUT + + +class Copy(BaseCopy["Connection[Any]"]): + """Manage an asynchronous :sql:`COPY` operation. + + :param cursor: the cursor where the operation is performed. + :param binary: if `!True`, write binary format. + :param writer: the object to write to destination. If not specified, write + to the `!cursor` connection. + + Choosing `!binary` is not necessary if the cursor has executed a + :sql:`COPY` operation, because the operation result describes the format + too. The parameter is useful when a `!Copy` object is created manually and + no operation is performed on the cursor, such as when using ``writer=``\\ + `~psycopg.copy.FileWriter`. + """ + + __module__ = "psycopg" + + writer: Writer + + def __init__( + self, + cursor: Cursor[Any], + *, + binary: bool | None = None, + writer: Writer | None = None, + ): + super().__init__(cursor, binary=binary) + if not writer: + writer = LibpqWriter(cursor) + + self.writer = writer + self._write = writer.write + + def __enter__(self: BaseCopy._Self) -> BaseCopy._Self: + self._enter() + return self + + def __exit__( + self, + exc_type: Type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + self.finish(exc_val) + + # End user sync interface + + def __iter__(self) -> Iterator[Buffer]: + """Implement block-by-block iteration on :sql:`COPY TO`.""" + while True: + data = self.read() + if not data: + break + yield data + + def read(self) -> Buffer: + """ + Read an unparsed row after a :sql:`COPY TO` operation. + + Return an empty string when the data is finished. + """ + return self.connection.wait(self._read_gen()) + + def rows(self) -> Iterator[Tuple[Any, ...]]: + """ + Iterate on the result of a :sql:`COPY TO` operation record by record. + + Note that the records returned will be tuples of unparsed strings or + bytes, unless data types are specified using `set_types()`. + """ + while True: + record = self.read_row() + if record is None: + break + yield record + + def read_row(self) -> Tuple[Any, ...] | None: + """ + Read a parsed row of data from a table after a :sql:`COPY TO` operation. + + Return `!None` when the data is finished. + + Note that the records returned will be tuples of unparsed strings or + bytes, unless data types are specified using `set_types()`. + """ + return self.connection.wait(self._read_row_gen()) + + def write(self, buffer: Buffer | str) -> None: + """ + Write a block of data to a table after a :sql:`COPY FROM` operation. + + If the :sql:`COPY` is in binary format `!buffer` must be `!bytes`. In + text mode it can be either `!bytes` or `!str`. + """ + data = self.formatter.write(buffer) + if data: + self._write(data) + + def write_row(self, row: Sequence[Any]) -> None: + """Write a record to a table after a :sql:`COPY FROM` operation.""" + data = self.formatter.write_row(row) + if data: + self._write(data) + + def finish(self, exc: BaseException | None) -> None: + """Terminate the copy operation and free the resources allocated. + + You shouldn't need to call this function yourself: it is usually called + by exit. It is available if, despite what is documented, you end up + using the `Copy` object outside a block. + """ + if self._direction == COPY_IN: + data = self.formatter.end() + if data: + self._write(data) + self.writer.finish(exc) + self._finished = True + else: + self.connection.wait(self._end_copy_out_gen(exc)) + + +class Writer(ABC): + """ + A class to write copy data somewhere (for async connections). + """ + + @abstractmethod + def write(self, data: Buffer) -> None: + """Write some data to destination.""" + ... + + def finish(self, exc: BaseException | None = None) -> None: + """ + Called when write operations are finished. + + If operations finished with an error, it will be passed to ``exc``. + """ + pass + + +class LibpqWriter(Writer): + """ + An `Writer` to write copy data to a Postgres database. + """ + + __module__ = "psycopg.copy" + + def __init__(self, cursor: Cursor[Any]): + self.cursor = cursor + self.connection = cursor.connection + self._pgconn = self.connection.pgconn + + def write(self, data: Buffer) -> None: + if len(data) <= MAX_BUFFER_SIZE: + # Most used path: we don't need to split the buffer in smaller + # bits, so don't make a copy. + self.connection.wait(copy_to(self._pgconn, data)) + else: + # Copy a buffer too large in chunks to avoid causing a memory + # error in the libpq, which may cause an infinite loop (#255). + for i in range(0, len(data), MAX_BUFFER_SIZE): + self.connection.wait( + copy_to(self._pgconn, data[i : i + MAX_BUFFER_SIZE]) + ) + + def finish(self, exc: BaseException | None = None) -> None: + bmsg: bytes | None + if exc: + msg = f"error from Python: {type(exc).__qualname__} - {exc}" + bmsg = msg.encode(pgconn_encoding(self._pgconn), "replace") + else: + bmsg = None + + try: + res = self.connection.wait(copy_end(self._pgconn, bmsg)) + # The QueryCanceled is expected if we sent an exception message to + # pgconn.put_copy_end(). The Python exception that generated that + # cancelling is more important, so don't clobber it. + except e.QueryCanceled: + if not bmsg: + raise + else: + self.cursor._results = [res] + + +class QueuedLibpqWriter(LibpqWriter): + """ + `Writer` using a buffer to queue data to write. + + `write()` returns immediately, so that the main thread can be CPU-bound + formatting messages, while a worker thread can be IO-bound waiting to write + on the connection. + """ + + __module__ = "psycopg.copy" + + def __init__(self, cursor: Cursor[Any]): + super().__init__(cursor) + + self._queue: Queue[Buffer] = Queue(maxsize=QUEUE_SIZE) + self._worker: Worker | None = None + self._worker_error: BaseException | None = None + + def worker(self) -> None: + """Push data to the server when available from the copy queue. + + Terminate reading when the queue receives a false-y value, or in case + of error. + + The function is designed to be run in a separate task. + """ + try: + while True: + data = self._queue.get() + if not data: + break + self.connection.wait(copy_to(self._pgconn, data)) + except BaseException as ex: + # Propagate the error to the main thread. + self._worker_error = ex + + def write(self, data: Buffer) -> None: + if not self._worker: + # warning: reference loop, broken by _write_end + self._worker = spawn(self.worker) + + # If the worker thread raies an exception, re-raise it to the caller. + if self._worker_error: + raise self._worker_error + + if len(data) <= MAX_BUFFER_SIZE: + # Most used path: we don't need to split the buffer in smaller + # bits, so don't make a copy. + self._queue.put(data) + else: + # Copy a buffer too large in chunks to avoid causing a memory + # error in the libpq, which may cause an infinite loop (#255). + for i in range(0, len(data), MAX_BUFFER_SIZE): + self._queue.put(data[i : i + MAX_BUFFER_SIZE]) + + def finish(self, exc: BaseException | None = None) -> None: + self._queue.put(b"") + + if self._worker: + gather(self._worker) + self._worker = None # break reference loops if any + + # Check if the worker thread raised any exception before terminating. + if self._worker_error: + raise self._worker_error + + super().finish(exc) diff --git a/psycopg/psycopg/_copy_async.py b/psycopg/psycopg/_copy_async.py new file mode 100644 index 000000000..248335300 --- /dev/null +++ b/psycopg/psycopg/_copy_async.py @@ -0,0 +1,280 @@ +""" +Psycopg AsyncCopy and related objects. +""" + +# Copyright (C) 2023 The Psycopg Team + +from __future__ import annotations + +from abc import ABC, abstractmethod +from types import TracebackType +from typing import Any, AsyncIterator, Type, Tuple, Sequence, TYPE_CHECKING + +from . import pq +from . import errors as e +from ._copy_base import BaseCopy, MAX_BUFFER_SIZE, QUEUE_SIZE +from .generators import copy_to, copy_end +from ._encodings import pgconn_encoding +from ._acompat import aspawn, agather, AQueue, AWorker + +if TYPE_CHECKING: + from .abc import Buffer + from .cursor_async import AsyncCursor + from .connection_async import AsyncConnection # noqa: F401 + +COPY_IN = pq.ExecStatus.COPY_IN +COPY_OUT = pq.ExecStatus.COPY_OUT + + +class AsyncCopy(BaseCopy["AsyncConnection[Any]"]): + """Manage an asynchronous :sql:`COPY` operation. + + :param cursor: the cursor where the operation is performed. + :param binary: if `!True`, write binary format. + :param writer: the object to write to destination. If not specified, write + to the `!cursor` connection. + + Choosing `!binary` is not necessary if the cursor has executed a + :sql:`COPY` operation, because the operation result describes the format + too. The parameter is useful when a `!Copy` object is created manually and + no operation is performed on the cursor, such as when using ``writer=``\\ + `~psycopg.copy.FileWriter`. + """ + + __module__ = "psycopg" + + writer: AsyncWriter + + def __init__( + self, + cursor: AsyncCursor[Any], + *, + binary: bool | None = None, + writer: AsyncWriter | None = None, + ): + super().__init__(cursor, binary=binary) + if not writer: + writer = AsyncLibpqWriter(cursor) + + self.writer = writer + self._write = writer.write + + async def __aenter__(self: BaseCopy._Self) -> BaseCopy._Self: + self._enter() + return self + + async def __aexit__( + self, + exc_type: Type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + await self.finish(exc_val) + + # End user sync interface + + async def __aiter__(self) -> AsyncIterator[Buffer]: + """Implement block-by-block iteration on :sql:`COPY TO`.""" + while True: + data = await self.read() + if not data: + break + yield data + + async def read(self) -> Buffer: + """ + Read an unparsed row after a :sql:`COPY TO` operation. + + Return an empty string when the data is finished. + """ + return await self.connection.wait(self._read_gen()) + + async def rows(self) -> AsyncIterator[Tuple[Any, ...]]: + """ + Iterate on the result of a :sql:`COPY TO` operation record by record. + + Note that the records returned will be tuples of unparsed strings or + bytes, unless data types are specified using `set_types()`. + """ + while True: + record = await self.read_row() + if record is None: + break + yield record + + async def read_row(self) -> Tuple[Any, ...] | None: + """ + Read a parsed row of data from a table after a :sql:`COPY TO` operation. + + Return `!None` when the data is finished. + + Note that the records returned will be tuples of unparsed strings or + bytes, unless data types are specified using `set_types()`. + """ + return await self.connection.wait(self._read_row_gen()) + + async def write(self, buffer: Buffer | str) -> None: + """ + Write a block of data to a table after a :sql:`COPY FROM` operation. + + If the :sql:`COPY` is in binary format `!buffer` must be `!bytes`. In + text mode it can be either `!bytes` or `!str`. + """ + data = self.formatter.write(buffer) + if data: + await self._write(data) + + async def write_row(self, row: Sequence[Any]) -> None: + """Write a record to a table after a :sql:`COPY FROM` operation.""" + data = self.formatter.write_row(row) + if data: + await self._write(data) + + async def finish(self, exc: BaseException | None) -> None: + """Terminate the copy operation and free the resources allocated. + + You shouldn't need to call this function yourself: it is usually called + by exit. It is available if, despite what is documented, you end up + using the `Copy` object outside a block. + """ + if self._direction == COPY_IN: + data = self.formatter.end() + if data: + await self._write(data) + await self.writer.finish(exc) + self._finished = True + else: + await self.connection.wait(self._end_copy_out_gen(exc)) + + +class AsyncWriter(ABC): + """ + A class to write copy data somewhere (for async connections). + """ + + @abstractmethod + async def write(self, data: Buffer) -> None: + """Write some data to destination.""" + ... + + async def finish(self, exc: BaseException | None = None) -> None: + """ + Called when write operations are finished. + + If operations finished with an error, it will be passed to ``exc``. + """ + pass + + +class AsyncLibpqWriter(AsyncWriter): + """ + An `AsyncWriter` to write copy data to a Postgres database. + """ + + __module__ = "psycopg.copy" + + def __init__(self, cursor: AsyncCursor[Any]): + self.cursor = cursor + self.connection = cursor.connection + self._pgconn = self.connection.pgconn + + async def write(self, data: Buffer) -> None: + if len(data) <= MAX_BUFFER_SIZE: + # Most used path: we don't need to split the buffer in smaller + # bits, so don't make a copy. + await self.connection.wait(copy_to(self._pgconn, data)) + else: + # Copy a buffer too large in chunks to avoid causing a memory + # error in the libpq, which may cause an infinite loop (#255). + for i in range(0, len(data), MAX_BUFFER_SIZE): + await self.connection.wait( + copy_to(self._pgconn, data[i : i + MAX_BUFFER_SIZE]) + ) + + async def finish(self, exc: BaseException | None = None) -> None: + bmsg: bytes | None + if exc: + msg = f"error from Python: {type(exc).__qualname__} - {exc}" + bmsg = msg.encode(pgconn_encoding(self._pgconn), "replace") + else: + bmsg = None + + try: + res = await self.connection.wait(copy_end(self._pgconn, bmsg)) + # The QueryCanceled is expected if we sent an exception message to + # pgconn.put_copy_end(). The Python exception that generated that + # cancelling is more important, so don't clobber it. + except e.QueryCanceled: + if not bmsg: + raise + else: + self.cursor._results = [res] + + +class AsyncQueuedLibpqWriter(AsyncLibpqWriter): + """ + `AsyncWriter` using a buffer to queue data to write. + + `write()` returns immediately, so that the main thread can be CPU-bound + formatting messages, while a worker thread can be IO-bound waiting to write + on the connection. + """ + + __module__ = "psycopg.copy" + + def __init__(self, cursor: AsyncCursor[Any]): + super().__init__(cursor) + + self._queue: AQueue[Buffer] = AQueue(maxsize=QUEUE_SIZE) + self._worker: AWorker | None = None + self._worker_error: BaseException | None = None + + async def worker(self) -> None: + """Push data to the server when available from the copy queue. + + Terminate reading when the queue receives a false-y value, or in case + of error. + + The function is designed to be run in a separate task. + """ + try: + while True: + data = await self._queue.get() + if not data: + break + await self.connection.wait(copy_to(self._pgconn, data)) + except BaseException as ex: + # Propagate the error to the main thread. + self._worker_error = ex + + async def write(self, data: Buffer) -> None: + if not self._worker: + # warning: reference loop, broken by _write_end + self._worker = aspawn(self.worker) + + # If the worker thread raies an exception, re-raise it to the caller. + if self._worker_error: + raise self._worker_error + + if len(data) <= MAX_BUFFER_SIZE: + # Most used path: we don't need to split the buffer in smaller + # bits, so don't make a copy. + await self._queue.put(data) + else: + # Copy a buffer too large in chunks to avoid causing a memory + # error in the libpq, which may cause an infinite loop (#255). + for i in range(0, len(data), MAX_BUFFER_SIZE): + await self._queue.put(data[i : i + MAX_BUFFER_SIZE]) + + async def finish(self, exc: BaseException | None = None) -> None: + await self._queue.put(b"") + + if self._worker: + await agather(self._worker) + self._worker = None # break reference loops if any + + # Check if the worker thread raised any exception before terminating. + if self._worker_error: + raise self._worker_error + + await super().finish(exc) diff --git a/psycopg/psycopg/_copy_base.py b/psycopg/psycopg/_copy_base.py new file mode 100644 index 000000000..8f2a8f72b --- /dev/null +++ b/psycopg/psycopg/_copy_base.py @@ -0,0 +1,458 @@ +""" +psycopg copy support +""" + +# Copyright (C) 2020 The Psycopg Team + +from __future__ import annotations + +import re +import struct +from abc import ABC, abstractmethod +from typing import Any, Dict, Generic, List, Match +from typing import Optional, Sequence, Tuple, TypeVar, Union, TYPE_CHECKING + +from . import pq +from . import adapt +from . import errors as e +from .abc import Buffer, ConnectionType, PQGen, Transformer +from .pq.misc import connection_summary +from ._cmodule import _psycopg +from ._encodings import pgconn_encoding +from .generators import copy_from + +if TYPE_CHECKING: + from ._cursor_base import BaseCursor + from .connection import Connection # noqa: F401 + from .connection_async import AsyncConnection # noqa: F401 + +PY_TEXT = adapt.PyFormat.TEXT +PY_BINARY = adapt.PyFormat.BINARY + +TEXT = pq.Format.TEXT +BINARY = pq.Format.BINARY + +COPY_IN = pq.ExecStatus.COPY_IN +COPY_OUT = pq.ExecStatus.COPY_OUT + +ACTIVE = pq.TransactionStatus.ACTIVE + +# Size of data to accumulate before sending it down the network. We fill a +# buffer this size field by field, and when it passes the threshold size +# we ship it, so it may end up being bigger than this. +BUFFER_SIZE = 32 * 1024 + +# Maximum data size we want to queue to send to the libpq copy. Sending a +# buffer too big to be handled can cause an infinite loop in the libpq +# (#255) so we want to split it in more digestable chunks. +MAX_BUFFER_SIZE = 4 * BUFFER_SIZE +# Note: making this buffer too large, e.g. +# MAX_BUFFER_SIZE = 1024 * 1024 +# makes operations *way* slower! Probably triggering some quadraticity +# in the libpq memory management and data sending. + +# Max size of the write queue of buffers. More than that copy will block +# Each buffer should be around BUFFER_SIZE size. +QUEUE_SIZE = 1024 + + +class BaseCopy(Generic[ConnectionType]): + """ + Base implementation for the copy user interface. + + Two subclasses expose real methods with the sync/async differences. + + The difference between the text and binary format is managed by two + different `Formatter` subclasses. + + Writing (the I/O part) is implemented in the subclasses by a `Writer` or + `AsyncWriter` instance. Normally writing implies sending copy data to a + database, but a different writer might be chosen, e.g. to stream data into + a file for later use. + """ + + _Self = TypeVar("_Self", bound="BaseCopy[Any]") + + formatter: Formatter + + def __init__( + self, + cursor: "BaseCursor[ConnectionType, Any]", + *, + binary: Optional[bool] = None, + ): + self.cursor = cursor + self.connection = cursor.connection + self._pgconn = self.connection.pgconn + + result = cursor.pgresult + if result: + self._direction = result.status + if self._direction != COPY_IN and self._direction != COPY_OUT: + raise e.ProgrammingError( + "the cursor should have performed a COPY operation;" + f" its status is {pq.ExecStatus(self._direction).name} instead" + ) + else: + self._direction = COPY_IN + + if binary is None: + binary = bool(result and result.binary_tuples) + + tx: Transformer = getattr(cursor, "_tx", None) or adapt.Transformer(cursor) + if binary: + self.formatter = BinaryFormatter(tx) + else: + self.formatter = TextFormatter(tx, encoding=pgconn_encoding(self._pgconn)) + + self._finished = False + + def __repr__(self) -> str: + cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}" + info = connection_summary(self._pgconn) + return f"<{cls} {info} at 0x{id(self):x}>" + + def _enter(self) -> None: + if self._finished: + raise TypeError("copy blocks can be used only once") + + def set_types(self, types: Sequence[Union[int, str]]) -> None: + """ + Set the types expected in a COPY operation. + + The types must be specified as a sequence of oid or PostgreSQL type + names (e.g. ``int4``, ``timestamptz[]``). + + This operation overcomes the lack of metadata returned by PostgreSQL + when a COPY operation begins: + + - On :sql:`COPY TO`, `!set_types()` allows to specify what types the + operation returns. If `!set_types()` is not used, the data will be + returned as unparsed strings or bytes instead of Python objects. + + - On :sql:`COPY FROM`, `!set_types()` allows to choose what type the + database expects. This is especially useful in binary copy, because + PostgreSQL will apply no cast rule. + + """ + registry = self.cursor.adapters.types + oids = [t if isinstance(t, int) else registry.get_oid(t) for t in types] + + if self._direction == COPY_IN: + self.formatter.transformer.set_dumper_types(oids, self.formatter.format) + else: + self.formatter.transformer.set_loader_types(oids, self.formatter.format) + + # High level copy protocol generators (state change of the Copy object) + + def _read_gen(self) -> PQGen[Buffer]: + if self._finished: + return memoryview(b"") + + res = yield from copy_from(self._pgconn) + if isinstance(res, memoryview): + return res + + # res is the final PGresult + self._finished = True + + # This result is a COMMAND_OK which has info about the number of rows + # returned, but not about the columns, which is instead an information + # that was received on the COPY_OUT result at the beginning of COPY. + # So, don't replace the results in the cursor, just update the rowcount. + nrows = res.command_tuples + self.cursor._rowcount = nrows if nrows is not None else -1 + return memoryview(b"") + + def _read_row_gen(self) -> PQGen[Optional[Tuple[Any, ...]]]: + data = yield from self._read_gen() + if not data: + return None + + row = self.formatter.parse_row(data) + if row is None: + # Get the final result to finish the copy operation + yield from self._read_gen() + self._finished = True + return None + + return row + + def _end_copy_out_gen(self, exc: Optional[BaseException]) -> PQGen[None]: + if not exc: + return + + if self._pgconn.transaction_status != ACTIVE: + # The server has already finished to send copy data. The connection + # is already in a good state. + return + + # Throw a cancel to the server, then consume the rest of the copy data + # (which might or might not have been already transferred entirely to + # the client, so we won't necessary see the exception associated with + # canceling). + self.connection.cancel() + try: + while (yield from self._read_gen()): + pass + except e.QueryCanceled: + pass + + +class Formatter(ABC): + """ + A class which understand a copy format (text, binary). + """ + + format: pq.Format + + def __init__(self, transformer: Transformer): + self.transformer = transformer + self._write_buffer = bytearray() + self._row_mode = False # true if the user is using write_row() + + @abstractmethod + def parse_row(self, data: Buffer) -> Optional[Tuple[Any, ...]]: + ... + + @abstractmethod + def write(self, buffer: Union[Buffer, str]) -> Buffer: + ... + + @abstractmethod + def write_row(self, row: Sequence[Any]) -> Buffer: + ... + + @abstractmethod + def end(self) -> Buffer: + ... + + +class TextFormatter(Formatter): + format = TEXT + + def __init__(self, transformer: Transformer, encoding: str = "utf-8"): + super().__init__(transformer) + self._encoding = encoding + + def parse_row(self, data: Buffer) -> Optional[Tuple[Any, ...]]: + if data: + return parse_row_text(data, self.transformer) + else: + return None + + def write(self, buffer: Union[Buffer, str]) -> Buffer: + data = self._ensure_bytes(buffer) + self._signature_sent = True + return data + + def write_row(self, row: Sequence[Any]) -> Buffer: + # Note down that we are writing in row mode: it means we will have + # to take care of the end-of-copy marker too + self._row_mode = True + + format_row_text(row, self.transformer, self._write_buffer) + if len(self._write_buffer) > BUFFER_SIZE: + buffer, self._write_buffer = self._write_buffer, bytearray() + return buffer + else: + return b"" + + def end(self) -> Buffer: + buffer, self._write_buffer = self._write_buffer, bytearray() + return buffer + + def _ensure_bytes(self, data: Union[Buffer, str]) -> Buffer: + if isinstance(data, str): + return data.encode(self._encoding) + else: + # Assume, for simplicity, that the user is not passing stupid + # things to the write function. If that's the case, things + # will fail downstream. + return data + + +class BinaryFormatter(Formatter): + format = BINARY + + def __init__(self, transformer: Transformer): + super().__init__(transformer) + self._signature_sent = False + + def parse_row(self, data: Buffer) -> Optional[Tuple[Any, ...]]: + if not self._signature_sent: + if data[: len(_binary_signature)] != _binary_signature: + raise e.DataError( + "binary copy doesn't start with the expected signature" + ) + self._signature_sent = True + data = data[len(_binary_signature) :] + + elif data == _binary_trailer: + return None + + return parse_row_binary(data, self.transformer) + + def write(self, buffer: Union[Buffer, str]) -> Buffer: + data = self._ensure_bytes(buffer) + self._signature_sent = True + return data + + def write_row(self, row: Sequence[Any]) -> Buffer: + # Note down that we are writing in row mode: it means we will have + # to take care of the end-of-copy marker too + self._row_mode = True + + if not self._signature_sent: + self._write_buffer += _binary_signature + self._signature_sent = True + + format_row_binary(row, self.transformer, self._write_buffer) + if len(self._write_buffer) > BUFFER_SIZE: + buffer, self._write_buffer = self._write_buffer, bytearray() + return buffer + else: + return b"" + + def end(self) -> Buffer: + # If we have sent no data we need to send the signature + # and the trailer + if not self._signature_sent: + self._write_buffer += _binary_signature + self._write_buffer += _binary_trailer + + elif self._row_mode: + # if we have sent data already, we have sent the signature + # too (either with the first row, or we assume that in + # block mode the signature is included). + # Write the trailer only if we are sending rows (with the + # assumption that who is copying binary data is sending the + # whole format). + self._write_buffer += _binary_trailer + + buffer, self._write_buffer = self._write_buffer, bytearray() + return buffer + + def _ensure_bytes(self, data: Union[Buffer, str]) -> Buffer: + if isinstance(data, str): + raise TypeError("cannot copy str data in binary mode: use bytes instead") + else: + # Assume, for simplicity, that the user is not passing stupid + # things to the write function. If that's the case, things + # will fail downstream. + return data + + +def _format_row_text( + row: Sequence[Any], tx: Transformer, out: Optional[bytearray] = None +) -> bytearray: + """Convert a row of objects to the data to send for copy.""" + if out is None: + out = bytearray() + + if not row: + out += b"\n" + return out + + adapted = tx.dump_sequence(row, [PY_TEXT] * len(row)) + for b in adapted: + out += _dump_re.sub(_dump_sub, b) if b is not None else rb"\N" + out += b"\t" + + out[-1:] = b"\n" + return out + + +def _format_row_binary( + row: Sequence[Any], tx: Transformer, out: Optional[bytearray] = None +) -> bytearray: + """Convert a row of objects to the data to send for binary copy.""" + if out is None: + out = bytearray() + + out += _pack_int2(len(row)) + adapted = tx.dump_sequence(row, [PY_BINARY] * len(row)) + for b in adapted: + if b is not None: + out += _pack_int4(len(b)) + out += b + else: + out += _binary_null + + return out + + +def _parse_row_text(data: Buffer, tx: Transformer) -> Tuple[Any, ...]: + if not isinstance(data, bytes): + data = bytes(data) + fields = data.split(b"\t") + fields[-1] = fields[-1][:-1] # drop \n + row = [None if f == b"\\N" else _load_re.sub(_load_sub, f) for f in fields] + return tx.load_sequence(row) + + +def _parse_row_binary(data: Buffer, tx: Transformer) -> Tuple[Any, ...]: + row: List[Optional[Buffer]] = [] + nfields = _unpack_int2(data, 0)[0] + pos = 2 + for i in range(nfields): + length = _unpack_int4(data, pos)[0] + pos += 4 + if length >= 0: + row.append(data[pos : pos + length]) + pos += length + else: + row.append(None) + + return tx.load_sequence(row) + + +_pack_int2 = struct.Struct("!h").pack +_pack_int4 = struct.Struct("!i").pack +_unpack_int2 = struct.Struct("!h").unpack_from +_unpack_int4 = struct.Struct("!i").unpack_from + +_binary_signature = ( + b"PGCOPY\n\xff\r\n\0" # Signature + b"\x00\x00\x00\x00" # flags + b"\x00\x00\x00\x00" # extra length +) +_binary_trailer = b"\xff\xff" +_binary_null = b"\xff\xff\xff\xff" + +_dump_re = re.compile(b"[\b\t\n\v\f\r\\\\]") +_dump_repl = { + b"\b": b"\\b", + b"\t": b"\\t", + b"\n": b"\\n", + b"\v": b"\\v", + b"\f": b"\\f", + b"\r": b"\\r", + b"\\": b"\\\\", +} + + +def _dump_sub(m: Match[bytes], __map: Dict[bytes, bytes] = _dump_repl) -> bytes: + return __map[m.group(0)] + + +_load_re = re.compile(b"\\\\[btnvfr\\\\]") +_load_repl = {v: k for k, v in _dump_repl.items()} + + +def _load_sub(m: Match[bytes], __map: Dict[bytes, bytes] = _load_repl) -> bytes: + return __map[m.group(0)] + + +# Override functions with fast versions if available +if _psycopg: + format_row_text = _psycopg.format_row_text + format_row_binary = _psycopg.format_row_binary + parse_row_text = _psycopg.parse_row_text + parse_row_binary = _psycopg.parse_row_binary + +else: + format_row_text = _format_row_text + format_row_binary = _format_row_binary + parse_row_text = _parse_row_text + parse_row_binary = _parse_row_binary diff --git a/psycopg/psycopg/copy.py b/psycopg/psycopg/copy.py index bf54e90be..b43d25ca6 100644 --- a/psycopg/psycopg/copy.py +++ b/psycopg/psycopg/copy.py @@ -1,461 +1,23 @@ """ -psycopg copy support +Module gathering the various parts of the copy subsystem. """ -# Copyright (C) 2020 The Psycopg Team +from typing import IO -import re -import queue -import struct -import asyncio -import threading -from abc import ABC, abstractmethod -from types import TracebackType -from typing import Any, AsyncIterator, Dict, Generic, Iterator, List, Match, IO -from typing import Optional, Sequence, Tuple, Type, TypeVar, Union, TYPE_CHECKING +from .abc import Buffer +from . import _copy, _copy_async -from . import pq -from . import adapt -from . import errors as e -from .abc import Buffer, ConnectionType, PQGen, Transformer -from .pq.misc import connection_summary -from ._cmodule import _psycopg -from ._encodings import pgconn_encoding -from .generators import copy_from, copy_to, copy_end +# re-exports -if TYPE_CHECKING: - from .cursor import Cursor - from ._cursor_base import BaseCursor - from .cursor_async import AsyncCursor - from .connection import Connection # noqa: F401 - from .connection_async import AsyncConnection # noqa: F401 +AsyncCopy = _copy_async.AsyncCopy +AsyncWriter = _copy_async.AsyncWriter +AsyncLibpqWriter = _copy_async.AsyncLibpqWriter +AsyncQueuedLibpqWriter = _copy_async.AsyncQueuedLibpqWriter -PY_TEXT = adapt.PyFormat.TEXT -PY_BINARY = adapt.PyFormat.BINARY - -TEXT = pq.Format.TEXT -BINARY = pq.Format.BINARY - -COPY_IN = pq.ExecStatus.COPY_IN -COPY_OUT = pq.ExecStatus.COPY_OUT - -ACTIVE = pq.TransactionStatus.ACTIVE - -# Size of data to accumulate before sending it down the network. We fill a -# buffer this size field by field, and when it passes the threshold size -# we ship it, so it may end up being bigger than this. -BUFFER_SIZE = 32 * 1024 - -# Maximum data size we want to queue to send to the libpq copy. Sending a -# buffer too big to be handled can cause an infinite loop in the libpq -# (#255) so we want to split it in more digestable chunks. -MAX_BUFFER_SIZE = 4 * BUFFER_SIZE -# Note: making this buffer too large, e.g. -# MAX_BUFFER_SIZE = 1024 * 1024 -# makes operations *way* slower! Probably triggering some quadraticity -# in the libpq memory management and data sending. - -# Max size of the write queue of buffers. More than that copy will block -# Each buffer should be around BUFFER_SIZE size. -QUEUE_SIZE = 1024 - - -class BaseCopy(Generic[ConnectionType]): - """ - Base implementation for the copy user interface. - - Two subclasses expose real methods with the sync/async differences. - - The difference between the text and binary format is managed by two - different `Formatter` subclasses. - - Writing (the I/O part) is implemented in the subclasses by a `Writer` or - `AsyncWriter` instance. Normally writing implies sending copy data to a - database, but a different writer might be chosen, e.g. to stream data into - a file for later use. - """ - - _Self = TypeVar("_Self", bound="BaseCopy[Any]") - - formatter: "Formatter" - - def __init__( - self, - cursor: "BaseCursor[ConnectionType, Any]", - *, - binary: Optional[bool] = None, - ): - self.cursor = cursor - self.connection = cursor.connection - self._pgconn = self.connection.pgconn - - result = cursor.pgresult - if result: - self._direction = result.status - if self._direction != COPY_IN and self._direction != COPY_OUT: - raise e.ProgrammingError( - "the cursor should have performed a COPY operation;" - f" its status is {pq.ExecStatus(self._direction).name} instead" - ) - else: - self._direction = COPY_IN - - if binary is None: - binary = bool(result and result.binary_tuples) - - tx: Transformer = getattr(cursor, "_tx", None) or adapt.Transformer(cursor) - if binary: - self.formatter = BinaryFormatter(tx) - else: - self.formatter = TextFormatter(tx, encoding=pgconn_encoding(self._pgconn)) - - self._finished = False - - def __repr__(self) -> str: - cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}" - info = connection_summary(self._pgconn) - return f"<{cls} {info} at 0x{id(self):x}>" - - def _enter(self) -> None: - if self._finished: - raise TypeError("copy blocks can be used only once") - - def set_types(self, types: Sequence[Union[int, str]]) -> None: - """ - Set the types expected in a COPY operation. - - The types must be specified as a sequence of oid or PostgreSQL type - names (e.g. ``int4``, ``timestamptz[]``). - - This operation overcomes the lack of metadata returned by PostgreSQL - when a COPY operation begins: - - - On :sql:`COPY TO`, `!set_types()` allows to specify what types the - operation returns. If `!set_types()` is not used, the data will be - returned as unparsed strings or bytes instead of Python objects. - - - On :sql:`COPY FROM`, `!set_types()` allows to choose what type the - database expects. This is especially useful in binary copy, because - PostgreSQL will apply no cast rule. - - """ - registry = self.cursor.adapters.types - oids = [t if isinstance(t, int) else registry.get_oid(t) for t in types] - - if self._direction == COPY_IN: - self.formatter.transformer.set_dumper_types(oids, self.formatter.format) - else: - self.formatter.transformer.set_loader_types(oids, self.formatter.format) - - # High level copy protocol generators (state change of the Copy object) - - def _read_gen(self) -> PQGen[Buffer]: - if self._finished: - return memoryview(b"") - - res = yield from copy_from(self._pgconn) - if isinstance(res, memoryview): - return res - - # res is the final PGresult - self._finished = True - - # This result is a COMMAND_OK which has info about the number of rows - # returned, but not about the columns, which is instead an information - # that was received on the COPY_OUT result at the beginning of COPY. - # So, don't replace the results in the cursor, just update the rowcount. - nrows = res.command_tuples - self.cursor._rowcount = nrows if nrows is not None else -1 - return memoryview(b"") - - def _read_row_gen(self) -> PQGen[Optional[Tuple[Any, ...]]]: - data = yield from self._read_gen() - if not data: - return None - - row = self.formatter.parse_row(data) - if row is None: - # Get the final result to finish the copy operation - yield from self._read_gen() - self._finished = True - return None - - return row - - def _end_copy_out_gen(self, exc: Optional[BaseException]) -> PQGen[None]: - if not exc: - return - - if self._pgconn.transaction_status != ACTIVE: - # The server has already finished to send copy data. The connection - # is already in a good state. - return - - # Throw a cancel to the server, then consume the rest of the copy data - # (which might or might not have been already transferred entirely to - # the client, so we won't necessary see the exception associated with - # canceling). - self.connection.cancel() - try: - while (yield from self._read_gen()): - pass - except e.QueryCanceled: - pass - - -class Copy(BaseCopy["Connection[Any]"]): - """Manage a :sql:`COPY` operation. - - :param cursor: the cursor where the operation is performed. - :param binary: if `!True`, write binary format. - :param writer: the object to write to destination. If not specified, write - to the `!cursor` connection. - - Choosing `!binary` is not necessary if the cursor has executed a - :sql:`COPY` operation, because the operation result describes the format - too. The parameter is useful when a `!Copy` object is created manually and - no operation is performed on the cursor, such as when using ``writer=``\\ - `~psycopg.copy.FileWriter`. - - """ - - __module__ = "psycopg" - - writer: "Writer" - - def __init__( - self, - cursor: "Cursor[Any]", - *, - binary: Optional[bool] = None, - writer: Optional["Writer"] = None, - ): - super().__init__(cursor, binary=binary) - if not writer: - writer = LibpqWriter(cursor) - - self.writer = writer - self._write = writer.write - - def __enter__(self: BaseCopy._Self) -> BaseCopy._Self: - self._enter() - return self - - def __exit__( - self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], - ) -> None: - self.finish(exc_val) - - # End user sync interface - - def __iter__(self) -> Iterator[Buffer]: - """Implement block-by-block iteration on :sql:`COPY TO`.""" - while True: - data = self.read() - if not data: - break - yield data - - def read(self) -> Buffer: - """ - Read an unparsed row after a :sql:`COPY TO` operation. - - Return an empty string when the data is finished. - """ - return self.connection.wait(self._read_gen()) - - def rows(self) -> Iterator[Tuple[Any, ...]]: - """ - Iterate on the result of a :sql:`COPY TO` operation record by record. - - Note that the records returned will be tuples of unparsed strings or - bytes, unless data types are specified using `set_types()`. - """ - while True: - record = self.read_row() - if record is None: - break - yield record - - def read_row(self) -> Optional[Tuple[Any, ...]]: - """ - Read a parsed row of data from a table after a :sql:`COPY TO` operation. - - Return `!None` when the data is finished. - - Note that the records returned will be tuples of unparsed strings or - bytes, unless data types are specified using `set_types()`. - """ - return self.connection.wait(self._read_row_gen()) - - def write(self, buffer: Union[Buffer, str]) -> None: - """ - Write a block of data to a table after a :sql:`COPY FROM` operation. - - If the :sql:`COPY` is in binary format `!buffer` must be `!bytes`. In - text mode it can be either `!bytes` or `!str`. - """ - data = self.formatter.write(buffer) - if data: - self._write(data) - - def write_row(self, row: Sequence[Any]) -> None: - """Write a record to a table after a :sql:`COPY FROM` operation.""" - data = self.formatter.write_row(row) - if data: - self._write(data) - - def finish(self, exc: Optional[BaseException]) -> None: - """Terminate the copy operation and free the resources allocated. - - You shouldn't need to call this function yourself: it is usually called - by exit. It is available if, despite what is documented, you end up - using the `Copy` object outside a block. - """ - if self._direction == COPY_IN: - data = self.formatter.end() - if data: - self._write(data) - self.writer.finish(exc) - self._finished = True - else: - self.connection.wait(self._end_copy_out_gen(exc)) - - -class Writer(ABC): - """ - A class to write copy data somewhere. - """ - - @abstractmethod - def write(self, data: Buffer) -> None: - """ - Write some data to destination. - """ - ... - - def finish(self, exc: Optional[BaseException] = None) -> None: - """ - Called when write operations are finished. - - If operations finished with an error, it will be passed to ``exc``. - """ - pass - - -class LibpqWriter(Writer): - """ - A `Writer` to write copy data to a Postgres database. - """ - - def __init__(self, cursor: "Cursor[Any]"): - self.cursor = cursor - self.connection = cursor.connection - self._pgconn = self.connection.pgconn - - def write(self, data: Buffer) -> None: - if len(data) <= MAX_BUFFER_SIZE: - # Most used path: we don't need to split the buffer in smaller - # bits, so don't make a copy. - self.connection.wait(copy_to(self._pgconn, data)) - else: - # Copy a buffer too large in chunks to avoid causing a memory - # error in the libpq, which may cause an infinite loop (#255). - for i in range(0, len(data), MAX_BUFFER_SIZE): - self.connection.wait( - copy_to(self._pgconn, data[i : i + MAX_BUFFER_SIZE]) - ) - - def finish(self, exc: Optional[BaseException] = None) -> None: - bmsg: Optional[bytes] - if exc: - msg = f"error from Python: {type(exc).__qualname__} - {exc}" - bmsg = msg.encode(pgconn_encoding(self._pgconn), "replace") - else: - bmsg = None - - try: - res = self.connection.wait(copy_end(self._pgconn, bmsg)) - # The QueryCanceled is expected if we sent an exception message to - # pgconn.put_copy_end(). The Python exception that generated that - # cancelling is more important, so don't clobber it. - except e.QueryCanceled: - if not bmsg: - raise - else: - self.cursor._results = [res] - - -class QueuedLibpqWriter(LibpqWriter): - """ - A writer using a buffer to queue data to write to a Postgres database. - - `write()` returns immediately, so that the main thread can be CPU-bound - formatting messages, while a worker thread can be IO-bound waiting to write - on the connection. - """ - - def __init__(self, cursor: "Cursor[Any]"): - super().__init__(cursor) - - self._queue: queue.Queue[Buffer] = queue.Queue(maxsize=QUEUE_SIZE) - self._worker: Optional[threading.Thread] = None - self._worker_error: Optional[BaseException] = None - - def worker(self) -> None: - """Push data to the server when available from the copy queue. - - Terminate reading when the queue receives a false-y value, or in case - of error. - - The function is designed to be run in a separate thread. - """ - try: - while True: - data = self._queue.get(block=True, timeout=24 * 60 * 60) - if not data: - break - self.connection.wait(copy_to(self._pgconn, data)) - except BaseException as ex: - # Propagate the error to the main thread. - self._worker_error = ex - - def write(self, data: Buffer) -> None: - if not self._worker: - # warning: reference loop, broken by _write_end - self._worker = threading.Thread(target=self.worker) - self._worker.daemon = True - self._worker.start() - - # If the worker thread raies an exception, re-raise it to the caller. - if self._worker_error: - raise self._worker_error - - if len(data) <= MAX_BUFFER_SIZE: - # Most used path: we don't need to split the buffer in smaller - # bits, so don't make a copy. - self._queue.put(data) - else: - # Copy a buffer too large in chunks to avoid causing a memory - # error in the libpq, which may cause an infinite loop (#255). - for i in range(0, len(data), MAX_BUFFER_SIZE): - self._queue.put(data[i : i + MAX_BUFFER_SIZE]) - - def finish(self, exc: Optional[BaseException] = None) -> None: - self._queue.put(b"") - - if self._worker: - self._worker.join() - self._worker = None # break the loop - - # Check if the worker thread raised any exception before terminating. - if self._worker_error: - raise self._worker_error - - super().finish(exc) +Copy = _copy.Copy +Writer = _copy.Writer +LibpqWriter = _copy.LibpqWriter +QueuedLibpqWriter = _copy.QueuedLibpqWriter class FileWriter(Writer): @@ -471,445 +33,3 @@ class FileWriter(Writer): def write(self, data: Buffer) -> None: self.file.write(data) - - -class AsyncCopy(BaseCopy["AsyncConnection[Any]"]): - """Manage an asynchronous :sql:`COPY` operation.""" - - __module__ = "psycopg" - - writer: "AsyncWriter" - - def __init__( - self, - cursor: "AsyncCursor[Any]", - *, - binary: Optional[bool] = None, - writer: Optional["AsyncWriter"] = None, - ): - super().__init__(cursor, binary=binary) - - if not writer: - writer = AsyncLibpqWriter(cursor) - - self.writer = writer - self._write = writer.write - - async def __aenter__(self: BaseCopy._Self) -> BaseCopy._Self: - self._enter() - return self - - async def __aexit__( - self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], - ) -> None: - await self.finish(exc_val) - - async def __aiter__(self) -> AsyncIterator[Buffer]: - while True: - data = await self.read() - if not data: - break - yield data - - async def read(self) -> Buffer: - return await self.connection.wait(self._read_gen()) - - async def rows(self) -> AsyncIterator[Tuple[Any, ...]]: - while True: - record = await self.read_row() - if record is None: - break - yield record - - async def read_row(self) -> Optional[Tuple[Any, ...]]: - return await self.connection.wait(self._read_row_gen()) - - async def write(self, buffer: Union[Buffer, str]) -> None: - data = self.formatter.write(buffer) - if data: - await self._write(data) - - async def write_row(self, row: Sequence[Any]) -> None: - data = self.formatter.write_row(row) - if data: - await self._write(data) - - async def finish(self, exc: Optional[BaseException]) -> None: - if self._direction == COPY_IN: - data = self.formatter.end() - if data: - await self._write(data) - await self.writer.finish(exc) - self._finished = True - else: - await self.connection.wait(self._end_copy_out_gen(exc)) - - -class AsyncWriter(ABC): - """ - A class to write copy data somewhere (for async connections). - """ - - @abstractmethod - async def write(self, data: Buffer) -> None: - ... - - async def finish(self, exc: Optional[BaseException] = None) -> None: - pass - - -class AsyncLibpqWriter(AsyncWriter): - """ - An `AsyncWriter` to write copy data to a Postgres database. - """ - - def __init__(self, cursor: "AsyncCursor[Any]"): - self.cursor = cursor - self.connection = cursor.connection - self._pgconn = self.connection.pgconn - - async def write(self, data: Buffer) -> None: - if len(data) <= MAX_BUFFER_SIZE: - # Most used path: we don't need to split the buffer in smaller - # bits, so don't make a copy. - await self.connection.wait(copy_to(self._pgconn, data)) - else: - # Copy a buffer too large in chunks to avoid causing a memory - # error in the libpq, which may cause an infinite loop (#255). - for i in range(0, len(data), MAX_BUFFER_SIZE): - await self.connection.wait( - copy_to(self._pgconn, data[i : i + MAX_BUFFER_SIZE]) - ) - - async def finish(self, exc: Optional[BaseException] = None) -> None: - bmsg: Optional[bytes] - if exc: - msg = f"error from Python: {type(exc).__qualname__} - {exc}" - bmsg = msg.encode(pgconn_encoding(self._pgconn), "replace") - else: - bmsg = None - - try: - res = await self.connection.wait(copy_end(self._pgconn, bmsg)) - # The QueryCanceled is expected if we sent an exception message to - # pgconn.put_copy_end(). The Python exception that generated that - # cancelling is more important, so don't clobber it. - except e.QueryCanceled: - if not bmsg: - raise - else: - self.cursor._results = [res] - - -class AsyncQueuedLibpqWriter(AsyncLibpqWriter): - """ - An `AsyncWriter` using a buffer to queue data to write. - - `write()` returns immediately, so that the main thread can be CPU-bound - formatting messages, while a worker thread can be IO-bound waiting to write - on the connection. - """ - - def __init__(self, cursor: "AsyncCursor[Any]"): - super().__init__(cursor) - - self._queue: asyncio.Queue[Buffer] = asyncio.Queue(maxsize=QUEUE_SIZE) - self._worker: Optional[asyncio.Future[None]] = None - - async def worker(self) -> None: - """Push data to the server when available from the copy queue. - - Terminate reading when the queue receives a false-y value. - - The function is designed to be run in a separate task. - """ - while True: - data = await self._queue.get() - if not data: - break - await self.connection.wait(copy_to(self._pgconn, data)) - - async def write(self, data: Buffer) -> None: - if not self._worker: - self._worker = asyncio.create_task(self.worker()) - - if len(data) <= MAX_BUFFER_SIZE: - # Most used path: we don't need to split the buffer in smaller - # bits, so don't make a copy. - await self._queue.put(data) - else: - # Copy a buffer too large in chunks to avoid causing a memory - # error in the libpq, which may cause an infinite loop (#255). - for i in range(0, len(data), MAX_BUFFER_SIZE): - await self._queue.put(data[i : i + MAX_BUFFER_SIZE]) - - async def finish(self, exc: Optional[BaseException] = None) -> None: - await self._queue.put(b"") - - if self._worker: - await asyncio.gather(self._worker) - self._worker = None # break reference loops if any - - await super().finish(exc) - - -class Formatter(ABC): - """ - A class which understand a copy format (text, binary). - """ - - format: pq.Format - - def __init__(self, transformer: Transformer): - self.transformer = transformer - self._write_buffer = bytearray() - self._row_mode = False # true if the user is using write_row() - - @abstractmethod - def parse_row(self, data: Buffer) -> Optional[Tuple[Any, ...]]: - ... - - @abstractmethod - def write(self, buffer: Union[Buffer, str]) -> Buffer: - ... - - @abstractmethod - def write_row(self, row: Sequence[Any]) -> Buffer: - ... - - @abstractmethod - def end(self) -> Buffer: - ... - - -class TextFormatter(Formatter): - format = TEXT - - def __init__(self, transformer: Transformer, encoding: str = "utf-8"): - super().__init__(transformer) - self._encoding = encoding - - def parse_row(self, data: Buffer) -> Optional[Tuple[Any, ...]]: - if data: - return parse_row_text(data, self.transformer) - else: - return None - - def write(self, buffer: Union[Buffer, str]) -> Buffer: - data = self._ensure_bytes(buffer) - self._signature_sent = True - return data - - def write_row(self, row: Sequence[Any]) -> Buffer: - # Note down that we are writing in row mode: it means we will have - # to take care of the end-of-copy marker too - self._row_mode = True - - format_row_text(row, self.transformer, self._write_buffer) - if len(self._write_buffer) > BUFFER_SIZE: - buffer, self._write_buffer = self._write_buffer, bytearray() - return buffer - else: - return b"" - - def end(self) -> Buffer: - buffer, self._write_buffer = self._write_buffer, bytearray() - return buffer - - def _ensure_bytes(self, data: Union[Buffer, str]) -> Buffer: - if isinstance(data, str): - return data.encode(self._encoding) - else: - # Assume, for simplicity, that the user is not passing stupid - # things to the write function. If that's the case, things - # will fail downstream. - return data - - -class BinaryFormatter(Formatter): - format = BINARY - - def __init__(self, transformer: Transformer): - super().__init__(transformer) - self._signature_sent = False - - def parse_row(self, data: Buffer) -> Optional[Tuple[Any, ...]]: - if not self._signature_sent: - if data[: len(_binary_signature)] != _binary_signature: - raise e.DataError( - "binary copy doesn't start with the expected signature" - ) - self._signature_sent = True - data = data[len(_binary_signature) :] - - elif data == _binary_trailer: - return None - - return parse_row_binary(data, self.transformer) - - def write(self, buffer: Union[Buffer, str]) -> Buffer: - data = self._ensure_bytes(buffer) - self._signature_sent = True - return data - - def write_row(self, row: Sequence[Any]) -> Buffer: - # Note down that we are writing in row mode: it means we will have - # to take care of the end-of-copy marker too - self._row_mode = True - - if not self._signature_sent: - self._write_buffer += _binary_signature - self._signature_sent = True - - format_row_binary(row, self.transformer, self._write_buffer) - if len(self._write_buffer) > BUFFER_SIZE: - buffer, self._write_buffer = self._write_buffer, bytearray() - return buffer - else: - return b"" - - def end(self) -> Buffer: - # If we have sent no data we need to send the signature - # and the trailer - if not self._signature_sent: - self._write_buffer += _binary_signature - self._write_buffer += _binary_trailer - - elif self._row_mode: - # if we have sent data already, we have sent the signature - # too (either with the first row, or we assume that in - # block mode the signature is included). - # Write the trailer only if we are sending rows (with the - # assumption that who is copying binary data is sending the - # whole format). - self._write_buffer += _binary_trailer - - buffer, self._write_buffer = self._write_buffer, bytearray() - return buffer - - def _ensure_bytes(self, data: Union[Buffer, str]) -> Buffer: - if isinstance(data, str): - raise TypeError("cannot copy str data in binary mode: use bytes instead") - else: - # Assume, for simplicity, that the user is not passing stupid - # things to the write function. If that's the case, things - # will fail downstream. - return data - - -def _format_row_text( - row: Sequence[Any], tx: Transformer, out: Optional[bytearray] = None -) -> bytearray: - """Convert a row of objects to the data to send for copy.""" - if out is None: - out = bytearray() - - if not row: - out += b"\n" - return out - - adapted = tx.dump_sequence(row, [PY_TEXT] * len(row)) - for b in adapted: - out += _dump_re.sub(_dump_sub, b) if b is not None else rb"\N" - out += b"\t" - - out[-1:] = b"\n" - return out - - -def _format_row_binary( - row: Sequence[Any], tx: Transformer, out: Optional[bytearray] = None -) -> bytearray: - """Convert a row of objects to the data to send for binary copy.""" - if out is None: - out = bytearray() - - out += _pack_int2(len(row)) - adapted = tx.dump_sequence(row, [PY_BINARY] * len(row)) - for b in adapted: - if b is not None: - out += _pack_int4(len(b)) - out += b - else: - out += _binary_null - - return out - - -def _parse_row_text(data: Buffer, tx: Transformer) -> Tuple[Any, ...]: - if not isinstance(data, bytes): - data = bytes(data) - fields = data.split(b"\t") - fields[-1] = fields[-1][:-1] # drop \n - row = [None if f == b"\\N" else _load_re.sub(_load_sub, f) for f in fields] - return tx.load_sequence(row) - - -def _parse_row_binary(data: Buffer, tx: Transformer) -> Tuple[Any, ...]: - row: List[Optional[Buffer]] = [] - nfields = _unpack_int2(data, 0)[0] - pos = 2 - for i in range(nfields): - length = _unpack_int4(data, pos)[0] - pos += 4 - if length >= 0: - row.append(data[pos : pos + length]) - pos += length - else: - row.append(None) - - return tx.load_sequence(row) - - -_pack_int2 = struct.Struct("!h").pack -_pack_int4 = struct.Struct("!i").pack -_unpack_int2 = struct.Struct("!h").unpack_from -_unpack_int4 = struct.Struct("!i").unpack_from - -_binary_signature = ( - b"PGCOPY\n\xff\r\n\0" # Signature - b"\x00\x00\x00\x00" # flags - b"\x00\x00\x00\x00" # extra length -) -_binary_trailer = b"\xff\xff" -_binary_null = b"\xff\xff\xff\xff" - -_dump_re = re.compile(b"[\b\t\n\v\f\r\\\\]") -_dump_repl = { - b"\b": b"\\b", - b"\t": b"\\t", - b"\n": b"\\n", - b"\v": b"\\v", - b"\f": b"\\f", - b"\r": b"\\r", - b"\\": b"\\\\", -} - - -def _dump_sub(m: Match[bytes], __map: Dict[bytes, bytes] = _dump_repl) -> bytes: - return __map[m.group(0)] - - -_load_re = re.compile(b"\\\\[btnvfr\\\\]") -_load_repl = {v: k for k, v in _dump_repl.items()} - - -def _load_sub(m: Match[bytes], __map: Dict[bytes, bytes] = _load_repl) -> bytes: - return __map[m.group(0)] - - -# Override functions with fast versions if available -if _psycopg: - format_row_text = _psycopg.format_row_text - format_row_binary = _psycopg.format_row_binary - parse_row_text = _psycopg.parse_row_text - parse_row_binary = _psycopg.parse_row_binary - -else: - format_row_text = _format_row_text - format_row_binary = _format_row_binary - parse_row_text = _parse_row_text - parse_row_binary = _parse_row_binary diff --git a/tests/test_copy.py b/tests/test_copy.py index 5ab25a862..4b3e182e6 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -645,7 +645,7 @@ def test_worker_error_propagated(conn, monkeypatch): raise ZeroDivisionError yield - monkeypatch.setattr(psycopg.copy, "copy_to", copy_to_broken) + monkeypatch.setattr(psycopg._copy, "copy_to", copy_to_broken) cur = conn.cursor() cur.execute("create temp table wat (a text, b text)") with pytest.raises(ZeroDivisionError): diff --git a/tests/test_copy_async.py b/tests/test_copy_async.py index c62d9f2d9..68c21d8ef 100644 --- a/tests/test_copy_async.py +++ b/tests/test_copy_async.py @@ -660,7 +660,7 @@ async def test_worker_error_propagated(aconn, monkeypatch): raise ZeroDivisionError yield - monkeypatch.setattr(psycopg.copy, "copy_to", copy_to_broken) + monkeypatch.setattr(psycopg._copy_async, "copy_to", copy_to_broken) cur = aconn.cursor() await cur.execute("create temp table wat (a text, b text)") with pytest.raises(ZeroDivisionError): diff --git a/tools/async_to_sync.py b/tools/async_to_sync.py index 6b028d53e..28f98a783 100755 --- a/tools/async_to_sync.py +++ b/tools/async_to_sync.py @@ -182,6 +182,7 @@ class RenameAsyncToSync(ast.NodeTransformer): "__aenter__": "__enter__", "__aexit__": "__exit__", "__aiter__": "__iter__", + "_copy_async": "_copy", "aclose": "close", "aclosing": "closing", "acommands": "commands", @@ -289,12 +290,18 @@ class RenameAsyncToSync(ast.NodeTransformer): for base in node.bases: if not isinstance(base, ast.Subscript): continue - if not isinstance(base.slice, ast.Tuple): - continue - for elt in base.slice.elts: - if not (isinstance(elt, ast.Constant) and isinstance(elt.value, str)): + + if isinstance(base.slice, ast.Constant): + if not isinstance(base.slice.value, str): continue - elt.value = self._visit_type_string(elt.value) + base.slice.value = self._visit_type_string(base.slice.value) + elif isinstance(base.slice, ast.Tuple): + for elt in base.slice.elts: + if not ( + isinstance(elt, ast.Constant) and isinstance(elt.value, str) + ): + continue + elt.value = self._visit_type_string(elt.value) return node diff --git a/tools/convert_async_to_sync.sh b/tools/convert_async_to_sync.sh index 68732131f..5932dd7cb 100755 --- a/tools/convert_async_to_sync.sh +++ b/tools/convert_async_to_sync.sh @@ -27,6 +27,7 @@ if [[ ${1:-} == '--check' ]]; then fi all_inputs=" + psycopg/psycopg/_copy_async.py psycopg/psycopg/connection_async.py psycopg/psycopg/cursor_async.py psycopg_pool/psycopg_pool/null_pool_async.py