From: Daniele Varrazzo Date: Sat, 23 Jul 2022 22:13:41 +0000 (+0100) Subject: feat(copy): add FileWriter to write copy data to a file-like object X-Git-Tag: 3.1~44^2~4 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=5544b99c55cf37801af1549fff8f945934489444;p=thirdparty%2Fpsycopg.git feat(copy): add FileWriter to write copy data to a file-like object --- diff --git a/psycopg/psycopg/copy.py b/psycopg/psycopg/copy.py index f53c9a98a..ee9973de6 100644 --- a/psycopg/psycopg/copy.py +++ b/psycopg/psycopg/copy.py @@ -11,7 +11,7 @@ import asyncio import threading from abc import ABC, abstractmethod from types import TracebackType -from typing import Any, AsyncIterator, Dict, Generic, Iterator, List, Match +from typing import Any, AsyncIterator, Dict, Generic, Iterator, List, Match, IO from typing import Optional, Sequence, Tuple, Type, TypeVar, Union, TYPE_CHECKING from . import pq @@ -416,6 +416,20 @@ class QueuedLibpqDriver(LibpqWriter): super().finish(exc) +class FileWriter(Writer): + """ + A `Writer` to write copy data to a file-like object. + + The file must be open for writing in binary mode. + """ + + def __init__(self, file: IO[bytes]): + self.file = file + + def write(self, data: Buffer) -> None: + self.file.write(data) + + class AsyncCopy(BaseCopy["AsyncConnection[Any]"]): """Manage an asynchronous :sql:`COPY` operation.""" diff --git a/tests/test_copy.py b/tests/test_copy.py index 547089f38..540b903db 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -12,7 +12,7 @@ from psycopg import pq from psycopg import sql from psycopg import errors as e from psycopg.pq import Format -from psycopg.copy import Copy, Writer, LibpqWriter, QueuedLibpqDriver +from psycopg.copy import Copy, LibpqWriter, QueuedLibpqDriver, FileWriter from psycopg.adapt import PyFormat from psycopg.types import TypeInfo from psycopg.types.hstore import register_hstore @@ -462,15 +462,15 @@ from copy_in group by 1, 2, 3 def test_copy_in_format(conn): - writer = BytesWriter() + file = BytesIO() conn.execute("set client_encoding to utf8") cur = conn.cursor() - with Copy(cur, writer=writer) as copy: + with Copy(cur, writer=FileWriter(file)) as copy: for i in range(1, 256): copy.write_row((i, chr(i))) - writer.file.seek(0) - rows = writer.file.read().split(b"\n") + file.seek(0) + rows = file.read().split(b"\n") assert not rows[-1] del rows[-1] @@ -860,11 +860,3 @@ class DataGenerator: block = block.encode() m.update(block) return m.hexdigest() - - -class BytesWriter(Writer): - def __init__(self): - self.file = BytesIO() - - def write(self, data): - self.file.write(data)