]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
perf(copy): avoid to call _write with empty buffer
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 25 May 2022 17:33:18 +0000 (18:33 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 27 May 2022 08:06:46 +0000 (09:06 +0100)
There is a call per row we can avoid if the formatter is buffering the
data produced.

psycopg/psycopg/copy.py

index abd7addaeb219d0bfed8525b82205ccc1aaf191b..1079f6256a95d2a0b95fc37d76ddd4f2467786a7 100644 (file)
@@ -260,12 +260,14 @@ class Copy(BaseCopy["Connection[Any]"]):
         text mode it can be either `!bytes` or `!str`.
         """
         data = self.formatter.write(buffer)
-        self._write(data)
+        if data:
+            self._write(data)
 
     def write_row(self, row: Sequence[Any]) -> None:
         """Write a record to a table after a :sql:`COPY FROM` operation."""
         data = self.formatter.write_row(row)
-        self._write(data)
+        if data:
+            self._write(data)
 
     def finish(self, exc: Optional[BaseException]) -> None:
         """Terminate the copy operation and free the resources allocated.
@@ -301,9 +303,6 @@ class Copy(BaseCopy["Connection[Any]"]):
             self._worker_error = ex
 
     def _write(self, data: Buffer) -> None:
-        if not data:
-            return
-
         if not self._worker:
             # warning: reference loop, broken by _write_end
             self._worker = threading.Thread(target=self.worker)
@@ -326,7 +325,8 @@ class Copy(BaseCopy["Connection[Any]"]):
 
     def _write_end(self) -> None:
         data = self.formatter.end()
-        self._write(data)
+        if data:
+            self._write(data)
         self._queue.put(b"")
 
         if self._worker:
@@ -382,11 +382,13 @@ class AsyncCopy(BaseCopy["AsyncConnection[Any]"]):
 
     async def write(self, buffer: Union[Buffer, str]) -> None:
         data = self.formatter.write(buffer)
-        await self._write(data)
+        if data:
+            await self._write(data)
 
     async def write_row(self, row: Sequence[Any]) -> None:
         data = self.formatter.write_row(row)
-        await self._write(data)
+        if data:
+            await self._write(data)
 
     async def finish(self, exc: Optional[BaseException]) -> None:
         if self._pgresult.status == ExecStatus.COPY_IN:
@@ -411,9 +413,6 @@ class AsyncCopy(BaseCopy["AsyncConnection[Any]"]):
             await self.connection.wait(copy_to(self._pgconn, data))
 
     async def _write(self, data: Buffer) -> None:
-        if not data:
-            return
-
         if not self._worker:
             self._worker = create_task(self.worker())
 
@@ -429,7 +428,8 @@ class AsyncCopy(BaseCopy["AsyncConnection[Any]"]):
 
     async def _write_end(self) -> None:
         data = self.formatter.end()
-        await self._write(data)
+        if data:
+            await self._write(data)
         await self._queue.put(b"")
 
         if self._worker: