]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
refactor: use assignment expressions in loops with assignment and break 1035/head
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 27 Mar 2025 16:36:29 +0000 (17:36 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 3 Apr 2025 10:56:14 +0000 (11:56 +0100)
16 files changed:
psycopg/psycopg/_copy.py
psycopg/psycopg/_copy_async.py
psycopg/psycopg/cursor.py
psycopg/psycopg/cursor_async.py
psycopg/psycopg/generators.py
tests/pq/test_async.py
tests/test_copy.py
tests/test_copy_async.py
tests/test_cursor.py
tests/test_cursor_async.py
tests/test_cursor_client.py
tests/test_cursor_client_async.py
tests/test_cursor_raw.py
tests/test_cursor_raw_async.py
tests/test_cursor_server.py
tests/test_cursor_server_async.py

index d03549464f1e9e2b98af2fb5955bc6c6cd4b3210..0934a63a2e42714acebb5c68821a27e197aa0268 100644 (file)
@@ -81,9 +81,7 @@ class Copy(BaseCopy["Connection[Any]"]):
 
     def __iter__(self) -> Iterator[Buffer]:
         """Implement block-by-block iteration on :sql:`COPY TO`."""
-        while True:
-            if not (data := self.read()):
-                break
+        while data := self.read():
             yield data
 
     def read(self) -> Buffer:
@@ -101,9 +99,7 @@ class Copy(BaseCopy["Connection[Any]"]):
         Note that the records returned will be tuples of unparsed strings or
         bytes, unless data types are specified using `set_types()`.
         """
-        while True:
-            if (record := self.read_row()) is None:
-                break
+        while (record := self.read_row()) is not None:
             yield record
 
     def read_row(self) -> tuple[Any, ...] | None:
@@ -252,9 +248,7 @@ class QueuedLibpqWriter(LibpqWriter):
         The function is designed to be run in a separate task.
         """
         try:
-            while True:
-                if not (data := self._queue.get()):
-                    break
+            while data := self._queue.get():
                 self.connection.wait(copy_to(self._pgconn, data, flush=PREFER_FLUSH))
         except BaseException as ex:
             # Propagate the error to the main thread.
index 0070ba8b55da4d0ad7871005e56822c2a4db6c9a..05ec0f98fb3e30a5d3a6dbb0ccd6410fb8c4a0a8 100644 (file)
@@ -78,9 +78,7 @@ class AsyncCopy(BaseCopy["AsyncConnection[Any]"]):
 
     async def __aiter__(self) -> AsyncIterator[Buffer]:
         """Implement block-by-block iteration on :sql:`COPY TO`."""
-        while True:
-            if not (data := (await self.read())):
-                break
+        while data := (await self.read()):
             yield data
 
     async def read(self) -> Buffer:
@@ -98,9 +96,7 @@ class AsyncCopy(BaseCopy["AsyncConnection[Any]"]):
         Note that the records returned will be tuples of unparsed strings or
         bytes, unless data types are specified using `set_types()`.
         """
-        while True:
-            if (record := (await self.read_row())) is None:
-                break
+        while (record := (await self.read_row())) is not None:
             yield record
 
     async def read_row(self) -> tuple[Any, ...] | None:
@@ -251,9 +247,7 @@ class AsyncQueuedLibpqWriter(AsyncLibpqWriter):
         The function is designed to be run in a separate task.
         """
         try:
-            while True:
-                if not (data := (await self._queue.get())):
-                    break
+            while data := (await self._queue.get()):
                 await self.connection.wait(
                     copy_to(self._pgconn, data, flush=PREFER_FLUSH)
                 )
index fe6b28c35ae0170f901e2229732d5a8d4817a8eb..fb2df08ebd896d7be2c83a28a8fb12dbf2837ae1 100644 (file)
@@ -231,9 +231,7 @@ class Cursor(BaseCursor["Connection[Any]", Row]):
         def load(pos: int) -> Row | None:
             return self._tx.load_row(pos, self._make_row)
 
-        while True:
-            if (row := load(self._pos)) is None:
-                break
+        while (row := load(self._pos)) is not None:
             self._pos += 1
             yield row
 
index 0d2886664d3cf4cc741e5423fd84e611116d42ce..4c6bdc55fc1886fe9a5e19ad72443dc902a14178 100644 (file)
@@ -235,9 +235,7 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
         def load(pos: int) -> Row | None:
             return self._tx.load_row(pos, self._make_row)
 
-        while True:
-            if (row := load(self._pos)) is None:
-                break
+        while (row := load(self._pos)) is not None:
             self._pos += 1
             yield row
 
index 0e1fbe9c89c60fe6640ff4ebb96efb2afd097e1f..f7ade8a3544804d152553543d2173179792c446c 100644 (file)
@@ -148,13 +148,10 @@ def _send(pgconn: PGconn) -> PQGen[None]:
     After this generator has finished you may want to cycle using `fetch()`
     to retrieve the results available.
     """
-    while True:
-        if pgconn.flush() == 0:
-            break
+    while pgconn.flush() != 0:
 
-        while True:
-            if ready := (yield WAIT_RW):
-                break
+        while not (ready := (yield WAIT_RW)):
+            continue
 
         if ready & READY_R:
             # This call may read notifies: they will be saved in the
@@ -216,17 +213,15 @@ def _fetch(pgconn: PGconn) -> PQGen[PGresult | None]:
     Return a result from the database (whether success or error).
     """
     if pgconn.is_busy():
-        while True:
-            if (yield WAIT_R):
-                break
+        while not (yield WAIT_R):
+            continue
 
         while True:
             pgconn.consume_input()
             if not pgconn.is_busy():
                 break
-            while True:
-                if (yield WAIT_R):
-                    break
+            while not (yield WAIT_R):
+                continue
 
     _consume_notifies(pgconn)
 
@@ -244,9 +239,8 @@ def _pipeline_communicate(
     results = []
 
     while True:
-        while True:
-            if ready := (yield WAIT_RW):
-                break
+        while not (ready := (yield WAIT_RW)):
+            continue
 
         if ready & READY_R:
             pgconn.consume_input()
@@ -283,9 +277,7 @@ def _pipeline_communicate(
 
 def _consume_notifies(pgconn: PGconn) -> None:
     # Consume notifies
-    while True:
-        if not (n := pgconn.notifies()):
-            break
+    while n := pgconn.notifies():
         if pgconn.notify_handler:
             pgconn.notify_handler(n)
 
@@ -295,13 +287,10 @@ def notifies(pgconn: PGconn) -> PQGen[list[pq.PGnotify]]:
     pgconn.consume_input()
 
     ns = []
-    while True:
-        if n := pgconn.notifies():
-            ns.append(n)
-            if pgconn.notify_handler:
-                pgconn.notify_handler(n)
-        else:
-            break
+    while n := pgconn.notifies():
+        ns.append(n)
+        if pgconn.notify_handler:
+            pgconn.notify_handler(n)
 
     return ns
 
@@ -313,9 +302,8 @@ def copy_from(pgconn: PGconn) -> PQGen[memoryview | PGresult]:
             break
 
         # would block
-        while True:
-            if (yield WAIT_R):
-                break
+        while not (yield WAIT_R):
+            continue
         pgconn.consume_input()
 
     if nbytes > 0:
@@ -342,9 +330,8 @@ def copy_to(pgconn: PGconn, buffer: Buffer, flush: bool = True) -> PQGen[None]:
     # into smaller ones. We prefer to do it there instead of here in order to
     # do it upstream the queue decoupling the writer task from the producer one.
     while pgconn.put_copy_data(buffer) == 0:
-        while True:
-            if (yield WAIT_W):
-                break
+        while not (yield WAIT_W):
+            continue
 
     # Flushing often has a good effect on macOS because memcpy operations
     # seem expensive on this platform so accumulating a large buffer has a
@@ -352,9 +339,8 @@ def copy_to(pgconn: PGconn, buffer: Buffer, flush: bool = True) -> PQGen[None]:
     if flush:
         # Repeat until it the message is flushed to the server
         while True:
-            while True:
-                if (yield WAIT_W):
-                    break
+            while not (yield WAIT_W):
+                continue
 
             if pgconn.flush() == 0:
                 break
@@ -363,15 +349,13 @@ def copy_to(pgconn: PGconn, buffer: Buffer, flush: bool = True) -> PQGen[None]:
 def copy_end(pgconn: PGconn, error: bytes | None) -> PQGen[PGresult]:
     # Retry enqueuing end copy message until successful
     while pgconn.put_copy_end(error) == 0:
-        while True:
-            if (yield WAIT_W):
-                break
+        while not (yield WAIT_W):
+            continue
 
     # Repeat until it the message is flushed to the server
     while True:
-        while True:
-            if (yield WAIT_W):
-                break
+        while not (yield WAIT_W):
+            continue
 
         if pgconn.flush() == 0:
             break
index 404239fb4824280a6fa0b1215c9f9b7f7039c3e9..dd54efe66b080cbad853a2dbfad260fadbe95c93 100644 (file)
@@ -23,9 +23,7 @@ def test_send_query(pgconn):
 
     # send loop
     waited_on_send = 0
-    while True:
-        if pgconn.flush() == 0:
-            break
+    while pgconn.flush() != 0:
 
         waited_on_send += 1
 
index 949ec0caff5ac19c6e883d51177e805047f347d5..3c64196d43d5f9cadabd2fa4286b85a0ccfe2006 100644 (file)
@@ -144,9 +144,7 @@ def test_copy_out_allchars(conn, format):
     )
     with cur.copy(query) as copy:
         copy.set_types(["text"])
-        while True:
-            if not (row := copy.read_row()):
-                break
+        while row := copy.read_row():
             assert len(row) == 1
             rows.append(row[0])
 
@@ -158,9 +156,7 @@ def test_read_row_notypes(conn, format):
     cur = conn.cursor()
     with cur.copy(f"copy ({sample_values}) to stdout (format {format.name})") as copy:
         rows = []
-        while True:
-            if not (row := copy.read_row()):
-                break
+        while row := copy.read_row():
             rows.append(row)
 
     ref = [tuple((py_to_raw(i, format) for i in record)) for record in sample_records]
@@ -733,15 +729,13 @@ def test_copy_to_leaks(conn_cls, dsn, faker, fmt, set_types, method, gc):
                         copy.set_types(faker.types_names)
 
                     if method == "read":
-                        while True:
-                            if not copy.read():
-                                break
+                        while copy.read():
+                            pass
                     elif method == "iter":
                         list(copy)
                     elif method == "row":
-                        while True:
-                            if copy.read_row() is None:
-                                break
+                        while copy.read_row() is not None:
+                            pass
                     elif method == "rows":
                         list(copy.rows())
 
@@ -859,9 +853,7 @@ class DataGenerator:
 
     def blocks(self):
         f = self.file()
-        while True:
-            if not (block := f.read(self.block_size)):
-                break
+        while block := f.read(self.block_size):
             yield block
 
     def assert_data(self):
@@ -874,9 +866,7 @@ class DataGenerator:
 
     def sha(self, f):
         m = hashlib.sha256()
-        while True:
-            if not (block := f.read()):
-                break
+        while block := f.read():
             if isinstance(block, str):
                 block = block.encode()
             m.update(block)
index 290c43c8530ff39a481eaad1276f34f0b9f0081b..7fa8223d63474d8cdd3c802327da506e9f0218ba 100644 (file)
@@ -148,9 +148,7 @@ async def test_copy_out_allchars(aconn, format):
     )
     async with cur.copy(query) as copy:
         copy.set_types(["text"])
-        while True:
-            if not (row := (await copy.read_row())):
-                break
+        while row := (await copy.read_row()):
             assert len(row) == 1
             rows.append(row[0])
 
@@ -164,9 +162,7 @@ async def test_read_row_notypes(aconn, format):
         f"copy ({sample_values}) to stdout (format {format.name})"
     ) as copy:
         rows = []
