]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
test: add --concurrency option to benchmark script
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 23 Dec 2022 04:24:51 +0000 (04:24 +0000)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 29 Dec 2022 04:15:24 +0000 (04:15 +0000)
Investigate the poor parallelism mentioned in #448.

tests/scripts/bench-411.py

index 82ea451d8f86e09d3317558e01f5f5c746bdb07a..30d71684b21a02e53defe0a15ed8d5839ba7753c 100644 (file)
@@ -8,6 +8,7 @@ from enum import Enum
 from typing import Any, Dict, List, Generator
 from argparse import ArgumentParser, Namespace
 from contextlib import contextmanager
+from concurrent.futures import ThreadPoolExecutor
 
 logger = logging.getLogger()
 logging.basicConfig(
@@ -134,15 +135,22 @@ def run_psycopg2(psycopg2: Any, args: Namespace) -> None:
                 cursor.executemany(insert, data)
             conn.commit()
 
-    logger.info(f"running {args.ntests} queries")
-    to_query = random.choices(ids, k=args.ntests)
-    with psycopg2.connect(args.dsn) as conn:
-        with time_log("psycopg2"):
-            for id_ in to_query:
-                with conn.cursor() as cursor:
-                    cursor.execute(select, {"id": id_})
-                    cursor.fetchall()
-                # conn.rollback()
+    def run(i):
+        logger.info(f"thread {i} running {args.ntests} queries")
+        to_query = random.choices(ids, k=args.ntests)
+        with psycopg2.connect(args.dsn) as conn:
+            with time_log("psycopg2"):
+                for id_ in to_query:
+                    with conn.cursor() as cursor:
+                        cursor.execute(select, {"id": id_})
+                        cursor.fetchall()
+                    # conn.rollback()
+
+    if args.concurrency == 1:
+        run(0)
+    else:
+        with ThreadPoolExecutor(max_workers=args.concurrency) as executor:
+            list(executor.map(run, range(args.concurrency)))
 
     if args.drop:
         logger.info("dropping test records")
@@ -164,15 +172,22 @@ def run_psycopg(psycopg: Any, args: Namespace) -> None:
                 cursor.executemany(insert, data)
             conn.commit()
 
-    logger.info(f"running {args.ntests} queries")
-    to_query = random.choices(ids, k=args.ntests)
-    with psycopg.connect(args.dsn) as conn:
-        with time_log("psycopg"):
-            for id_ in to_query:
-                with conn.cursor() as cursor:
-                    cursor.execute(select, {"id": id_})
-                    cursor.fetchall()
-                # conn.rollback()
+    def run(i):
+        logger.info(f"thread {i} running {args.ntests} queries")
+        to_query = random.choices(ids, k=args.ntests)
+        with psycopg.connect(args.dsn) as conn:
+            with time_log("psycopg"):
+                for id_ in to_query:
+                    with conn.cursor() as cursor:
+                        cursor.execute(select, {"id": id_})
+                        cursor.fetchall()
+                    # conn.rollback()
+
+    if args.concurrency == 1:
+        run(0)
+    else:
+        with ThreadPoolExecutor(max_workers=args.concurrency) as executor:
+            list(executor.map(run, range(args.concurrency)))
 
     if args.drop:
         logger.info("dropping test records")
@@ -196,15 +211,22 @@ async def run_psycopg_async(psycopg: Any, args: Namespace) -> None:
                 await cursor.executemany(insert, data)
             await conn.commit()
 
-    logger.info(f"running {args.ntests} queries")
-    to_query = random.choices(ids, k=args.ntests)
-    async with await psycopg.AsyncConnection.connect(args.dsn) as conn:
-        with time_log("psycopg_async"):
-            for id_ in to_query:
-                cursor = await conn.execute(select, {"id": id_})
-                await cursor.fetchall()
-                await cursor.close()
-                # await conn.rollback()
+    async def run(i):
+        logger.info(f"task {i} running {args.ntests} queries")
+        to_query = random.choices(ids, k=args.ntests)
+        async with await psycopg.AsyncConnection.connect(args.dsn) as conn:
+            with time_log("psycopg_async"):
+                for id_ in to_query:
+                    cursor = await conn.execute(select, {"id": id_})
+                    await cursor.fetchall()
+                    await cursor.close()
+                    # await conn.rollback()
+
+    if args.concurrency == 1:
+        await run(0)
+    else:
+        tasks = [run(i) for i in range(args.concurrency)]
+        await asyncio.gather(*tasks)
 
     if args.drop:
         logger.info("dropping test records")
@@ -232,16 +254,23 @@ async def run_asyncpg(asyncpg: Any, args: Namespace) -> None:
             await conn.executemany(a_insert, [tuple(d.values()) for d in data])
         await conn.close()
 
-    logger.info(f"running {args.ntests} queries")
-    to_query = random.choices(ids, k=args.ntests)
-    conn = await asyncpg.connect(args.dsn)
-    with time_log("asyncpg"):
-        for id_ in to_query:
-            tr = conn.transaction()
-            await tr.start()
-            await conn.fetch(a_select, id_)
-            # await tr.rollback()
-    await conn.close()
+    async def run(i):
+        logger.info(f"task {i} running {args.ntests} queries")
+        to_query = random.choices(ids, k=args.ntests)
+        conn = await asyncpg.connect(args.dsn)
+        with time_log("asyncpg"):
+            for id_ in to_query:
+                # tr = conn.transaction()
+                # await tr.start()
+                await conn.fetch(a_select, id_)
+                # await tr.rollback()
+        await conn.close()
+
+    if args.concurrency == 1:
+        await run(0)
+    else:
+        tasks = [run(i) for i in range(args.concurrency)]
+        await asyncio.gather(*tasks)
 
     if args.drop:
         logger.info("dropping test records")
@@ -263,11 +292,20 @@ def parse_cmdline() -> Namespace:
 
     parser.add_argument(
         "--ntests",
+        "-n",
         type=int,
         default=10_000,
         help="number of tests to perform [default: %(default)s]",
     )
 
+    parser.add_argument(
+        "--concurrency",
+        "-c",
+        type=int,
+        default=1,
+        help="number of parallel tasks [default: %(default)s]",
+    )
+
     parser.add_argument(
         "--dsn",
         default=os.environ.get("PSYCOPG_TEST_DSN", ""),