# Copyright (C) 2020 The Psycopg Team
import re
-from typing import cast, TYPE_CHECKING, AsyncGenerator, Generator
-from typing import Any, Deque, Dict, List, Match, Optional, Tuple, Type
+import codecs
+from typing import TYPE_CHECKING, AsyncGenerator, Generator
+from typing import Any, Deque, Dict, List, Match, Optional, Tuple, Type, Union
from types import TracebackType
from collections import deque
from .generators import copy_from, copy_to, copy_end
if TYPE_CHECKING:
- from .connection import Connection, AsyncConnection
+ from .connection import BaseConnection, Connection, AsyncConnection
class BaseCopy:
+ _connection: Optional["BaseConnection"]
+
def __init__(
self,
context: AdaptContext,
):
from .adapt import Transformer
+ self._connection = None
self._transformer = Transformer(context)
self.format = format
self.pgresult = result
self._partial: Deque[bytes] = deque()
self._header_seen = False
+ self._codec: Optional[codecs.CodecInfo] = None
@property
def finished(self) -> bool:
return self._finished
+ @property
+ def connection(self) -> "BaseConnection":
+ if self._connection is not None:
+ return self._connection
+
+ self._connection = conn = self._transformer.connection
+ if conn is not None:
+ return conn
+
+ raise ValueError("no connection available")
+
@property
def pgresult(self) -> Optional[pq.proto.PGresult]:
return self._pgresult
def _load_binary(self, buffer: bytes) -> List[Tuple[Any, ...]]:
raise NotImplementedError
+ def _ensure_bytes(self, data: Union[bytes, str]) -> bytes:
+ if isinstance(data, bytes):
+ return data
+
+ elif isinstance(data, str):
+ if self._codec is not None:
+ return self._codec.encode(data)[0]
+
+ if (
+ self.pgresult is None
+ or self.pgresult.binary_tuples == pq.Format.BINARY
+ ):
+ raise TypeError(
+ "cannot copy str data in binary mode: use bytes instead"
+ )
+ self._codec = self.connection.codec
+ return self._codec.encode(data)[0]
+
def _bsrepl_sub(
m: Match[bytes],
class Copy(BaseCopy):
- def __init__(
- self,
- context: AdaptContext,
- result: Optional[pq.proto.PGresult],
- format: pq.Format = pq.Format.TEXT,
- ):
- super().__init__(context=context, result=result, format=format)
- self._connection: Optional["Connection"] = None
+ _connection: Optional["Connection"]
@property
def connection(self) -> "Connection":
- if self._connection is None:
- conn = self._transformer.connection
- if conn is None:
- raise ValueError("no connection available")
- self._connection = cast("Connection", conn)
-
- return self._connection
+ # TODO: mypy error: "Callable[[BaseCopy], BaseConnection]" has no
+ # attribute "fget"
+ return BaseCopy.connection.fget(self) # type: ignore
def read(self) -> Optional[bytes]:
if self._finished:
return rv
- def write(self, buffer: bytes) -> None:
+ def write(self, buffer: Union[str, bytes]) -> None:
conn = self.connection
- conn.wait(copy_to(conn.pgconn, buffer))
+ conn.wait(copy_to(conn.pgconn, self._ensure_bytes(buffer)))
def finish(self, error: Optional[str] = None) -> None:
conn = self.connection
class AsyncCopy(BaseCopy):
- def __init__(
- self,
- context: AdaptContext,
- result: Optional[pq.proto.PGresult],
- format: pq.Format = pq.Format.TEXT,
- ):
- super().__init__(context=context, result=result, format=format)
- self._connection: Optional["AsyncConnection"] = None
+ _connection: Optional["AsyncConnection"]
@property
def connection(self) -> "AsyncConnection":
- if self._connection is None:
- conn = self._transformer.connection
- if conn is None:
- raise ValueError("no connection available")
- self._connection = cast("AsyncConnection", conn)
-
- return self._connection
+ return BaseCopy.connection.fget(self) # type: ignore
async def read(self) -> Optional[bytes]:
if self._finished:
return rv
- async def write(self, buffer: bytes) -> None:
+ async def write(self, buffer: Union[str, bytes]) -> None:
conn = self.connection
- await conn.wait(copy_to(conn.pgconn, buffer))
+ await conn.wait(copy_to(conn.pgconn, self._ensure_bytes(buffer)))
async def finish(self, error: Optional[str] = None) -> None:
conn = self.connection