]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
refactor(copy): generate sync code from async
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 6 Oct 2023 10:58:49 +0000 (12:58 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 11 Oct 2023 21:45:39 +0000 (23:45 +0200)
psycopg/psycopg/_acompat.py [new file with mode: 0644]
psycopg/psycopg/_copy.py [new file with mode: 0644]
psycopg/psycopg/_copy_async.py [new file with mode: 0644]
psycopg/psycopg/_copy_base.py [new file with mode: 0644]
psycopg/psycopg/copy.py
tests/test_copy.py
tests/test_copy_async.py
tools/async_to_sync.py
tools/convert_async_to_sync.sh

diff --git a/psycopg/psycopg/_acompat.py b/psycopg/psycopg/_acompat.py
new file mode 100644 (file)
index 0000000..4915cab
--- /dev/null
@@ -0,0 +1,106 @@
+"""
+Utilities to ease the differences between async and sync code.
+
+These object offer a similar interface between sync and async versions; the
+script async_to_sync.py will replace the async names with the sync names
+when generating the sync version.
+"""
+
+# Copyright (C) 2023 The Psycopg Team
+
+from __future__ import annotations
+
+import queue
+import asyncio
+import threading
+from typing import Any, Callable, Coroutine, TypeVar, TYPE_CHECKING
+
+from typing_extensions import TypeAlias
+
+Worker: TypeAlias = threading.Thread
+AWorker: TypeAlias = "asyncio.Task[None]"
+T = TypeVar("T")
+
+# Hack required on Python 3.8 because subclassing Queue[T] fails at runtime.
+# https://stackoverflow.com/questions/45414066/mypy-how-to-define-a-generic-subclass
+if TYPE_CHECKING:
+    _GQueue: TypeAlias = queue.Queue
+    _AGQueue: TypeAlias = asyncio.Queue
+
+else:
+
+    class FakeGenericMeta(type):
+        def __getitem__(self, item):
+            return self
+
+    class _GQueue(queue.Queue, metaclass=FakeGenericMeta):
+        pass
+
+    class _AGQueue(asyncio.Queue, metaclass=FakeGenericMeta):
+        pass
+
+
+class Queue(_GQueue[T]):
+    """
+    A Queue subclass with an interruptible get() method.
+    """
+
+    def get(self, block: bool = True, timeout: float | None = None) -> T:
+        # Always specify a timeout to make the wait interruptible.
+        if timeout is None:
+            timeout = 24.0 * 60.0 * 60.0
+        return super().get(block=block, timeout=timeout)
+
+
+class AQueue(_AGQueue[T]):
+    pass
+
+
+def aspawn(
+    f: Callable[..., Coroutine[Any, Any, None]],
+    args: tuple[Any, ...] = (),
+    name: str | None = None,
+) -> asyncio.Task[None]:
+    """
+    Equivalent to asyncio.create_task.
+    """
+    return asyncio.create_task(f(*args), name=name)
+
+
+def spawn(
+    f: Callable[..., Any],
+    args: tuple[Any, ...] = (),
+    name: str | None = None,
+) -> threading.Thread:
+    """
+    Equivalent to creating and running a daemon thread.
+    """
+    t = threading.Thread(target=f, args=args, name=name, daemon=True)
+    t.start()
+    return t
+
+
+async def agather(*tasks: asyncio.Task[Any], timeout: float | None = None) -> None:
+    """
+    Equivalent to asyncio.gather or Thread.join()
+    """
+    wait = asyncio.gather(*tasks)
+    try:
+        if timeout is not None:
+            await asyncio.wait_for(asyncio.shield(wait), timeout=timeout)
+        else:
+            await wait
+    except asyncio.TimeoutError:
+        pass
+    else:
+        return
+
+
+def gather(*tasks: threading.Thread, timeout: float | None = None) -> None:
+    """
+    Equivalent to asyncio.gather or Thread.join()
+    """
+    for t in tasks:
+        if not t.is_alive():
+            continue
+        t.join(timeout)
diff --git a/psycopg/psycopg/_copy.py b/psycopg/psycopg/_copy.py
new file mode 100644 (file)
index 0000000..5713389
--- /dev/null
@@ -0,0 +1,283 @@
+# WARNING: this file is auto-generated by 'async_to_sync.py'
+# from the original file '_copy_async.py'
+# DO NOT CHANGE! Change the original file instead.
+"""
+Psycopg Copy and related objects.
+"""
+
+# Copyright (C) 2023 The Psycopg Team
+
+from __future__ import annotations
+
+from abc import ABC, abstractmethod
+from types import TracebackType
+from typing import Any, Iterator, Type, Tuple, Sequence, TYPE_CHECKING
+
+from . import pq
+from . import errors as e
+from ._copy_base import BaseCopy, MAX_BUFFER_SIZE, QUEUE_SIZE
+from .generators import copy_to, copy_end
+from ._encodings import pgconn_encoding
+from ._acompat import spawn, gather, Queue, Worker
+
+if TYPE_CHECKING:
+    from .abc import Buffer
+    from .cursor import Cursor
+    from .connection import Connection  # noqa: F401
+
+COPY_IN = pq.ExecStatus.COPY_IN
+COPY_OUT = pq.ExecStatus.COPY_OUT
+
+
+class Copy(BaseCopy["Connection[Any]"]):
+    """Manage an asynchronous :sql:`COPY` operation.
+
+    :param cursor: the cursor where the operation is performed.
+    :param binary: if `!True`, write binary format.
+    :param writer: the object to write to destination. If not specified, write
+        to the `!cursor` connection.
+
+    Choosing `!binary` is not necessary if the cursor has executed a
+    :sql:`COPY` operation, because the operation result describes the format
+    too. The parameter is useful when a `!Copy` object is created manually and
+    no operation is performed on the cursor, such as when using ``writer=``\\
+    `~psycopg.copy.FileWriter`.
+    """
+
+    __module__ = "psycopg"
+
+    writer: Writer
+
+    def __init__(
+        self,
+        cursor: Cursor[Any],
+        *,
+        binary: bool | None = None,
+        writer: Writer | None = None,
+    ):
+        super().__init__(cursor, binary=binary)
+        if not writer:
+            writer = LibpqWriter(cursor)
+
+        self.writer = writer
+        self._write = writer.write
+
+    def __enter__(self: BaseCopy._Self) -> BaseCopy._Self:
+        self._enter()
+        return self
+
+    def __exit__(
+        self,
+        exc_type: Type[BaseException] | None,
+        exc_val: BaseException | None,
+        exc_tb: TracebackType | None,
+    ) -> None:
+        self.finish(exc_val)
+
+    # End user sync interface
+
+    def __iter__(self) -> Iterator[Buffer]:
+        """Implement block-by-block iteration on :sql:`COPY TO`."""
+        while True:
+            data = self.read()
+            if not data:
+                break
+            yield data
+
+    def read(self) -> Buffer:
+        """
+        Read an unparsed row after a :sql:`COPY TO` operation.
+
+        Return an empty string when the data is finished.
+        """
+        return self.connection.wait(self._read_gen())
+
+    def rows(self) -> Iterator[Tuple[Any, ...]]:
+        """
+        Iterate on the result of a :sql:`COPY TO` operation record by record.
+
+        Note that the records returned will be tuples of unparsed strings or
+        bytes, unless data types are specified using `set_types()`.
+        """
+        while True:
+            record = self.read_row()
+            if record is None:
+                break
+            yield record
+
+    def read_row(self) -> Tuple[Any, ...] | None:
+        """
+        Read a parsed row of data from a table after a :sql:`COPY TO` operation.
+
+        Return `!None` when the data is finished.
+
+        Note that the records returned will be tuples of unparsed strings or
+        bytes, unless data types are specified using `set_types()`.
+        """
+        return self.connection.wait(self._read_row_gen())
+
+    def write(self, buffer: Buffer | str) -> None:
+        """
+        Write a block of data to a table after a :sql:`COPY FROM` operation.
+
+        If the :sql:`COPY` is in binary format `!buffer` must be `!bytes`. In
+        text mode it can be either `!bytes` or `!str`.
+        """
+        data = self.formatter.write(buffer)
+        if data:
+            self._write(data)
+
+    def write_row(self, row: Sequence[Any]) -> None:
+        """Write a record to a table after a :sql:`COPY FROM` operation."""
+        data = self.formatter.write_row(row)
+        if data:
+            self._write(data)
+
+    def finish(self, exc: BaseException | None) -> None:
+        """Terminate the copy operation and free the resources allocated.
+
+        You shouldn't need to call this function yourself: it is usually called
+        by exit. It is available if, despite what is documented, you end up
+        using the `Copy` object outside a block.
+        """
+        if self._direction == COPY_IN:
+            data = self.formatter.end()
+            if data:
+                self._write(data)
+            self.writer.finish(exc)
+            self._finished = True
+        else:
+            self.connection.wait(self._end_copy_out_gen(exc))
+
+
+class Writer(ABC):
+    """
+    A class to write copy data somewhere (for async connections).
+    """
+
+    @abstractmethod
+    def write(self, data: Buffer) -> None:
+        """Write some data to destination."""
+        ...
+
+    def finish(self, exc: BaseException | None = None) -> None:
+        """
+        Called when write operations are finished.
+
+        If operations finished with an error, it will be passed to ``exc``.
+        """
+        pass
+
+
+class LibpqWriter(Writer):
+    """
+    An `Writer` to write copy data to a Postgres database.
+    """
+
+    __module__ = "psycopg.copy"
+
+    def __init__(self, cursor: Cursor[Any]):
+        self.cursor = cursor
+        self.connection = cursor.connection
+        self._pgconn = self.connection.pgconn
+
+    def write(self, data: Buffer) -> None:
+        if len(data) <= MAX_BUFFER_SIZE:
+            # Most used path: we don't need to split the buffer in smaller
+            # bits, so don't make a copy.
+            self.connection.wait(copy_to(self._pgconn, data))
+        else:
+            # Copy a buffer too large in chunks to avoid causing a memory
+            # error in the libpq, which may cause an infinite loop (#255).
+            for i in range(0, len(data), MAX_BUFFER_SIZE):
+                self.connection.wait(
+                    copy_to(self._pgconn, data[i : i + MAX_BUFFER_SIZE])
+                )
+
+    def finish(self, exc: BaseException | None = None) -> None:
+        bmsg: bytes | None
+        if exc:
+            msg = f"error from Python: {type(exc).__qualname__} - {exc}"
+            bmsg = msg.encode(pgconn_encoding(self._pgconn), "replace")
+        else:
+            bmsg = None
+
+        try:
+            res = self.connection.wait(copy_end(self._pgconn, bmsg))
+        # The QueryCanceled is expected if we sent an exception message to
+        # pgconn.put_copy_end(). The Python exception that generated that
+        # cancelling is more important, so don't clobber it.
+        except e.QueryCanceled:
+            if not bmsg:
+                raise
+        else:
+            self.cursor._results = [res]
+
+
+class QueuedLibpqWriter(LibpqWriter):
+    """
+    `Writer` using a buffer to queue data to write.
+
+    `write()` returns immediately, so that the main thread can be CPU-bound
+    formatting messages, while a worker thread can be IO-bound waiting to write
+    on the connection.
+    """
+
+    __module__ = "psycopg.copy"
+
+    def __init__(self, cursor: Cursor[Any]):
+        super().__init__(cursor)
+
+        self._queue: Queue[Buffer] = Queue(maxsize=QUEUE_SIZE)
+        self._worker: Worker | None = None
+        self._worker_error: BaseException | None = None
+
+    def worker(self) -> None:
+        """Push data to the server when available from the copy queue.
+
+        Terminate reading when the queue receives a false-y value, or in case
+        of error.
+
+        The function is designed to be run in a separate task.
+        """
+        try:
+            while True:
+                data = self._queue.get()
+                if not data:
+                    break
+                self.connection.wait(copy_to(self._pgconn, data))
+        except BaseException as ex:
+            # Propagate the error to the main thread.
+            self._worker_error = ex
+
+    def write(self, data: Buffer) -> None:
+        if not self._worker:
+            # warning: reference loop, broken by _write_end
+            self._worker = spawn(self.worker)
+
+        # If the worker thread raies an exception, re-raise it to the caller.
+        if self._worker_error:
+            raise self._worker_error
+
+        if len(data) <= MAX_BUFFER_SIZE:
+            # Most used path: we don't need to split the buffer in smaller
+            # bits, so don't make a copy.
+            self._queue.put(data)
+        else:
+            # Copy a buffer too large in chunks to avoid causing a memory
+            # error in the libpq, which may cause an infinite loop (#255).
+            for i in range(0, len(data), MAX_BUFFER_SIZE):
+                self._queue.put(data[i : i + MAX_BUFFER_SIZE])
+
+    def finish(self, exc: BaseException | None = None) -> None:
+        self._queue.put(b"")
+
+        if self._worker:
+            gather(self._worker)
+            self._worker = None  # break reference loops if any
+
+        # Check if the worker thread raised any exception before terminating.
+        if self._worker_error:
+            raise self._worker_error
+
+        super().finish(exc)
diff --git a/psycopg/psycopg/_copy_async.py b/psycopg/psycopg/_copy_async.py
new file mode 100644 (file)
index 0000000..2483353
--- /dev/null
@@ -0,0 +1,280 @@
+"""
+Psycopg AsyncCopy and related objects.
+"""
+
+# Copyright (C) 2023 The Psycopg Team
+
+from __future__ import annotations
+
+from abc import ABC, abstractmethod
+from types import TracebackType
+from typing import Any, AsyncIterator, Type, Tuple, Sequence, TYPE_CHECKING
+
+from . import pq
+from . import errors as e
+from ._copy_base import BaseCopy, MAX_BUFFER_SIZE, QUEUE_SIZE
+from .generators import copy_to, copy_end
+from ._encodings import pgconn_encoding
+from ._acompat import aspawn, agather, AQueue, AWorker
+
+if TYPE_CHECKING:
+    from .abc import Buffer
+    from .cursor_async import AsyncCursor
+    from .connection_async import AsyncConnection  # noqa: F401
+
+COPY_IN = pq.ExecStatus.COPY_IN
+COPY_OUT = pq.ExecStatus.COPY_OUT
+
+
+class AsyncCopy(BaseCopy["AsyncConnection[Any]"]):
+    """Manage an asynchronous :sql:`COPY` operation.
+
+    :param cursor: the cursor where the operation is performed.
+    :param binary: if `!True`, write binary format.
+    :param writer: the object to write to destination. If not specified, write
+        to the `!cursor` connection.
+
+    Choosing `!binary` is not necessary if the cursor has executed a
+    :sql:`COPY` operation, because the operation result describes the format
+    too. The parameter is useful when a `!Copy` object is created manually and
+    no operation is performed on the cursor, such as when using ``writer=``\\
+    `~psycopg.copy.FileWriter`.
+    """
+
+    __module__ = "psycopg"
+
+    writer: AsyncWriter
+
+    def __init__(
+        self,
+        cursor: AsyncCursor[Any],
+        *,
+        binary: bool | None = None,
+        writer: AsyncWriter | None = None,
+    ):
+        super().__init__(cursor, binary=binary)
+        if not writer:
+            writer = AsyncLibpqWriter(cursor)
+
+        self.writer = writer
+        self._write = writer.write
+
+    async def __aenter__(self: BaseCopy._Self) -> BaseCopy._Self:
+        self._enter()
+        return self
+
+    async def __aexit__(
+        self,
+        exc_type: Type[BaseException] | None,
+        exc_val: BaseException | None,
+        exc_tb: TracebackType | None,
+    ) -> None:
+        await self.finish(exc_val)
+
+    # End user sync interface
+
+    async def __aiter__(self) -> AsyncIterator[Buffer]:
+        """Implement block-by-block iteration on :sql:`COPY TO`."""
+        while True:
+            data = await self.read()
+            if not data:
+                break
+            yield data
+
+    async def read(self) -> Buffer:
+        """
+        Read an unparsed row after a :sql:`COPY TO` operation.
+
+        Return an empty string when the data is finished.
+        """
+        return await self.connection.wait(self._read_gen())
+
+    async def rows(self) -> AsyncIterator[Tuple[Any, ...]]:
+        """
+        Iterate on the result of a :sql:`COPY TO` operation record by record.
+
+        Note that the records returned will be tuples of unparsed strings or
+        bytes, unless data types are specified using `set_types()`.
+        """
+        while True:
+            record = await self.read_row()
+            if record is None:
+                break
+            yield record
+
+    async def read_row(self) -> Tuple[Any, ...] | None:
+        """
+        Read a parsed row of data from a table after a :sql:`COPY TO` operation.
+
+        Return `!None` when the data is finished.
+
+        Note that the records returned will be tuples of unparsed strings or
+        bytes, unless data types are specified using `set_types()`.
+        """
+        return await self.connection.wait(self._read_row_gen())
+
+    async def write(self, buffer: Buffer | str) -> None:
+        """
+        Write a block of data to a table after a :sql:`COPY FROM` operation.
+
+        If the :sql:`COPY` is in binary format `!buffer` must be `!bytes`. In
+        text mode it can be either `!bytes` or `!str`.
+        """
+        data = self.formatter.write(buffer)
+        if data:
+            await self._write(data)
+
+    async def write_row(self, row: Sequence[Any]) -> None:
+        """Write a record to a table after a :sql:`COPY FROM` operation."""
+        data = self.formatter.write_row(row)
+        if data:
+            await self._write(data)
+
+    async def finish(self, exc: BaseException | None) -> None:
+        """Terminate the copy operation and free the resources allocated.
+
+        You shouldn't need to call this function yourself: it is usually called
+        by exit. It is available if, despite what is documented, you end up
+        using the `Copy` object outside a block.
+        """
+        if self._direction == COPY_IN:
+            data = self.formatter.end()
+            if data:
+                await self._write(data)
+            await self.writer.finish(exc)
+            self._finished = True
+        else:
+            await self.connection.wait(self._end_copy_out_gen(exc))
+
+
+class AsyncWriter(ABC):
+    """
+    A class to write copy data somewhere (for async connections).
+    """
+
+    @abstractmethod
+    async def write(self, data: Buffer) -> None:
+        """Write some data to destination."""
+        ...
+
+    async def finish(self, exc: BaseException | None = None) -> None:
+        """
+        Called when write operations are finished.
+
+        If operations finished with an error, it will be passed to ``exc``.
+        """
+        pass
+
+
+class AsyncLibpqWriter(AsyncWriter):
+    """
+    An `AsyncWriter` to write copy data to a Postgres database.
+    """
+
+    __module__ = "psycopg.copy"
+
+    def __init__(self, cursor: AsyncCursor[Any]):
+        self.cursor = cursor
+        self.connection = cursor.connection
+        self._pgconn = self.connection.pgconn
+
+    async def write(self, data: Buffer) -> None:
+        if len(data) <= MAX_BUFFER_SIZE:
+            # Most used path: we don't need to split the buffer in smaller
+            # bits, so don't make a copy.
+            await self.connection.wait(copy_to(self._pgconn, data))
+        else:
+            # Copy a buffer too large in chunks to avoid causing a memory
+            # error in the libpq, which may cause an infinite loop (#255).
+            for i in range(0, len(data), MAX_BUFFER_SIZE):
+                await self.connection.wait(
+                    copy_to(self._pgconn, data[i : i + MAX_BUFFER_SIZE])
+                )
+
+    async def finish(self, exc: BaseException | None = None) -> None:
+        bmsg: bytes | None
+        if exc:
+            msg = f"error from Python: {type(exc).__qualname__} - {exc}"
+            bmsg = msg.encode(pgconn_encoding(self._pgconn), "replace")
+        else:
+            bmsg = None
+
+        try:
+            res = await self.connection.wait(copy_end(self._pgconn, bmsg))
+        # The QueryCanceled is expected if we sent an exception message to
+        # pgconn.put_copy_end(). The Python exception that generated that
+        # cancelling is more important, so don't clobber it.
+        except e.QueryCanceled:
+            if not bmsg:
+                raise
+        else:
+            self.cursor._results = [res]
+
+
+class AsyncQueuedLibpqWriter(AsyncLibpqWriter):
+    """
+    `AsyncWriter` using a buffer to queue data to write.
+
+    `write()` returns immediately, so that the main thread can be CPU-bound
+    formatting messages, while a worker thread can be IO-bound waiting to write
+    on the connection.
+    """
+
+    __module__ = "psycopg.copy"
+
+    def __init__(self, cursor: AsyncCursor[Any]):
+        super().__init__(cursor)
+
+        self._queue: AQueue[Buffer] = AQueue(maxsize=QUEUE_SIZE)
+        self._worker: AWorker | None = None
+        self._worker_error: BaseException | None = None
+
+    async def worker(self) -> None:
+        """Push data to the server when available from the copy queue.
+
+        Terminate reading when the queue receives a false-y value, or in case
+        of error.
+
+        The function is designed to be run in a separate task.
+        """
+        try:
+            while True:
+                data = await self._queue.get()
+                if not data:
+                    break
+                await self.connection.wait(copy_to(self._pgconn, data))
+        except BaseException as ex:
+            # Propagate the error to the main thread.
+            self._worker_error = ex
+
+    async def write(self, data: Buffer) -> None:
+        if not self._worker:
+            # warning: reference loop, broken by _write_end
+            self._worker = aspawn(self.worker)
+
+        # If the worker thread raies an exception, re-raise it to the caller.
+        if self._worker_error:
+            raise self._worker_error
+
+        if len(data) <= MAX_BUFFER_SIZE:
+            # Most used path: we don't need to split the buffer in smaller
+            # bits, so don't make a copy.
+            await self._queue.put(data)
+        else:
+            # Copy a buffer too large in chunks to avoid causing a memory
+            # error in the libpq, which may cause an infinite loop (#255).
+            for i in range(0, len(data), MAX_BUFFER_SIZE):
+                await self._queue.put(data[i : i + MAX_BUFFER_SIZE])
+
+    async def finish(self, exc: BaseException | None = None) -> None:
+        await self._queue.put(b"")
+
+        if self._worker:
+            await agather(self._worker)
+            self._worker = None  # break reference loops if any
+
+        # Check if the worker thread raised any exception before terminating.
+        if self._worker_error:
+            raise self._worker_error
+
+        await super().finish(exc)
diff --git a/psycopg/psycopg/_copy_base.py b/psycopg/psycopg/_copy_base.py
new file mode 100644 (file)
index 0000000..8f2a8f7
--- /dev/null
@@ -0,0 +1,458 @@
+"""
+psycopg copy support
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from __future__ import annotations
+
+import re
+import struct
+from abc import ABC, abstractmethod
+from typing import Any, Dict, Generic, List, Match
+from typing import Optional, Sequence, Tuple, TypeVar, Union, TYPE_CHECKING
+
+from . import pq
+from . import adapt
+from . import errors as e
+from .abc import Buffer, ConnectionType, PQGen, Transformer
+from .pq.misc import connection_summary
+from ._cmodule import _psycopg
+from ._encodings import pgconn_encoding
+from .generators import copy_from
+
+if TYPE_CHECKING:
+    from ._cursor_base import BaseCursor
+    from .connection import Connection  # noqa: F401
+    from .connection_async import AsyncConnection  # noqa: F401
+
+PY_TEXT = adapt.PyFormat.TEXT
+PY_BINARY = adapt.PyFormat.BINARY
+
+TEXT = pq.Format.TEXT
+BINARY = pq.Format.BINARY
+
+COPY_IN = pq.ExecStatus.COPY_IN
+COPY_OUT = pq.ExecStatus.COPY_OUT
+
+ACTIVE = pq.TransactionStatus.ACTIVE
+
+# Size of data to accumulate before sending it down the network. We fill a
+# buffer this size field by field, and when it passes the threshold size
+# we ship it, so it may end up being bigger than this.
+BUFFER_SIZE = 32 * 1024
+
+# Maximum data size we want to queue to send to the libpq copy. Sending a
+# buffer too big to be handled can cause an infinite loop in the libpq
+# (#255) so we want to split it in more digestable chunks.
+MAX_BUFFER_SIZE = 4 * BUFFER_SIZE
+# Note: making this buffer too large, e.g.
+# MAX_BUFFER_SIZE = 1024 * 1024
+# makes operations *way* slower! Probably triggering some quadraticity
+# in the libpq memory management and data sending.
+
+# Max size of the write queue of buffers. More than that copy will block
+# Each buffer should be around BUFFER_SIZE size.
+QUEUE_SIZE = 1024
+
+
+class BaseCopy(Generic[ConnectionType]):
+    """
+    Base implementation for the copy user interface.
+
+    Two subclasses expose real methods with the sync/async differences.
+
+    The difference between the text and binary format is managed by two
+    different `Formatter` subclasses.
+
+    Writing (the I/O part) is implemented in the subclasses by a `Writer` or
+    `AsyncWriter` instance. Normally writing implies sending copy data to a
+    database, but a different writer might be chosen, e.g. to stream data into
+    a file for later use.
+    """
+
+    _Self = TypeVar("_Self", bound="BaseCopy[Any]")
+
+    formatter: Formatter
+
+    def __init__(
+        self,
+        cursor: "BaseCursor[ConnectionType, Any]",
+        *,
+        binary: Optional[bool] = None,
+    ):
+        self.cursor = cursor
+        self.connection = cursor.connection
+        self._pgconn = self.connection.pgconn
+
+        result = cursor.pgresult
+        if result:
+            self._direction = result.status
+            if self._direction != COPY_IN and self._direction != COPY_OUT:
+                raise e.ProgrammingError(
+                    "the cursor should have performed a COPY operation;"
+                    f" its status is {pq.ExecStatus(self._direction).name} instead"
+                )
+        else:
+            self._direction = COPY_IN
+
+        if binary is None:
+            binary = bool(result and result.binary_tuples)
+
+        tx: Transformer = getattr(cursor, "_tx", None) or adapt.Transformer(cursor)
+        if binary:
+            self.formatter = BinaryFormatter(tx)
+        else:
+            self.formatter = TextFormatter(tx, encoding=pgconn_encoding(self._pgconn))
+
+        self._finished = False
+
+    def __repr__(self) -> str:
+        cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
+        info = connection_summary(self._pgconn)
+        return f"<{cls} {info} at 0x{id(self):x}>"
+
+    def _enter(self) -> None:
+        if self._finished:
+            raise TypeError("copy blocks can be used only once")
+
+    def set_types(self, types: Sequence[Union[int, str]]) -> None:
+        """
+        Set the types expected in a COPY operation.
+
+        The types must be specified as a sequence of oid or PostgreSQL type
+        names (e.g. ``int4``, ``timestamptz[]``).
+
+        This operation overcomes the lack of metadata returned by PostgreSQL
+        when a COPY operation begins:
+
+        - On :sql:`COPY TO`, `!set_types()` allows to specify what types the
+          operation returns. If `!set_types()` is not used, the data will be
+          returned as unparsed strings or bytes instead of Python objects.
+
+        - On :sql:`COPY FROM`, `!set_types()` allows to choose what type the
+          database expects. This is especially useful in binary copy, because
+          PostgreSQL will apply no cast rule.
+
+        """
+        registry = self.cursor.adapters.types
+        oids = [t if isinstance(t, int) else registry.get_oid(t) for t in types]
+
+        if self._direction == COPY_IN:
+            self.formatter.transformer.set_dumper_types(oids, self.formatter.format)
+        else:
+            self.formatter.transformer.set_loader_types(oids, self.formatter.format)
+
+    # High level copy protocol generators (state change of the Copy object)
+
+    def _read_gen(self) -> PQGen[Buffer]:
+        if self._finished:
+            return memoryview(b"")
+
+        res = yield from copy_from(self._pgconn)
+        if isinstance(res, memoryview):
+            return res
+
+        # res is the final PGresult
+        self._finished = True
+
+        # This result is a COMMAND_OK which has info about the number of rows
+        # returned, but not about the columns, which is instead an information
+        # that was received on the COPY_OUT result at the beginning of COPY.
+        # So, don't replace the results in the cursor, just update the rowcount.
+        nrows = res.command_tuples
+        self.cursor._rowcount = nrows if nrows is not None else -1
+        return memoryview(b"")
+
+    def _read_row_gen(self) -> PQGen[Optional[Tuple[Any, ...]]]:
+        data = yield from self._read_gen()
+        if not data:
+            return None
+
+        row = self.formatter.parse_row(data)
+        if row is None:
+            # Get the final result to finish the copy operation
+            yield from self._read_gen()
+            self._finished = True
+            return None
+
+        return row
+
+    def _end_copy_out_gen(self, exc: Optional[BaseException]) -> PQGen[None]:
+        if not exc:
+            return
+
+        if self._pgconn.transaction_status != ACTIVE:
+            # The server has already finished to send copy data. The connection
+            # is already in a good state.
+            return
+
+        # Throw a cancel to the server, then consume the rest of the copy data
+        # (which might or might not have been already transferred entirely to
+        # the client, so we won't necessary see the exception associated with
+        # canceling).
+        self.connection.cancel()
+        try:
+            while (yield from self._read_gen()):
+                pass
+        except e.QueryCanceled:
+            pass
+
+
+class Formatter(ABC):
+    """
+    A class which understand a copy format (text, binary).
+    """
+
+    format: pq.Format
+
+    def __init__(self, transformer: Transformer):
+        self.transformer = transformer
+        self._write_buffer = bytearray()
+        self._row_mode = False  # true if the user is using write_row()
+
+    @abstractmethod
+    def parse_row(self, data: Buffer) -> Optional[Tuple[Any, ...]]:
+        ...
+
+    @abstractmethod
+    def write(self, buffer: Union[Buffer, str]) -> Buffer:
+        ...
+
+    @abstractmethod
+    def write_row(self, row: Sequence[Any]) -> Buffer:
+        ...
+
+    @abstractmethod
+    def end(self) -> Buffer:
+        ...
+
+
+class TextFormatter(Formatter):
+    format = TEXT
+
+    def __init__(self, transformer: Transformer, encoding: str = "utf-8"):
+        super().__init__(transformer)
+        self._encoding = encoding
+
+    def parse_row(self, data: Buffer) -> Optional[Tuple[Any, ...]]:
+        if data:
+            return parse_row_text(data, self.transformer)
+        else:
+            return None
+
+    def write(self, buffer: Union[Buffer, str]) -> Buffer:
+        data = self._ensure_bytes(buffer)
+        self._signature_sent = True
+        return data
+
+    def write_row(self, row: Sequence[Any]) -> Buffer:
+        # Note down that we are writing in row mode: it means we will have
+        # to take care of the end-of-copy marker too
+        self._row_mode = True
+
+        format_row_text(row, self.transformer, self._write_buffer)
+        if len(self._write_buffer) > BUFFER_SIZE:
+            buffer, self._write_buffer = self._write_buffer, bytearray()
+            return buffer
+        else:
+            return b""
+
+    def end(self) -> Buffer:
+        buffer, self._write_buffer = self._write_buffer, bytearray()
+        return buffer
+
+    def _ensure_bytes(self, data: Union[Buffer, str]) -> Buffer:
+        if isinstance(data, str):
+            return data.encode(self._encoding)
+        else:
+            # Assume, for simplicity, that the user is not passing stupid
+            # things to the write function. If that's the case, things
+            # will fail downstream.
+            return data
+
+
+class BinaryFormatter(Formatter):
+    format = BINARY
+
+    def __init__(self, transformer: Transformer):
+        super().__init__(transformer)
+        self._signature_sent = False
+
+    def parse_row(self, data: Buffer) -> Optional[Tuple[Any, ...]]:
+        if not self._signature_sent:
+            if data[: len(_binary_signature)] != _binary_signature:
+                raise e.DataError(
+                    "binary copy doesn't start with the expected signature"
+                )
+            self._signature_sent = True
+            data = data[len(_binary_signature) :]
+
+        elif data == _binary_trailer:
+            return None
+
+        return parse_row_binary(data, self.transformer)
+
+    def write(self, buffer: Union[Buffer, str]) -> Buffer:
+        data = self._ensure_bytes(buffer)
+        self._signature_sent = True
+        return data
+
+    def write_row(self, row: Sequence[Any]) -> Buffer:
+        # Note down that we are writing in row mode: it means we will have
+        # to take care of the end-of-copy marker too
+        self._row_mode = True
+
+        if not self._signature_sent:
+            self._write_buffer += _binary_signature
+            self._signature_sent = True
+
+        format_row_binary(row, self.transformer, self._write_buffer)
+        if len(self._write_buffer) > BUFFER_SIZE:
+            buffer, self._write_buffer = self._write_buffer, bytearray()
+            return buffer
+        else:
+            return b""
+
+    def end(self) -> Buffer:
+        # If we have sent no data we need to send the signature
+        # and the trailer
+        if not self._signature_sent:
+            self._write_buffer += _binary_signature
+            self._write_buffer += _binary_trailer
+
+        elif self._row_mode:
+            # if we have sent data already, we have sent the signature
+            # too (either with the first row, or we assume that in
+            # block mode the signature is included).
+            # Write the trailer only if we are sending rows (with the
+            # assumption that who is copying binary data is sending the
+            # whole format).
+            self._write_buffer += _binary_trailer
+
+        buffer, self._write_buffer = self._write_buffer, bytearray()
+        return buffer
+
+    def _ensure_bytes(self, data: Union[Buffer, str]) -> Buffer:
+        if isinstance(data, str):
+            raise TypeError("cannot copy str data in binary mode: use bytes instead")
+        else:
+            # Assume, for simplicity, that the user is not passing stupid
+            # things to the write function. If that's the case, things
+            # will fail downstream.
+            return data
+
+
+def _format_row_text(
+    row: Sequence[Any], tx: Transformer, out: Optional[bytearray] = None
+) -> bytearray:
+    """Convert a row of objects to the data to send for copy."""
+    if out is None:
+        out = bytearray()
+
+    if not row:
+        out += b"\n"
+        return out
+
+    adapted = tx.dump_sequence(row, [PY_TEXT] * len(row))
+    for b in adapted:
+        out += _dump_re.sub(_dump_sub, b) if b is not None else rb"\N"
+        out += b"\t"
+
+    out[-1:] = b"\n"
+    return out
+
+
+def _format_row_binary(
+    row: Sequence[Any], tx: Transformer, out: Optional[bytearray] = None
+) -> bytearray:
+    """Convert a row of objects to the data to send for binary copy."""
+    if out is None:
+        out = bytearray()
+
+    out += _pack_int2(len(row))
+    adapted = tx.dump_sequence(row, [PY_BINARY] * len(row))
+    for b in adapted:
+        if b is not None:
+            out += _pack_int4(len(b))
+            out += b
+        else:
+            out += _binary_null
+
+    return out
+
+
+def _parse_row_text(data: Buffer, tx: Transformer) -> Tuple[Any, ...]:
+    if not isinstance(data, bytes):
+        data = bytes(data)
+    fields = data.split(b"\t")
+    fields[-1] = fields[-1][:-1]  # drop \n
+    row = [None if f == b"\\N" else _load_re.sub(_load_sub, f) for f in fields]
+    return tx.load_sequence(row)
+
+
+def _parse_row_binary(data: Buffer, tx: Transformer) -> Tuple[Any, ...]:
+    row: List[Optional[Buffer]] = []
+    nfields = _unpack_int2(data, 0)[0]
+    pos = 2
+    for i in range(nfields):
+        length = _unpack_int4(data, pos)[0]
+        pos += 4
+        if length >= 0:
+            row.append(data[pos : pos + length])
+            pos += length
+        else:
+            row.append(None)
+
+    return tx.load_sequence(row)
+
+
+_pack_int2 = struct.Struct("!h").pack
+_pack_int4 = struct.Struct("!i").pack
+_unpack_int2 = struct.Struct("!h").unpack_from
+_unpack_int4 = struct.Struct("!i").unpack_from
+
+_binary_signature = (
+    b"PGCOPY\n\xff\r\n\0"  # Signature
+    b"\x00\x00\x00\x00"  # flags
+    b"\x00\x00\x00\x00"  # extra length
+)
+_binary_trailer = b"\xff\xff"
+_binary_null = b"\xff\xff\xff\xff"
+
+_dump_re = re.compile(b"[\b\t\n\v\f\r\\\\]")
+_dump_repl = {
+    b"\b": b"\\b",
+    b"\t": b"\\t",
+    b"\n": b"\\n",
+    b"\v": b"\\v",
+    b"\f": b"\\f",
+    b"\r": b"\\r",
+    b"\\": b"\\\\",
+}
+
+
+def _dump_sub(m: Match[bytes], __map: Dict[bytes, bytes] = _dump_repl) -> bytes:
+    return __map[m.group(0)]
+
+
+_load_re = re.compile(b"\\\\[btnvfr\\\\]")
+_load_repl = {v: k for k, v in _dump_repl.items()}
+
+
+def _load_sub(m: Match[bytes], __map: Dict[bytes, bytes] = _load_repl) -> bytes:
+    return __map[m.group(0)]
+
+
+# Override functions with fast versions if available
+if _psycopg:
+    format_row_text = _psycopg.format_row_text
+    format_row_binary = _psycopg.format_row_binary
+    parse_row_text = _psycopg.parse_row_text
+    parse_row_binary = _psycopg.parse_row_binary
+
+else:
+    format_row_text = _format_row_text
+    format_row_binary = _format_row_binary
+    parse_row_text = _parse_row_text
+    parse_row_binary = _parse_row_binary
index bf54e90be90274ca97437a9d3a630a567fd8de97..b43d25ca6befa4dea56323935c0da74b4a06c390 100644 (file)
 """
-psycopg copy support
+Module gathering the various parts of the copy subsystem.
 """
 
-# Copyright (C) 2020 The Psycopg Team
+from typing import IO
 
-import re
-import queue
-import struct
-import asyncio
-import threading
-from abc import ABC, abstractmethod
-from types import TracebackType
-from typing import Any, AsyncIterator, Dict, Generic, Iterator, List, Match, IO
-from typing import Optional, Sequence, Tuple, Type, TypeVar, Union, TYPE_CHECKING
+from .abc import Buffer
+from . import _copy, _copy_async
 
-from . import pq
-from . import adapt
-from . import errors as e
-from .abc import Buffer, ConnectionType, PQGen, Transformer
-from .pq.misc import connection_summary
-from ._cmodule import _psycopg
-from ._encodings import pgconn_encoding
-from .generators import copy_from, copy_to, copy_end
+# re-exports
 
-if TYPE_CHECKING:
-    from .cursor import Cursor
-    from ._cursor_base import BaseCursor
-    from .cursor_async import AsyncCursor
-    from .connection import Connection  # noqa: F401
-    from .connection_async import AsyncConnection  # noqa: F401
+AsyncCopy = _copy_async.AsyncCopy
+AsyncWriter = _copy_async.AsyncWriter
+AsyncLibpqWriter = _copy_async.AsyncLibpqWriter
+AsyncQueuedLibpqWriter = _copy_async.AsyncQueuedLibpqWriter
 
-PY_TEXT = adapt.PyFormat.TEXT
-PY_BINARY = adapt.PyFormat.BINARY
-
-TEXT = pq.Format.TEXT
-BINARY = pq.Format.BINARY
-
-COPY_IN = pq.ExecStatus.COPY_IN
-COPY_OUT = pq.ExecStatus.COPY_OUT
-
-ACTIVE = pq.TransactionStatus.ACTIVE
-
-# Size of data to accumulate before sending it down the network. We fill a
-# buffer this size field by field, and when it passes the threshold size
-# we ship it, so it may end up being bigger than this.
-BUFFER_SIZE = 32 * 1024
-
-# Maximum data size we want to queue to send to the libpq copy. Sending a
-# buffer too big to be handled can cause an infinite loop in the libpq
-# (#255) so we want to split it in more digestable chunks.
-MAX_BUFFER_SIZE = 4 * BUFFER_SIZE
-# Note: making this buffer too large, e.g.
-# MAX_BUFFER_SIZE = 1024 * 1024
-# makes operations *way* slower! Probably triggering some quadraticity
-# in the libpq memory management and data sending.
-
-# Max size of the write queue of buffers. More than that copy will block
-# Each buffer should be around BUFFER_SIZE size.
-QUEUE_SIZE = 1024
-
-
-class BaseCopy(Generic[ConnectionType]):
-    """
-    Base implementation for the copy user interface.
-
-    Two subclasses expose real methods with the sync/async differences.
-
-    The difference between the text and binary format is managed by two
-    different `Formatter` subclasses.
-
-    Writing (the I/O part) is implemented in the subclasses by a `Writer` or
-    `AsyncWriter` instance. Normally writing implies sending copy data to a
-    database, but a different writer might be chosen, e.g. to stream data into
-    a file for later use.
-    """
-
-    _Self = TypeVar("_Self", bound="BaseCopy[Any]")
-
-    formatter: "Formatter"
-
-    def __init__(
-        self,
-        cursor: "BaseCursor[ConnectionType, Any]",
-        *,
-        binary: Optional[bool] = None,
-    ):
-        self.cursor = cursor
-        self.connection = cursor.connection
-        self._pgconn = self.connection.pgconn
-
-        result = cursor.pgresult
-        if result:
-            self._direction = result.status
-            if self._direction != COPY_IN and self._direction != COPY_OUT:
-                raise e.ProgrammingError(
-                    "the cursor should have performed a COPY operation;"
-                    f" its status is {pq.ExecStatus(self._direction).name} instead"
-                )
-        else:
-            self._direction = COPY_IN
-
-        if binary is None:
-            binary = bool(result and result.binary_tuples)
-
-        tx: Transformer = getattr(cursor, "_tx", None) or adapt.Transformer(cursor)
-        if binary:
-            self.formatter = BinaryFormatter(tx)
-        else:
-            self.formatter = TextFormatter(tx, encoding=pgconn_encoding(self._pgconn))
-
-        self._finished = False
-
-    def __repr__(self) -> str:
-        cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
-        info = connection_summary(self._pgconn)
-        return f"<{cls} {info} at 0x{id(self):x}>"
-
-    def _enter(self) -> None:
-        if self._finished:
-            raise TypeError("copy blocks can be used only once")
-
-    def set_types(self, types: Sequence[Union[int, str]]) -> None:
-        """
-        Set the types expected in a COPY operation.
-
-        The types must be specified as a sequence of oid or PostgreSQL type
-        names (e.g. ``int4``, ``timestamptz[]``).
-
-        This operation overcomes the lack of metadata returned by PostgreSQL
-        when a COPY operation begins:
-
-        - On :sql:`COPY TO`, `!set_types()` allows to specify what types the
-          operation returns. If `!set_types()` is not used, the data will be
-          returned as unparsed strings or bytes instead of Python objects.
-
-        - On :sql:`COPY FROM`, `!set_types()` allows to choose what type the
-          database expects. This is especially useful in binary copy, because
-          PostgreSQL will apply no cast rule.
-
-        """
-        registry = self.cursor.adapters.types
-        oids = [t if isinstance(t, int) else registry.get_oid(t) for t in types]
-
-        if self._direction == COPY_IN:
-            self.formatter.transformer.set_dumper_types(oids, self.formatter.format)
-        else:
-            self.formatter.transformer.set_loader_types(oids, self.formatter.format)
-
-    # High level copy protocol generators (state change of the Copy object)
-
-    def _read_gen(self) -> PQGen[Buffer]:
-        if self._finished:
-            return memoryview(b"")
-
-        res = yield from copy_from(self._pgconn)
-        if isinstance(res, memoryview):
-            return res
-
-        # res is the final PGresult
-        self._finished = True
-
-        # This result is a COMMAND_OK which has info about the number of rows
-        # returned, but not about the columns, which is instead an information
-        # that was received on the COPY_OUT result at the beginning of COPY.
-        # So, don't replace the results in the cursor, just update the rowcount.
-        nrows = res.command_tuples
-        self.cursor._rowcount = nrows if nrows is not None else -1
-        return memoryview(b"")
-
-    def _read_row_gen(self) -> PQGen[Optional[Tuple[Any, ...]]]:
-        data = yield from self._read_gen()
-        if not data:
-            return None
-
-        row = self.formatter.parse_row(data)
-        if row is None:
-            # Get the final result to finish the copy operation
-            yield from self._read_gen()
-            self._finished = True
-            return None
-
-        return row
-
-    def _end_copy_out_gen(self, exc: Optional[BaseException]) -> PQGen[None]:
-        if not exc:
-            return
-
-        if self._pgconn.transaction_status != ACTIVE:
-            # The server has already finished to send copy data. The connection
-            # is already in a good state.
-            return
-
-        # Throw a cancel to the server, then consume the rest of the copy data
-        # (which might or might not have been already transferred entirely to
-        # the client, so we won't necessary see the exception associated with
-        # canceling).
-        self.connection.cancel()
-        try:
-            while (yield from self._read_gen()):
-                pass
-        except e.QueryCanceled:
-            pass
-
-
-class Copy(BaseCopy["Connection[Any]"]):
-    """Manage a :sql:`COPY` operation.
-
-    :param cursor: the cursor where the operation is performed.
-    :param binary: if `!True`, write binary format.
-    :param writer: the object to write to destination. If not specified, write
-        to the `!cursor` connection.
-
-    Choosing `!binary` is not necessary if the cursor has executed a
-    :sql:`COPY` operation, because the operation result describes the format
-    too. The parameter is useful when a `!Copy` object is created manually and
-    no operation is performed on the cursor, such as when using ``writer=``\\
-    `~psycopg.copy.FileWriter`.
-
-    """
-
-    __module__ = "psycopg"
-
-    writer: "Writer"
-
-    def __init__(
-        self,
-        cursor: "Cursor[Any]",
-        *,
-        binary: Optional[bool] = None,
-        writer: Optional["Writer"] = None,
-    ):
-        super().__init__(cursor, binary=binary)
-        if not writer:
-            writer = LibpqWriter(cursor)
-
-        self.writer = writer
-        self._write = writer.write
-
-    def __enter__(self: BaseCopy._Self) -> BaseCopy._Self:
-        self._enter()
-        return self
-
-    def __exit__(
-        self,
-        exc_type: Optional[Type[BaseException]],
-        exc_val: Optional[BaseException],
-        exc_tb: Optional[TracebackType],
-    ) -> None:
-        self.finish(exc_val)
-
-    # End user sync interface
-
-    def __iter__(self) -> Iterator[Buffer]:
-        """Implement block-by-block iteration on :sql:`COPY TO`."""
-        while True:
-            data = self.read()
-            if not data:
-                break
-            yield data
-
-    def read(self) -> Buffer:
-        """
-        Read an unparsed row after a :sql:`COPY TO` operation.
-
-        Return an empty string when the data is finished.
-        """
-        return self.connection.wait(self._read_gen())
-
-    def rows(self) -> Iterator[Tuple[Any, ...]]:
-        """
-        Iterate on the result of a :sql:`COPY TO` operation record by record.
-
-        Note that the records returned will be tuples of unparsed strings or
-        bytes, unless data types are specified using `set_types()`.
-        """
-        while True:
-            record = self.read_row()
-            if record is None:
-                break
-            yield record
-
-    def read_row(self) -> Optional[Tuple[Any, ...]]:
-        """
-        Read a parsed row of data from a table after a :sql:`COPY TO` operation.
-
-        Return `!None` when the data is finished.
-
-        Note that the records returned will be tuples of unparsed strings or
-        bytes, unless data types are specified using `set_types()`.
-        """
-        return self.connection.wait(self._read_row_gen())
-
-    def write(self, buffer: Union[Buffer, str]) -> None:
-        """
-        Write a block of data to a table after a :sql:`COPY FROM` operation.
-
-        If the :sql:`COPY` is in binary format `!buffer` must be `!bytes`. In
-        text mode it can be either `!bytes` or `!str`.
-        """
-        data = self.formatter.write(buffer)
-        if data:
-            self._write(data)
-
-    def write_row(self, row: Sequence[Any]) -> None:
-        """Write a record to a table after a :sql:`COPY FROM` operation."""
-        data = self.formatter.write_row(row)
-        if data:
-            self._write(data)
-
-    def finish(self, exc: Optional[BaseException]) -> None:
-        """Terminate the copy operation and free the resources allocated.
-
-        You shouldn't need to call this function yourself: it is usually called
-        by exit. It is available if, despite what is documented, you end up
-        using the `Copy` object outside a block.
-        """
-        if self._direction == COPY_IN:
-            data = self.formatter.end()
-            if data:
-                self._write(data)
-            self.writer.finish(exc)
-            self._finished = True
-        else:
-            self.connection.wait(self._end_copy_out_gen(exc))
-
-
-class Writer(ABC):
-    """
-    A class to write copy data somewhere.
-    """
-
-    @abstractmethod
-    def write(self, data: Buffer) -> None:
-        """
-        Write some data to destination.
-        """
-        ...
-
-    def finish(self, exc: Optional[BaseException] = None) -> None:
-        """
-        Called when write operations are finished.
-
-        If operations finished with an error, it will be passed to ``exc``.
-        """
-        pass
-
-
-class LibpqWriter(Writer):
-    """
-    A `Writer` to write copy data to a Postgres database.
-    """
-
-    def __init__(self, cursor: "Cursor[Any]"):
-        self.cursor = cursor
-        self.connection = cursor.connection
-        self._pgconn = self.connection.pgconn
-
-    def write(self, data: Buffer) -> None:
-        if len(data) <= MAX_BUFFER_SIZE:
-            # Most used path: we don't need to split the buffer in smaller
-            # bits, so don't make a copy.
-            self.connection.wait(copy_to(self._pgconn, data))
-        else:
-            # Copy a buffer too large in chunks to avoid causing a memory
-            # error in the libpq, which may cause an infinite loop (#255).
-            for i in range(0, len(data), MAX_BUFFER_SIZE):
-                self.connection.wait(
-                    copy_to(self._pgconn, data[i : i + MAX_BUFFER_SIZE])
-                )
-
-    def finish(self, exc: Optional[BaseException] = None) -> None:
-        bmsg: Optional[bytes]
-        if exc:
-            msg = f"error from Python: {type(exc).__qualname__} - {exc}"
-            bmsg = msg.encode(pgconn_encoding(self._pgconn), "replace")
-        else:
-            bmsg = None
-
-        try:
-            res = self.connection.wait(copy_end(self._pgconn, bmsg))
-        # The QueryCanceled is expected if we sent an exception message to
-        # pgconn.put_copy_end(). The Python exception that generated that
-        # cancelling is more important, so don't clobber it.
-        except e.QueryCanceled:
-            if not bmsg:
-                raise
-        else:
-            self.cursor._results = [res]
-
-
-class QueuedLibpqWriter(LibpqWriter):
-    """
-    A writer using a buffer to queue data to write to a Postgres database.
-
-    `write()` returns immediately, so that the main thread can be CPU-bound
-    formatting messages, while a worker thread can be IO-bound waiting to write
-    on the connection.
-    """
-
-    def __init__(self, cursor: "Cursor[Any]"):
-        super().__init__(cursor)
-
-        self._queue: queue.Queue[Buffer] = queue.Queue(maxsize=QUEUE_SIZE)
-        self._worker: Optional[threading.Thread] = None
-        self._worker_error: Optional[BaseException] = None
-
-    def worker(self) -> None:
-        """Push data to the server when available from the copy queue.
-
-        Terminate reading when the queue receives a false-y value, or in case
-        of error.
-
-        The function is designed to be run in a separate thread.
-        """
-        try:
-            while True:
-                data = self._queue.get(block=True, timeout=24 * 60 * 60)
-                if not data:
-                    break
-                self.connection.wait(copy_to(self._pgconn, data))
-        except BaseException as ex:
-            # Propagate the error to the main thread.
-            self._worker_error = ex
-
-    def write(self, data: Buffer) -> None:
-        if not self._worker:
-            # warning: reference loop, broken by _write_end
-            self._worker = threading.Thread(target=self.worker)
-            self._worker.daemon = True
-            self._worker.start()
-
-        # If the worker thread raies an exception, re-raise it to the caller.
-        if self._worker_error:
-            raise self._worker_error
-
-        if len(data) <= MAX_BUFFER_SIZE:
-            # Most used path: we don't need to split the buffer in smaller
-            # bits, so don't make a copy.
-            self._queue.put(data)
-        else:
-            # Copy a buffer too large in chunks to avoid causing a memory
-            # error in the libpq, which may cause an infinite loop (#255).
-            for i in range(0, len(data), MAX_BUFFER_SIZE):
-                self._queue.put(data[i : i + MAX_BUFFER_SIZE])
-
-    def finish(self, exc: Optional[BaseException] = None) -> None:
-        self._queue.put(b"")
-
-        if self._worker:
-            self._worker.join()
-            self._worker = None  # break the loop
-
-        # Check if the worker thread raised any exception before terminating.
-        if self._worker_error:
-            raise self._worker_error
-
-        super().finish(exc)
+Copy = _copy.Copy
+Writer = _copy.Writer
+LibpqWriter = _copy.LibpqWriter
+QueuedLibpqWriter = _copy.QueuedLibpqWriter
 
 
 class FileWriter(Writer):
@@ -471,445 +33,3 @@ class FileWriter(Writer):
 
     def write(self, data: Buffer) -> None:
         self.file.write(data)
-
-
-class AsyncCopy(BaseCopy["AsyncConnection[Any]"]):
-    """Manage an asynchronous :sql:`COPY` operation."""
-
-    __module__ = "psycopg"
-
-    writer: "AsyncWriter"
-
-    def __init__(
-        self,
-        cursor: "AsyncCursor[Any]",
-        *,
-        binary: Optional[bool] = None,
-        writer: Optional["AsyncWriter"] = None,
-    ):
-        super().__init__(cursor, binary=binary)
-
-        if not writer:
-            writer = AsyncLibpqWriter(cursor)
-
-        self.writer = writer
-        self._write = writer.write
-
-    async def __aenter__(self: BaseCopy._Self) -> BaseCopy._Self:
-        self._enter()
-        return self
-
-    async def __aexit__(
-        self,
-        exc_type: Optional[Type[BaseException]],
-        exc_val: Optional[BaseException],
-        exc_tb: Optional[TracebackType],
-    ) -> None:
-        await self.finish(exc_val)
-
-    async def __aiter__(self) -> AsyncIterator[Buffer]:
-        while True:
-            data = await self.read()
-            if not data:
-                break
-            yield data
-
-    async def read(self) -> Buffer:
-        return await self.connection.wait(self._read_gen())
-
-    async def rows(self) -> AsyncIterator[Tuple[Any, ...]]:
-        while True:
-            record = await self.read_row()
-            if record is None:
-                break
-            yield record
-
-    async def read_row(self) -> Optional[Tuple[Any, ...]]:
-        return await self.connection.wait(self._read_row_gen())
-
-    async def write(self, buffer: Union[Buffer, str]) -> None:
-        data = self.formatter.write(buffer)
-        if data:
-            await self._write(data)
-
-    async def write_row(self, row: Sequence[Any]) -> None:
-        data = self.formatter.write_row(row)
-        if data:
-            await self._write(data)
-
-    async def finish(self, exc: Optional[BaseException]) -> None:
-        if self._direction == COPY_IN:
-            data = self.formatter.end()
-            if data:
-                await self._write(data)
-            await self.writer.finish(exc)
-            self._finished = True
-        else:
-            await self.connection.wait(self._end_copy_out_gen(exc))
-
-
-class AsyncWriter(ABC):
-    """
-    A class to write copy data somewhere (for async connections).
-    """
-
-    @abstractmethod
-    async def write(self, data: Buffer) -> None:
-        ...
-
-    async def finish(self, exc: Optional[BaseException] = None) -> None:
-        pass
-
-
-class AsyncLibpqWriter(AsyncWriter):
-    """
-    An `AsyncWriter` to write copy data to a Postgres database.
-    """
-
-    def __init__(self, cursor: "AsyncCursor[Any]"):
-        self.cursor = cursor
-        self.connection = cursor.connection
-        self._pgconn = self.connection.pgconn
-
-    async def write(self, data: Buffer) -> None:
-        if len(data) <= MAX_BUFFER_SIZE:
-            # Most used path: we don't need to split the buffer in smaller
-            # bits, so don't make a copy.
-            await self.connection.wait(copy_to(self._pgconn, data))
-        else:
-            # Copy a buffer too large in chunks to avoid causing a memory
-            # error in the libpq, which may cause an infinite loop (#255).
-            for i in range(0, len(data), MAX_BUFFER_SIZE):
-                await self.connection.wait(
-                    copy_to(self._pgconn, data[i : i + MAX_BUFFER_SIZE])
-                )
-
-    async def finish(self, exc: Optional[BaseException] = None) -> None:
-        bmsg: Optional[bytes]
-        if exc:
-            msg = f"error from Python: {type(exc).__qualname__} - {exc}"
-            bmsg = msg.encode(pgconn_encoding(self._pgconn), "replace")
-        else:
-            bmsg = None
-
-        try:
-            res = await self.connection.wait(copy_end(self._pgconn, bmsg))
-        # The QueryCanceled is expected if we sent an exception message to
-        # pgconn.put_copy_end(). The Python exception that generated that
-        # cancelling is more important, so don't clobber it.
-        except e.QueryCanceled:
-            if not bmsg:
-                raise
-        else:
-            self.cursor._results = [res]
-
-
-class AsyncQueuedLibpqWriter(AsyncLibpqWriter):
-    """
-    An `AsyncWriter` using a buffer to queue data to write.
-
-    `write()` returns immediately, so that the main thread can be CPU-bound
-    formatting messages, while a worker thread can be IO-bound waiting to write
-    on the connection.
-    """
-
-    def __init__(self, cursor: "AsyncCursor[Any]"):
-        super().__init__(cursor)
-
-        self._queue: asyncio.Queue[Buffer] = asyncio.Queue(maxsize=QUEUE_SIZE)
-        self._worker: Optional[asyncio.Future[None]] = None
-
-    async def worker(self) -> None:
-        """Push data to the server when available from the copy queue.
-
-        Terminate reading when the queue receives a false-y value.
-
-        The function is designed to be run in a separate task.
-        """
-        while True:
-            data = await self._queue.get()
-            if not data:
-                break
-            await self.connection.wait(copy_to(self._pgconn, data))
-
-    async def write(self, data: Buffer) -> None:
-        if not self._worker:
-            self._worker = asyncio.create_task(self.worker())
-
-        if len(data) <= MAX_BUFFER_SIZE:
-            # Most used path: we don't need to split the buffer in smaller
-            # bits, so don't make a copy.
-            await self._queue.put(data)
-        else:
-            # Copy a buffer too large in chunks to avoid causing a memory
-            # error in the libpq, which may cause an infinite loop (#255).
-            for i in range(0, len(data), MAX_BUFFER_SIZE):
-                await self._queue.put(data[i : i + MAX_BUFFER_SIZE])
-
-    async def finish(self, exc: Optional[BaseException] = None) -> None:
-        await self._queue.put(b"")
-
-        if self._worker:
-            await asyncio.gather(self._worker)
-            self._worker = None  # break reference loops if any
-
-        await super().finish(exc)
-
-
-class Formatter(ABC):
-    """
-    A class which understand a copy format (text, binary).
-    """
-
-    format: pq.Format
-
-    def __init__(self, transformer: Transformer):
-        self.transformer = transformer
-        self._write_buffer = bytearray()
-        self._row_mode = False  # true if the user is using write_row()
-
-    @abstractmethod
-    def parse_row(self, data: Buffer) -> Optional[Tuple[Any, ...]]:
-        ...
-
-    @abstractmethod
-    def write(self, buffer: Union[Buffer, str]) -> Buffer:
-        ...
-
-    @abstractmethod
-    def write_row(self, row: Sequence[Any]) -> Buffer:
-        ...
-
-    @abstractmethod
-    def end(self) -> Buffer:
-        ...
-
-
-class TextFormatter(Formatter):
-    format = TEXT
-
-    def __init__(self, transformer: Transformer, encoding: str = "utf-8"):
-        super().__init__(transformer)
-        self._encoding = encoding
-
-    def parse_row(self, data: Buffer) -> Optional[Tuple[Any, ...]]:
-        if data:
-            return parse_row_text(data, self.transformer)
-        else:
-            return None
-
-    def write(self, buffer: Union[Buffer, str]) -> Buffer:
-        data = self._ensure_bytes(buffer)
-        self._signature_sent = True
-        return data
-
-    def write_row(self, row: Sequence[Any]) -> Buffer:
-        # Note down that we are writing in row mode: it means we will have
-        # to take care of the end-of-copy marker too
-        self._row_mode = True
-
-        format_row_text(row, self.transformer, self._write_buffer)
-        if len(self._write_buffer) > BUFFER_SIZE:
-            buffer, self._write_buffer = self._write_buffer, bytearray()
-            return buffer
-        else:
-            return b""
-
-    def end(self) -> Buffer:
-        buffer, self._write_buffer = self._write_buffer, bytearray()
-        return buffer
-
-    def _ensure_bytes(self, data: Union[Buffer, str]) -> Buffer:
-        if isinstance(data, str):
-            return data.encode(self._encoding)
-        else:
-            # Assume, for simplicity, that the user is not passing stupid
-            # things to the write function. If that's the case, things
-            # will fail downstream.
-            return data
-
-
-class BinaryFormatter(Formatter):
-    format = BINARY
-
-    def __init__(self, transformer: Transformer):
-        super().__init__(transformer)
-        self._signature_sent = False
-
-    def parse_row(self, data: Buffer) -> Optional[Tuple[Any, ...]]:
-        if not self._signature_sent:
-            if data[: len(_binary_signature)] != _binary_signature:
-                raise e.DataError(
-                    "binary copy doesn't start with the expected signature"
-                )
-            self._signature_sent = True
-            data = data[len(_binary_signature) :]
-
-        elif data == _binary_trailer:
-            return None
-
-        return parse_row_binary(data, self.transformer)
-
-    def write(self, buffer: Union[Buffer, str]) -> Buffer:
-        data = self._ensure_bytes(buffer)
-        self._signature_sent = True
-        return data
-
-    def write_row(self, row: Sequence[Any]) -> Buffer:
-        # Note down that we are writing in row mode: it means we will have
-        # to take care of the end-of-copy marker too
-        self._row_mode = True
-
-        if not self._signature_sent:
-            self._write_buffer += _binary_signature
-            self._signature_sent = True
-
-        format_row_binary(row, self.transformer, self._write_buffer)
-        if len(self._write_buffer) > BUFFER_SIZE:
-            buffer, self._write_buffer = self._write_buffer, bytearray()
-            return buffer
-        else:
-            return b""
-
-    def end(self) -> Buffer:
-        # If we have sent no data we need to send the signature
-        # and the trailer
-        if not self._signature_sent:
-            self._write_buffer += _binary_signature
-            self._write_buffer += _binary_trailer
-
-        elif self._row_mode:
-            # if we have sent data already, we have sent the signature
-            # too (either with the first row, or we assume that in
-            # block mode the signature is included).
-            # Write the trailer only if we are sending rows (with the
-            # assumption that who is copying binary data is sending the
-            # whole format).
-            self._write_buffer += _binary_trailer
-
-        buffer, self._write_buffer = self._write_buffer, bytearray()
-        return buffer
-
-    def _ensure_bytes(self, data: Union[Buffer, str]) -> Buffer:
-        if isinstance(data, str):
-            raise TypeError("cannot copy str data in binary mode: use bytes instead")
-        else:
-            # Assume, for simplicity, that the user is not passing stupid
-            # things to the write function. If that's the case, things
-            # will fail downstream.
-            return data
-
-
-def _format_row_text(
-    row: Sequence[Any], tx: Transformer, out: Optional[bytearray] = None
-) -> bytearray:
-    """Convert a row of objects to the data to send for copy."""
-    if out is None:
-        out = bytearray()
-
-    if not row:
-        out += b"\n"
-        return out
-
-    adapted = tx.dump_sequence(row, [PY_TEXT] * len(row))
-    for b in adapted:
-        out += _dump_re.sub(_dump_sub, b) if b is not None else rb"\N"
-        out += b"\t"
-
-    out[-1:] = b"\n"
-    return out
-
-
-def _format_row_binary(
-    row: Sequence[Any], tx: Transformer, out: Optional[bytearray] = None
-) -> bytearray:
-    """Convert a row of objects to the data to send for binary copy."""
-    if out is None:
-        out = bytearray()
-
-    out += _pack_int2(len(row))
-    adapted = tx.dump_sequence(row, [PY_BINARY] * len(row))
-    for b in adapted:
-        if b is not None:
-            out += _pack_int4(len(b))
-            out += b
-        else:
-            out += _binary_null
-
-    return out
-
-
-def _parse_row_text(data: Buffer, tx: Transformer) -> Tuple[Any, ...]:
-    if not isinstance(data, bytes):
-        data = bytes(data)
-    fields = data.split(b"\t")
-    fields[-1] = fields[-1][:-1]  # drop \n
-    row = [None if f == b"\\N" else _load_re.sub(_load_sub, f) for f in fields]
-    return tx.load_sequence(row)
-
-
-def _parse_row_binary(data: Buffer, tx: Transformer) -> Tuple[Any, ...]:
-    row: List[Optional[Buffer]] = []
-    nfields = _unpack_int2(data, 0)[0]
-    pos = 2
-    for i in range(nfields):
-        length = _unpack_int4(data, pos)[0]
-        pos += 4
-        if length >= 0:
-            row.append(data[pos : pos + length])
-            pos += length
-        else:
-            row.append(None)
-
-    return tx.load_sequence(row)
-
-
-_pack_int2 = struct.Struct("!h").pack
-_pack_int4 = struct.Struct("!i").pack
-_unpack_int2 = struct.Struct("!h").unpack_from
-_unpack_int4 = struct.Struct("!i").unpack_from
-
-_binary_signature = (
-    b"PGCOPY\n\xff\r\n\0"  # Signature
-    b"\x00\x00\x00\x00"  # flags
-    b"\x00\x00\x00\x00"  # extra length
-)
-_binary_trailer = b"\xff\xff"
-_binary_null = b"\xff\xff\xff\xff"
-
-_dump_re = re.compile(b"[\b\t\n\v\f\r\\\\]")
-_dump_repl = {
-    b"\b": b"\\b",
-    b"\t": b"\\t",
-    b"\n": b"\\n",
-    b"\v": b"\\v",
-    b"\f": b"\\f",
-    b"\r": b"\\r",
-    b"\\": b"\\\\",
-}
-
-
-def _dump_sub(m: Match[bytes], __map: Dict[bytes, bytes] = _dump_repl) -> bytes:
-    return __map[m.group(0)]
-
-
-_load_re = re.compile(b"\\\\[btnvfr\\\\]")
-_load_repl = {v: k for k, v in _dump_repl.items()}
-
-
-def _load_sub(m: Match[bytes], __map: Dict[bytes, bytes] = _load_repl) -> bytes:
-    return __map[m.group(0)]
-
-
-# Override functions with fast versions if available
-if _psycopg:
-    format_row_text = _psycopg.format_row_text
-    format_row_binary = _psycopg.format_row_binary
-    parse_row_text = _psycopg.parse_row_text
-    parse_row_binary = _psycopg.parse_row_binary
-
-else:
-    format_row_text = _format_row_text
-    format_row_binary = _format_row_binary
-    parse_row_text = _parse_row_text
-    parse_row_binary = _parse_row_binary
index 5ab25a86285b133bade3537d4ddb9ee7458cea55..4b3e182e65dd524f308d44af310225bbc933db3f 100644 (file)
@@ -645,7 +645,7 @@ def test_worker_error_propagated(conn, monkeypatch):
         raise ZeroDivisionError
         yield
 
-    monkeypatch.setattr(psycopg.copy, "copy_to", copy_to_broken)
+    monkeypatch.setattr(psycopg._copy, "copy_to", copy_to_broken)
     cur = conn.cursor()
     cur.execute("create temp table wat (a text, b text)")
     with pytest.raises(ZeroDivisionError):
index c62d9f2d9c763b91692c7d8d251479f435b07c53..68c21d8ef87adf3193011f1b9f70f66255e5fbb6 100644 (file)
@@ -660,7 +660,7 @@ async def test_worker_error_propagated(aconn, monkeypatch):
         raise ZeroDivisionError
         yield
 
-    monkeypatch.setattr(psycopg.copy, "copy_to", copy_to_broken)
+    monkeypatch.setattr(psycopg._copy_async, "copy_to", copy_to_broken)
     cur = aconn.cursor()
     await cur.execute("create temp table wat (a text, b text)")
     with pytest.raises(ZeroDivisionError):
index 6b028d53e5f76ce2497b882f25b96a7248263535..28f98a7831443b222a399ff54eb452f8dc471542 100755 (executable)
@@ -182,6 +182,7 @@ class RenameAsyncToSync(ast.NodeTransformer):
         "__aenter__": "__enter__",
         "__aexit__": "__exit__",
         "__aiter__": "__iter__",
+        "_copy_async": "_copy",
         "aclose": "close",
         "aclosing": "closing",
         "acommands": "commands",
@@ -289,12 +290,18 @@ class RenameAsyncToSync(ast.NodeTransformer):
         for base in node.bases:
             if not isinstance(base, ast.Subscript):
                 continue
-            if not isinstance(base.slice, ast.Tuple):
-                continue
-            for elt in base.slice.elts:
-                if not (isinstance(elt, ast.Constant) and isinstance(elt.value, str)):
+
+            if isinstance(base.slice, ast.Constant):
+                if not isinstance(base.slice.value, str):
                     continue
-                elt.value = self._visit_type_string(elt.value)
+                base.slice.value = self._visit_type_string(base.slice.value)
+            elif isinstance(base.slice, ast.Tuple):
+                for elt in base.slice.elts:
+                    if not (
+                        isinstance(elt, ast.Constant) and isinstance(elt.value, str)
+                    ):
+                        continue
+                    elt.value = self._visit_type_string(elt.value)
 
         return node
 
index 68732131f4e5f9d7e9b28f8ff3d787d620a07ad0..5932dd7cbb00679968d842a22478d63532403109 100755 (executable)
@@ -27,6 +27,7 @@ if [[ ${1:-} == '--check' ]]; then
 fi
 
 all_inputs="
+    psycopg/psycopg/_copy_async.py
     psycopg/psycopg/connection_async.py
     psycopg/psycopg/cursor_async.py
     psycopg_pool/psycopg_pool/null_pool_async.py