]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
refactor: use set literals
authorDenis Laxalde <denis.laxalde@dalibo.com>
Tue, 4 Jun 2024 07:40:30 +0000 (09:40 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 4 Jun 2024 19:03:55 +0000 (21:03 +0200)
As suggested by pyupgrade --py38-plus.

psycopg/psycopg/types/array.py
tests/_test_transaction.py
tests/pool/test_pool.py
tests/pool/test_pool_async.py
tests/pool/test_pool_common.py
tests/pool/test_pool_common_async.py
tools/async_to_sync.py
tools/bump_version.py

index 59a09ab545de098ee9947510e5aca69abd352798..78f3eaeff6d0a98a1bb68851f2c7cd2c3af4b8bc 100644 (file)
@@ -63,7 +63,7 @@ class BaseListDumper(RecursiveDumper):
             # More than one type in the list. It might be still good, as long
             # as they dump with the same oid (e.g. IPv4Network, IPv6Network).
             dumpers = [self._tx.get_dumper(item, format) for item in types.values()]
-            oids = set(d.oid for d in dumpers)
+            oids = {d.oid for d in dumpers}
             if len(oids) == 1:
                 t, v = types.popitem()
             else:
index 005e48ccae241aee8051f9b0a2d06c67d7b9eccd..29e7ad70d77ffd76d845957a2008fa2a03b7474e 100644 (file)
@@ -37,14 +37,14 @@ def inserted(conn):
     sql = "SELECT * FROM test_table"
     if isinstance(conn, psycopg.Connection):
         rows = conn.cursor().execute(sql).fetchall()
-        return set(v for (v,) in rows)
+        return {v for (v,) in rows}
     else:
 
         async def f():
             cur = conn.cursor()
             await cur.execute(sql)
             rows = await cur.fetchall()
-            return set(v for (v,) in rows)
+            return {v for (v,) in rows}
 
         return f()
 
index d5946f2d4d3ea902642b699dfaa16f192bdabb35..a3e991d1fb4671b6938ff4e40ddae790ac02e5ab 100644 (file)
@@ -637,7 +637,7 @@ def test_uniform_use(dsn):
                 counts[id(conn)] += 1
 
     assert len(counts) == 4
-    assert set(counts.values()) == set([2])
+    assert set(counts.values()) == {2}
 
 
 @pytest.mark.slow
@@ -712,7 +712,7 @@ def test_check(dsn, caplog):
             pid = conn.info.backend_pid
 
         p.wait(1.0)
-        pids = set((conn.info.backend_pid for conn in p._pool))
+        pids = {conn.info.backend_pid for conn in p._pool}
         assert pid in pids
         conn.close()
 
@@ -720,7 +720,7 @@ def test_check(dsn, caplog):
         p.check()
         assert len(caplog.records) == 1
         p.wait(1.0)
-        pids2 = set((conn.info.backend_pid for conn in p._pool))
+        pids2 = {conn.info.backend_pid for conn in p._pool}
         assert len(pids & pids2) == 3
         assert pid not in pids2
 
index 160d7119d6fd5027d67d74cd341e70baf8e60f15..bdcedff4949883e6b81aeeecb333b660930031d4 100644 (file)
@@ -640,7 +640,7 @@ async def test_uniform_use(dsn):
                 counts[id(conn)] += 1
 
     assert len(counts) == 4
-    assert set(counts.values()) == set([2])
+    assert set(counts.values()) == {2}
 
 
 @pytest.mark.slow
@@ -714,7 +714,7 @@ async def test_check(dsn, caplog):
             pid = conn.info.backend_pid
 
         await p.wait(1.0)
-        pids = set(conn.info.backend_pid for conn in p._pool)
+        pids = {conn.info.backend_pid for conn in p._pool}
         assert pid in pids
         await conn.close()
 
@@ -722,7 +722,7 @@ async def test_check(dsn, caplog):
         await p.check()
         assert len(caplog.records) == 1
         await p.wait(1.0)
-        pids2 = set(conn.info.backend_pid for conn in p._pool)
+        pids2 = {conn.info.backend_pid for conn in p._pool}
         assert len(pids & pids2) == 3
         assert pid not in pids2
 
index 60ef84b899aa20e892df144955a4a8d584ca0b20..1f2230cf0203962cc5e524b6eff44c5b94bf79f0 100644 (file)
@@ -182,7 +182,7 @@ def test_queue(pool_cls, dsn):
     for got, want in zip(times, want_times):
         assert got == pytest.approx(want, 0.2), times
 
-    assert len(set((r[2] for r in results))) == 2, results
+    assert len({r[2] for r in results}) == 2, results
 
 
 @pytest.mark.slow
@@ -272,7 +272,7 @@ def test_dead_client(pool_cls, dsn):
         gather(*ts)
 
         sleep(0.2)
-        assert set(results) == set([0, 1, 3, 4])
+        assert set(results) == {0, 1, 3, 4}
         if pool_cls is pool.ConnectionPool:
             assert len(p._pool) == 2  # no connection was lost
 
index 01cc957a900a25b396b019dde087a09b4988161c..4ee23270cb0c525a3f2bc15904e65d128e205e0a 100644 (file)
@@ -192,7 +192,7 @@ async def test_queue(pool_cls, dsn):
     for got, want in zip(times, want_times):
         assert got == pytest.approx(want, 0.2), times
 
-    assert len(set(r[2] for r in results)) == 2, results
+    assert len({r[2] for r in results}) == 2, results
 
 
 @pytest.mark.slow
@@ -283,7 +283,7 @@ async def test_dead_client(pool_cls, dsn):
         await gather(*ts)
 
         await asleep(0.2)
-        assert set(results) == set([0, 1, 3, 4])
+        assert set(results) == {0, 1, 3, 4}
         if pool_cls is pool.AsyncConnectionPool:
             assert len(p._pool) == 2  # no connection was lost
 
index 41d88fa5b962a4cc8c1b4cb4d6914e87a7e94647..ba6fd94be48c16c71a85634d4a65b371fc435464 100755 (executable)
@@ -144,9 +144,7 @@ def check(outputs: list[str]) -> int:
     if not maybe_conv:
         logger.error("no file to check? Maybe this script bitrot?")
         return 1
-    unk_conv = sorted(
-        set(maybe_conv) - set(fn.replace("_async", "") for fn in ALL_INPUTS)
-    )
+    unk_conv = sorted(set(maybe_conv) - {fn.replace("_async", "") for fn in ALL_INPUTS})
     if unk_conv:
         logger.error(
             "files converted by %s but not included in ALL_INPUTS: %s",
index dd371d7ba3bda1410b1b3f85966d0977c4790644..65b6e0220a371e7fbaf93bc64ec3c7dd6e707a6e 100755 (executable)
@@ -84,7 +84,7 @@ class Bumper:
 
     @cached_property
     def current_version(self) -> Version:
-        versions = set(self._parse_version_from_file(f) for f in self.package.ini_files)
+        versions = {self._parse_version_from_file(f) for f in self.package.ini_files}
         if len(versions) > 1:
             raise ValueError(
                 f"inconsistent versions ({', '.join(map(str, sorted(versions)))})"