-        while True:
-            if not (row := (await copy.read_row())):
-                break
+        while row := (await copy.read_row()):
             rows.append(row)
 
     ref = [tuple(py_to_raw(i, format) for i in record) for record in sample_records]
@@ -749,15 +745,13 @@ async def test_copy_to_leaks(aconn_cls, dsn, faker, fmt, set_types, method, gc):
                         copy.set_types(faker.types_names)
 
                     if method == "read":
-                        while True:
-                            if not (await copy.read()):
-                                break
+                        while await copy.read():
+                            pass
                     elif method == "iter":
                         await alist(copy)
                     elif method == "row":
-                        while True:
-                            if (await copy.read_row()) is None:
-                                break
+                        while (await copy.read_row()) is not None:
+                            pass
                     elif method == "rows":
                         await alist(copy.rows())
 
@@ -874,9 +868,7 @@ class DataGenerator:
 
     def blocks(self):
         f = self.file()
-        while True:
-            if not (block := f.read(self.block_size)):
-                break
+        while block := f.read(self.block_size):
             yield block
 
     async def assert_data(self):
@@ -889,9 +881,7 @@ class DataGenerator:
 
     def sha(self, f):
         m = hashlib.sha256()
-        while True:
-            if not (block := f.read()):
-                break
+        while block := f.read():
             if isinstance(block, str):
                 block = block.encode()
             m.update(block)
