]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
perf(copy): avoid to call _write with empty buffer
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 25 May 2022 17:33:18 +0000 (18:33 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 28 May 2022 01:02:30 +0000 (02:02 +0100)
There is a call per row we can avoid if the formatter is buffering the
data produced.

psycopg/psycopg/copy.py

index 778a850be9a1ef3d2040f7790ed0caa2de5ef289..b1fe6dc7d603446671055e4112166b7ffc5315d4 100644 (file)
@@ -266,12 +266,14 @@ class Copy(BaseCopy["Connection[Any]"]):
         text mode it can be either `!bytes` or `!str`.
         """
         data = self.formatter.write(buffer)
-        self._write(data)
+        if data:
+            self._write(data)
 
     def write_row(self, row: Sequence[Any]) -> None:
         """Write a record to a table after a :sql:`COPY FROM` operation."""
         data = self.formatter.write_row(row)
-        self._write(data)
+        if data:
+            self._write(data)
 
     def finish(self, exc: Optional[BaseException]) -> None:
         """Terminate the copy operation and free the resources allocated.
@@ -307,9 +309,6 @@ class Copy(BaseCopy["Connection[Any]"]):
             self._worker_error = ex
 
     def _write(self, data: Buffer) -> None:
-        if not data:
-            return
-
         if not self._worker:
             # warning: reference loop, broken by _write_end
             self._worker = threading.Thread(target=self.worker)
@@ -332,7 +331,8 @@ class Copy(BaseCopy["Connection[Any]"]):
 
     def _write_end(self) -> None:
         data = self.formatter.end()
-        self._write(data)
+        if data:
+            self._write(data)
         self._queue.put(b"")
 
         if self._worker:
@@ -388,11 +388,13 @@ class AsyncCopy(BaseCopy["AsyncConnection[Any]"]):
 
     async def write(self, buffer: Union[Buffer, str]) -> None:
         data = self.formatter.write(buffer)
-        await self._write(data)
+        if data:
+            await self._write(data)
 
     async def write_row(self, row: Sequence[Any]) -> None:
         data = self.formatter.write_row(row)
-        await self._write(data)
+        if data:
+            await self._write(data)
 
     async def finish(self, exc: Optional[BaseException]) -> None:
         if self._pgresult.status == COPY_IN:
@@ -417,9 +419,6 @@ class AsyncCopy(BaseCopy["AsyncConnection[Any]"]):
             await self.connection.wait(copy_to(self._pgconn, data))
 
     async def _write(self, data: Buffer) -> None:
-        if not data:
-            return
-
         if not self._worker:
             self._worker = create_task(self.worker())
 
@@ -435,7 +434,8 @@ class AsyncCopy(BaseCopy["AsyncConnection[Any]"]):
 
     async def _write_end(self) -> None:
         data = self.formatter.end()
-        await self._write(data)
+        if data:
+            await self._write(data)
         await self._queue.put(b"")
 
         if self._worker: