import re
import codecs
from typing import TYPE_CHECKING, AsyncGenerator, Generator
-from typing import Any, Deque, Dict, List, Match, Optional, Tuple, Type, Union
+from typing import Dict, Match, Optional, Type, Union
from types import TracebackType
-from collections import deque
from . import pq
-from . import errors as e
from .proto import AdaptContext
from .generators import copy_from, copy_to, copy_end
class BaseCopy:
- _connection: Optional["BaseConnection"]
-
def __init__(
self,
context: AdaptContext,
):
from .adapt import Transformer
- self._connection = None
+ self._connection: Optional["BaseConnection"] = None
self._transformer = Transformer(context)
self.format = format
self.pgresult = result
self._finished = False
-
- self._partial: Deque[bytes] = deque()
- self._header_seen = False
self._codec: Optional[codecs.CodecInfo] = None
@property
self._pgresult = result
self._transformer.pgresult = result
- def load(self, buffer: bytes) -> List[Tuple[Any, ...]]:
- if self._finished:
- raise e.ProgrammingError("copy already finished")
-
- if self.format == pq.Format.TEXT:
- return self._load_text(buffer)
- else:
- return self._load_binary(buffer)
-
- def _load_text(self, buffer: bytes) -> List[Tuple[Any, ...]]:
- rows = buffer.split(b"\n")
- last_row = rows.pop(-1)
-
- if self._partial and rows:
- self._partial.append(rows[0])
- rows[0] = b"".join(self._partial)
- self._partial.clear()
-
- if last_row:
- self._partial.append(last_row)
-
- # If there is no result then the transformer has no info about types
- load_sequence = (
- self._transformer.load_sequence
- if self.pgresult is not None
- else None
- )
-
- rv = []
- for row in rows:
- if row == b"\\.":
- self._finished = True
- break
-
- values = row.split(b"\t")
- prow = tuple(
- _bsrepl_re.sub(_bsrepl_sub, v) if v != b"\\N" else None
- for v in values
- )
- rv.append(
- load_sequence(prow) if load_sequence is not None else prow
- )
-
- return rv
-
- def _load_binary(self, buffer: bytes) -> List[Tuple[Any, ...]]:
- raise NotImplementedError
-
def _ensure_bytes(self, data: Union[bytes, str]) -> bytes:
if isinstance(data, bytes):
return data
from psycopg3 import pq
from psycopg3 import errors as e
from psycopg3.adapt import Format
-from psycopg3.types import builtins
sample_records = [(10, 20, "hello"), (40, None, "world")]
sample_binary = b"".join(sample_binary_rows)
-def set_sample_attributes(res, format):
- attrs = [
- pq.PGresAttDesc(b"col1", 0, 0, format, builtins["int4"].oid, 0, 0),
- pq.PGresAttDesc(b"col2", 0, 0, format, builtins["int4"].oid, 0, 0),
- pq.PGresAttDesc(b"data", 0, 0, format, builtins["text"].oid, 0, 0),
- ]
- res.set_attributes(attrs)
-
-
-@pytest.mark.xfail
-@pytest.mark.parametrize(
- "format, buffer",
- [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
-)
-def test_load_noinfo(conn, format, buffer):
- from psycopg3.copy import Copy
-
- copy = Copy(context=None, result=None, format=format)
- records = copy.load(globals()[buffer])
- assert records == as_bytes(sample_records)
-
-
-@pytest.mark.xfail
-@pytest.mark.parametrize(
- "format, buffer",
- [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
-)
-def test_load(conn, format, buffer):
- from psycopg3.copy import Copy
-
- res = conn.pgconn.make_empty_result(pq.ExecStatus.COPY_OUT)
- set_sample_attributes(res, format)
-
- copy = Copy(context=None, result=res, format=format)
- records = copy.load(globals()[buffer])
- assert records == sample_records
-
-
-@pytest.mark.xfail
-@pytest.mark.parametrize(
- "format, buffer",
- [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
-)
-def test_dump(conn, format, buffer):
- from psycopg3.copy import Copy
-
- res = conn.pgconn.make_empty_result(pq.ExecStatus.COPY_OUT)
- set_sample_attributes(res, format)
-
- copy = Copy(context=None, result=res, format=format)
- assert copy.get_buffer() is None
- for row in sample_records:
- copy.dump(row)
- assert copy.get_buffer() == globals()[buffer]
- assert copy.get_buffer() is None
-
-
-@pytest.mark.xfail
-@pytest.mark.parametrize(
- "format, buffer",
- [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
-)
-def test_buffers(format, buffer):
- from psycopg3.copy import Copy
-
- copy = Copy(format=format)
- assert list(copy.buffers(sample_records)) == [globals()[buffer]]
-
-
-@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
-def test_copy_out_read(conn, format):
- cur = conn.cursor()
- copy = cur.copy(f"copy ({sample_values}) to stdout (format {format.name})")
-
- if format == pq.Format.TEXT:
- want = [row + b"\n" for row in sample_text.splitlines()]
- else:
- want = sample_binary_rows
-
- for row in want:
- got = copy.read()
- assert got == row
-
- assert copy.read() is None
- assert copy.read() is None
-
-
@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
def test_copy_out_iter(conn, format):
cur = conn.cursor()