index cf5551ca0f95551d76c2172be5960d0802e73c73..bb7af9f79bcf3a4bb3b71b4b4ad77d1a3bd34eb6 100644 (file)
@@ -85,13 +85,11 @@ def test_leak(conn_cls, dsn, faker, fmt, fmt_out, fetch, row_factory, gc):
                 cur.execute(faker.select_stmt)
 
                 if fetch == "one":
-                    while True:
-                        if cur.fetchone() is None:
-                            break
+                    while cur.fetchone() is not None:
+                        pass
                 elif fetch == "many":
-                    while True:
-                        if not cur.fetchmany(3):
-                            break
+                    while cur.fetchmany(3):
+                        pass
                 elif fetch == "all":
                     cur.fetchall()
                 elif fetch == "iter":
index fa6b70175da19a23b72a6feb6e5efc7c4b4be417..4f683d1edcee28cc217113dc39ce53301d290d2f 100644 (file)
@@ -86,13 +86,11 @@ async def test_leak(aconn_cls, dsn, faker, fmt, fmt_out, fetch, row_factory, gc)
                 await cur.execute(faker.select_stmt)
 
                 if fetch == "one":
-                    while True:
-                        if (await cur.fetchone()) is None:
-                            break
+                    while (await cur.fetchone()) is not None:
+                        pass
                 elif fetch == "many":
