]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
feat(copy): add FileWriter to write copy data to a file-like object
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 23 Jul 2022 22:13:41 +0000 (23:13 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 26 Jul 2022 12:01:42 +0000 (13:01 +0100)
psycopg/psycopg/copy.py
tests/test_copy.py

index f53c9a98a9b2826400c488729995e9ea831a034e..ee9973de687bc46587039ae75cadb38929c2eddb 100644 (file)
@@ -11,7 +11,7 @@ import asyncio
 import threading
 from abc import ABC, abstractmethod
 from types import TracebackType
-from typing import Any, AsyncIterator, Dict, Generic, Iterator, List, Match
+from typing import Any, AsyncIterator, Dict, Generic, Iterator, List, Match, IO
 from typing import Optional, Sequence, Tuple, Type, TypeVar, Union, TYPE_CHECKING
 
 from . import pq
@@ -416,6 +416,20 @@ class QueuedLibpqDriver(LibpqWriter):
         super().finish(exc)
 
 
+class FileWriter(Writer):
+    """
+    A `Writer` to write copy data to a file-like object.
+
+    The file must be open for writing in binary mode.
+    """
+
+    def __init__(self, file: IO[bytes]):
+        self.file = file
+
+    def write(self, data: Buffer) -> None:
+        self.file.write(data)
+
+
 class AsyncCopy(BaseCopy["AsyncConnection[Any]"]):
     """Manage an asynchronous :sql:`COPY` operation."""
 
index 547089f3806bc2d6acce8d714e5a86c4e55f515a..540b903dbaf414c5fec0cfbe34625b30ff87555c 100644 (file)
@@ -12,7 +12,7 @@ from psycopg import pq
 from psycopg import sql
 from psycopg import errors as e
 from psycopg.pq import Format
-from psycopg.copy import Copy, Writer, LibpqWriter, QueuedLibpqDriver
+from psycopg.copy import Copy, LibpqWriter, QueuedLibpqDriver, FileWriter
 from psycopg.adapt import PyFormat
 from psycopg.types import TypeInfo
 from psycopg.types.hstore import register_hstore
@@ -462,15 +462,15 @@ from copy_in group by 1, 2, 3
 
 
 def test_copy_in_format(conn):
-    writer = BytesWriter()
+    file = BytesIO()
     conn.execute("set client_encoding to utf8")
     cur = conn.cursor()
-    with Copy(cur, writer=writer) as copy:
+    with Copy(cur, writer=FileWriter(file)) as copy:
         for i in range(1, 256):
             copy.write_row((i, chr(i)))
 
-    writer.file.seek(0)
-    rows = writer.file.read().split(b"\n")
+    file.seek(0)
+    rows = file.read().split(b"\n")
     assert not rows[-1]
     del rows[-1]
 
@@ -860,11 +860,3 @@ class DataGenerator:
                 block = block.encode()
             m.update(block)
         return m.hexdigest()
-
-
-class BytesWriter(Writer):
-    def __init__(self):
-        self.file = BytesIO()
-
-    def write(self, data):
-        self.file.write(data)