From: Daniele Varrazzo Date: Thu, 31 Dec 2020 19:25:57 +0000 (+0100) Subject: Copy objects refactoring to isolate the row format function X-Git-Tag: 3.0.dev0~211 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=aa05cb067f2538766e70cc12dd414378a17336de;p=thirdparty%2Fpsycopg.git Copy objects refactoring to isolate the row format function --- diff --git a/psycopg3/psycopg3/copy.py b/psycopg3/psycopg3/copy.py index fe39b4656..999594471 100644 --- a/psycopg3/psycopg3/copy.py +++ b/psycopg3/psycopg3/copy.py @@ -26,6 +26,7 @@ class BaseCopy(Generic[ConnectionType]): self.cursor = cursor self.connection = cursor.connection self.transformer = cursor._transformer + self._pgconn = self.connection.pgconn assert ( self.transformer.pgresult @@ -39,13 +40,13 @@ class BaseCopy(Generic[ConnectionType]): self._finished = False if self.format == Format.TEXT: - self._format_copy_row = self._format_row_text + self._format_copy_row = format_row_text else: - self._format_copy_row = self._format_row_binary + self._format_copy_row = format_row_binary def __repr__(self) -> str: cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}" - info = pq.misc.connection_summary(self.connection.pgconn) + info = pq.misc.connection_summary(self._pgconn) return f"<{cls} {info} at 0x{id(self):x}>" # High level copy protocol generators (state change of the Copy object) @@ -54,8 +55,7 @@ class BaseCopy(Generic[ConnectionType]): if self._finished: return b"" - conn = self.connection - res = yield from copy_from(conn.pgconn) + res = yield from copy_from(self._pgconn) if isinstance(res, bytes): return res @@ -66,18 +66,30 @@ class BaseCopy(Generic[ConnectionType]): return b"" def _write_gen(self, buffer: Union[str, bytes]) -> PQGen[None]: - conn = self.connection # if write() was called, assume the header was sent together with the - # first block of data (either because we added it to the first row - # or, if the user is copying blocks, assume the blocks contain - # the header). + # first block of data. self._signature_sent = True - yield from copy_to(conn.pgconn, self._ensure_bytes(buffer)) + yield from copy_to(self._pgconn, self._ensure_bytes(buffer)) + + def _write_row_gen(self, row: Sequence[Any]) -> PQGen[None]: + # Note down that we are writing in row mode: it means we will have + # to take care of the end-of-copy marker too + self._row_mode = True + + data = self._format_row(row) + if self.format == Format.BINARY and not self._signature_sent: + yield from copy_to(self._pgconn, _binary_signature) + self._signature_sent = True + + yield from copy_to(self._pgconn, data) def _finish_gen(self, error: str = "") -> PQGen[None]: - conn = self.connection - berr = error.encode(conn.client_encoding, "replace") if error else None - res = yield from copy_end(conn.pgconn, berr) + berr = ( + error.encode(self.connection.client_encoding, "replace") + if error + else None + ) + res = yield from copy_end(self._pgconn, berr) nrows = res.command_tuples self.cursor._rowcount = nrows if nrows is not None else -1 self._finished = True @@ -102,8 +114,8 @@ class BaseCopy(Generic[ConnectionType]): # If we have sent no data we need to send the signature # and the trailer if not self._signature_sent: - yield from self._write_gen(self._binary_signature) - yield from self._write_gen(self._binary_trailer) + yield from copy_to(self._pgconn, _binary_signature) + yield from copy_to(self._pgconn, _binary_trailer) elif self._row_mode: # if we have sent data already, we have sent the signature too # (either with the first row, or we assume that in block mode @@ -111,7 +123,7 @@ class BaseCopy(Generic[ConnectionType]): # Write the trailer only if we are sending rows (with the # assumption that who is copying binary data is sending the # whole format). - yield from self._write_gen(self._binary_trailer) + yield from copy_to(self._pgconn, _binary_trailer) yield from self._finish_gen() @@ -129,53 +141,6 @@ class BaseCopy(Generic[ConnectionType]): return self._format_copy_row(out) - def _format_row_text(self, row: Sequence[Optional[bytes]]) -> bytes: - """Convert a row of adapted data to the data to send for text copy""" - return ( - b"\t".join( - _bsrepl_re.sub(_bsrepl_sub, item) - if item is not None - else br"\N" - for item in row - ) - + b"\n" - ) - - def _format_row_binary( - self, - row: Sequence[Optional[bytes]], - __int2_struct: struct.Struct = struct.Struct("!h"), - __int4_struct: struct.Struct = struct.Struct("!i"), - ) -> bytes: - """Convert a row of adapted data to the data to send for binary copy""" - out = [] - if not self._signature_sent: - out.append(self._binary_signature) - self._signature_sent = True - - # Note down that we are writing in row mode: it means we will have - # to take care of the end-of-copy marker too - self._row_mode = True - - out.append(__int2_struct.pack(len(row))) - for item in row: - if item is not None: - out.append(__int4_struct.pack(len(item))) - out.append(item) - else: - out.append(self._binary_null) - - return b"".join(out) - - _binary_signature = ( - # Signature, flags, extra length - b"PGCOPY\n\xff\r\n\0" - b"\x00\x00\x00\x00" - b"\x00\x00\x00\x00" - ) - _binary_trailer = b"\xff\xff" - _binary_null = b"\xff\xff\xff\xff" - def _ensure_bytes(self, data: Union[bytes, str]) -> bytes: if isinstance(data, bytes): return data @@ -195,24 +160,6 @@ class BaseCopy(Generic[ConnectionType]): raise TypeError("copy blocks can be used only once") -def _bsrepl_sub( - m: Match[bytes], - __map: Dict[bytes, bytes] = { - b"\b": b"\\b", - b"\t": b"\\t", - b"\n": b"\\n", - b"\v": b"\\v", - b"\f": b"\\f", - b"\r": b"\\r", - b"\\": b"\\\\", - }, -) -> bytes: - return __map[m.group(0)] - - -_bsrepl_re = re.compile(b"[\b\t\n\v\f\r\\\\]") - - class Copy(BaseCopy["Connection"]): """Manage a :sql:`COPY` operation.""" @@ -235,8 +182,7 @@ class Copy(BaseCopy["Connection"]): def write_row(self, row: Sequence[Any]) -> None: """Write a record after a :sql:`COPY FROM` operation.""" - data = self._format_row(row) - self.connection.wait(self._write_gen(data)) + self.connection.wait(self._write_row_gen(row)) def _finish(self, error: str = "") -> None: """Terminate a :sql:`COPY FROM` operation.""" @@ -274,8 +220,7 @@ class AsyncCopy(BaseCopy["AsyncConnection"]): await self.connection.wait(self._write_gen(buffer)) async def write_row(self, row: Sequence[Any]) -> None: - data = self._format_row(row) - await self.connection.wait(self._write_gen(data)) + await self.connection.wait(self._write_row_gen(row)) async def _finish(self, error: str = "") -> None: await self.connection.wait(self._finish_gen(error)) @@ -298,3 +243,60 @@ class AsyncCopy(BaseCopy["AsyncConnection"]): if not data: break yield data + + +def format_row_text(row: Sequence[Optional[bytes]]) -> bytes: + """Convert a row of adapted data to the data to send for text copy""" + return ( + b"\t".join( + _bsrepl_re.sub(_bsrepl_sub, item) if item is not None else br"\N" + for item in row + ) + + b"\n" + ) + + +def format_row_binary( + row: Sequence[Optional[bytes]], + __int2_struct: struct.Struct = struct.Struct("!h"), + __int4_struct: struct.Struct = struct.Struct("!i"), +) -> bytes: + """Convert a row of adapted data to the data to send for binary copy""" + out = [] + out.append(__int2_struct.pack(len(row))) + for item in row: + if item is not None: + out.append(__int4_struct.pack(len(item))) + out.append(item) + else: + out.append(_binary_null) + + return b"".join(out) + + +_binary_signature = ( + # Signature, flags, extra length + b"PGCOPY\n\xff\r\n\0" + b"\x00\x00\x00\x00" + b"\x00\x00\x00\x00" +) +_binary_trailer = b"\xff\xff" +_binary_null = b"\xff\xff\xff\xff" + + +def _bsrepl_sub( + m: Match[bytes], + __map: Dict[bytes, bytes] = { + b"\b": b"\\b", + b"\t": b"\\t", + b"\n": b"\\n", + b"\v": b"\\v", + b"\f": b"\\f", + b"\r": b"\\r", + b"\\": b"\\\\", + }, +) -> bytes: + return __map[m.group(0)] + + +_bsrepl_re = re.compile(b"[\b\t\n\v\f\r\\\\]")