-                    while True:
-                        if not (await cur.fetchmany(3)):
-                            break
+                    while await cur.fetchmany(3):
+                        pass
                 elif fetch == "all":
                     await cur.fetchall()
                 elif fetch == "iter":
index fd87cf551979989f6675c7b8620b682dbde63d33..de4c0854116e0f4ee7e97123fdec68412a9c57dd 100644 (file)
@@ -95,13 +95,11 @@ def test_leak(conn_cls, dsn, faker, fetch, row_factory, gc):
                 cur.execute(faker.select_stmt)
 
                 if fetch == "one":
-                    while True:
-                        if cur.fetchone() is None:
-                            break
+                    while cur.fetchone() is not None:
+                        pass
                 elif fetch == "many":
-                    while True:
-                        if not cur.fetchmany(3):
-                            break
+                    while cur.fetchmany(3):
+                        pass
                 elif fetch == "all":
                     cur.fetchall()
                 elif fetch == "iter":
index a580f0d6662326fb4a7a2e293dc4fd77424e8cb7..7945c1b7eba71265bec6a482d0d06ecf59e349b3 100644 (file)
@@ -96,13 +96,11 @@ async def test_leak(aconn_cls, dsn, faker, fetch, row_factory, gc):
                 await cur.execute(faker.select_stmt)
 
                 if fetch == "one":
-                    while True:
-                        if (await cur.fetchone()) is None:
-                            break
+                    while (await cur.fetchone()) is not None:
+                        pass
                 elif fetch == "many":
-                    while True:
-                        if not (await cur.fetchmany(3)):
-                            break
+                    while await cur.fetchmany(3):
+                        pass
                 elif fetch == "all":
                     await cur.fetchall()
                 elif fetch == "iter":
index 0f2155e3ff6d36b9a230fdc4cc1c49b95d757021..b12128cf5ceda32aec388c2904f005c6cd3ddb0f 100644 (file)
@@ -93,13 +93,11 @@ def test_leak(conn_cls, dsn, faker, fmt, fmt_out, fetch, row_factory, gc):
                     cur.execute(ph(cur, faker.select_stmt))
 
                     if fetch == "one":
-                        while True:
-                            if cur.fetchone() is None:
-                                break
+                        while cur.fetchone() is not None:
+                            pass
                     elif fetch == "many":
-                        while True:
-                            if not cur.fetchmany(3):
-                                break
+                        while cur.fetchmany(3):
+                            pass
                     elif fetch == "all":
                         cur.fetchall()
                     elif fetch == "iter":
index f2f9553ad7576f09d8949b47f1577ca6373c3da9..089d0ff4b987c2cfc7dfd42c44a7b9a42e1a21d2 100644 (file)
@@ -90,13 +90,11 @@ async def test_leak(aconn_cls, dsn, faker, fmt, fmt_out, fetch, row_factory, gc)
                     await cur.execute(ph(cur, faker.select_stmt))
 
                     if fetch == "one":
-                        while True:
-                            if (await cur.fetchone()) is None:
-                                break
+                        while (await cur.fetchone()) is not None:
+                            pass
                     elif fetch == "many":
-                        while True:
-                            if not (await cur.fetchmany(3)):
-                                break
+                        while await cur.fetchmany(3):
+                            pass
                     elif fetch == "all":
                         await cur.fetchall()
                     elif fetch == "iter":
index 365f0b76235e5c31b18c123cb7edd94b18ee0bec..d5342a689035c8ad5b5f28c7fa44ab3e265ea157 100644 (file)
@@ -384,9 +384,7 @@ def test_row_factory(conn):
     cur.execute("select generate_series(1, 3) as x")
     recs = cur.fetchall()
     cur.scroll(0, "absolute")
-    while True:
-        if not (rec := cur.fetchone()):
-            break
+    while rec := cur.fetchone():
         recs.append(rec)
     assert recs == [[1, -1], [1, -2], [1, -3]] * 2
 
index cede008f76162cb1e0bcbb8da108fc34f84e74fa..682859984172a971202108eb4aa6720b4d47876d 100644 (file)
@@ -390,9 +390,7 @@ async def test_row_factory(aconn):
     await cur.execute("select generate_series(1, 3) as x")
     recs = await cur.fetchall()
     await cur.scroll(0, "absolute")
-    while True:
-        if not (rec := (await cur.fetchone())):
-            break
+    while rec := (await cur.fetchone()):
         recs.append(rec)
     assert recs == [[1, -1], [1, -2], [1, -3]] * 2