]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
test: add free-threading-specific tests
authorLysandros Nikolaou <lisandrosnik@gmail.com>
Fri, 28 Nov 2025 01:12:43 +0000 (02:12 +0100)
committerGitHub <noreply@github.com>
Fri, 28 Nov 2025 01:12:43 +0000 (01:12 +0000)
Add specific concurrency tests for the public interface of connections and
cursors

PR #1211, related to free-threading support (#1095).

.github/workflows/tests.yml
psycopg_c/build_backend/psycopg_build_ext.py
tests/fix_gc.py
tests/pool/test_pool.py
tests/pool/test_pool_async.py
tests/test_free_threading.py [new file with mode: 0644]
tests/test_notify.py
tests/test_notify_async.py
tests/utils.py

index f58933f9249d59d629ed022f67eb5ed157f3bc45..3c3adce41c0b19df61fb6e01b44adfceb9f99b3a 100644 (file)
@@ -34,12 +34,14 @@ jobs:
           - {impl: python, python: "3.12", postgres: "postgres:16", libpq: newest}
           - {impl: python, python: "3.13", postgres: "postgres:13"}
           - {impl: python, python: "3.14", postgres: "postgres:14"}
+          - {impl: python, python: "3.14t", postgres: "postgres:14"}
 
           - {impl: c, python: "3.10", postgres: "postgres:13", libpq: master}
           - {impl: c, python: "3.11", postgres: "postgres:15", libpq: oldest}
           - {impl: c, python: "3.12", postgres: "postgres:16"}
           - {impl: c, python: "3.13", postgres: "postgres:17", libpq: newest}
           - {impl: c, python: "3.14", postgres: "postgres:18"}
+          - {impl: c, python: "3.14t", postgres: "postgres:18"}
 
           - {impl: python, python: "3.10", ext: gevent, postgres: "postgres:17"}
           - {impl: python, python: "3.10", ext: dns, postgres: "postgres:14"}
index fb066cf85b4b2838fff292bf9c536d689f03cc2e..b4293fe7dba33943f637f2b011a8a36793ae1b38 100644 (file)
@@ -61,6 +61,7 @@ class psycopg_build_ext(build_ext):
                 language_level=3,
                 compiler_directives={
                     "always_allow_keywords": False,
+                    "freethreading_compatible": True,
                 },
                 annotate=False,  # enable to get an html view of the C module
             )
index 0cf1686b4576d98607b12dff8cc9746515d3637b..7138166146e4311a44f957fe297b4b9690bc465d 100644 (file)
@@ -2,6 +2,7 @@ from __future__ import annotations
 
 import gc
 import sys
+import sysconfig
 
 import pytest
 
@@ -71,7 +72,7 @@ def fixture_gc():
 
     **Note:** This will skip tests on PyPy.
     """
-    if sys.implementation.name == "pypy":
+    if sys.implementation.name == "pypy" or sysconfig.get_config_var("Py_GIL_DISABLED"):
         pytest.skip(reason="depends on refcount semantics")
     return GCFixture()
 
@@ -83,4 +84,6 @@ def gc_collect():
 
     **Note:** This will *not* skip tests on PyPy.
     """
+    if sysconfig.get_config_var("Py_GIL_DISABLED"):
+        pytest.skip(reason="depends on refcount semantics")
     return GCFixture.collect
index 965e5e34d0bdba9f8d0f18bef6fab95666112b30..856b9914ac7cf7e7caf33270caaca3c42ae49dec 100644 (file)
@@ -15,7 +15,7 @@ import psycopg
 from psycopg.pq import TransactionStatus
 from psycopg.rows import Row, TupleRow, class_row
 
-from ..utils import assert_type, set_autocommit
+from ..utils import assert_type, set_autocommit, skip_free_threaded
 from ..acompat import Event, gather, skip_sync, sleep, spawn
 from .test_pool_common import delay_connection
 
@@ -861,6 +861,7 @@ def test_check_max_lifetime(dsn):
 
 
 @pytest.mark.slow
+@skip_free_threaded("timing not accurate under the free-threaded build")
 def test_stats_connect(proxy, monkeypatch):
     proxy.start()
     delay_connection(monkeypatch, 0.2)
index e1c52ec273720231a2a78295b987a5243ea69bec..dc91c2f1b15964791536d10b96bb9184926ec3d8 100644 (file)
@@ -12,7 +12,7 @@ import psycopg
 from psycopg.pq import TransactionStatus
 from psycopg.rows import Row, TupleRow, class_row
 
-from ..utils import assert_type, set_autocommit
+from ..utils import assert_type, set_autocommit, skip_free_threaded
 from ..acompat import AEvent, asleep, gather, skip_sync, spawn
 from .test_pool_common_async import delay_connection
 
@@ -862,6 +862,7 @@ async def test_check_max_lifetime(dsn):
 
 
 @pytest.mark.slow
+@skip_free_threaded("timing not accurate under the free-threaded build")
 async def test_stats_connect(proxy, monkeypatch):
     proxy.start()
     delay_connection(monkeypatch, 0.2)
diff --git a/tests/test_free_threading.py b/tests/test_free_threading.py
new file mode 100644 (file)
index 0000000..cfe8090
--- /dev/null
@@ -0,0 +1,184 @@
+import threading
+from concurrent.futures import ThreadPoolExecutor
+
+import pytest
+
+import psycopg
+
+from ._test_connection import testctx  # noqa: F401  # fixture
+
+
+@pytest.mark.slow
+@pytest.mark.usefixtures("testctx")
+def test_concurrent_connection_insert(conn):
+    nthreads = 10
+    barrier = threading.Barrier(parties=nthreads)
+
+    def worker(i):
+        barrier.wait()
+        with conn.cursor() as cur:
+            cur.execute("insert into testctx values (%s)", (i,))
+
+    with ThreadPoolExecutor(max_workers=nthreads) as tpe:
+        futures = [tpe.submit(worker, i) for i in range(100)]
+        for future in futures:
+            future.result()  # to verify nothing raises
+
+    with conn.cursor() as cur:
+        cur.execute("select id from testctx")
+        data = set(cur)
+
+    assert data == set((i,) for i in range(100))
+
+
+@pytest.mark.slow
+@pytest.mark.usefixtures("testctx")
+def test_concurrent_connection_select(conn):
+    nthreads = 10
+    barrier = threading.Barrier(parties=nthreads)
+
+    with conn.cursor() as cur:
+        cur.execute("insert into testctx values (1), (2), (3)")
+
+    def worker():
+        barrier.wait()
+        with conn.cursor() as cur:
+            cur.execute("select id from testctx")
+            assert cur.fetchall() == [(1,), (2,), (3,)]
+
+    with ThreadPoolExecutor(max_workers=nthreads) as tpe:
+        futures = [tpe.submit(worker) for _ in range(100)]
+        for future in futures:
+            future.result()  # to verify nothing raises
+
+
+@pytest.mark.slow
+@pytest.mark.usefixtures("testctx")
+def test_concurrent_connection_update(conn):
+    nthreads = 10
+    barrier = threading.Barrier(parties=nthreads)
+
+    with conn.cursor() as cur:
+        cur.execute("insert into testctx values (0)")
+
+    def worker():
+        barrier.wait()
+        with conn.cursor() as cur:
+            cur.execute("update testctx set id = id + 1")
+
+    with ThreadPoolExecutor(max_workers=nthreads) as tpe:
+        futures = [tpe.submit(worker) for _ in range(100)]
+        for future in futures:
+            future.result()  # to verify nothing raises
+
+    with conn.cursor() as cur:
+        cur.execute("select id from testctx")
+        assert cur.fetchone()[0] == 100
+
+
+@pytest.mark.slow
+@pytest.mark.usefixtures("testctx")
+def test_concurrent_connection_cursors_share_transaction_state(conn):
+    with conn.cursor() as cur:
+        cur.execute("insert into testctx values (1)")
+    conn.commit()
+
+    barrier = threading.Barrier(parties=2)
+    row_added = threading.Event()
+    row_read = threading.Event()
+    transaction_rolled_back = threading.Event()
+
+    def writer():
+        """Thread that inserts a new row but doesn't commit"""
+        barrier.wait()
+        with conn.cursor() as cur:
+            cur.execute("insert into testctx values (2)")
+        row_added.set()
+        row_read.wait()
+        conn.rollback()
+        transaction_rolled_back.set()
+
+    def reader():
+        """Thread that should see uncommitted changes from writer"""
+        barrier.wait()
+
+        row_added.wait()
+        with conn.cursor() as cur:
+            cur.execute("select id from testctx order by id")
+            data = [row[0] for row in cur.fetchall()]
+            reader_saw = data
+        row_read.set()
+        transaction_rolled_back.wait()
+        with conn.cursor() as cur:
+            cur.execute("select id from testctx order by id")
+            assert [row[0] for row in cur.fetchall()] == [1]
+
+        return reader_saw
+
+    with ThreadPoolExecutor(max_workers=2) as tpe:
+        t1 = tpe.submit(writer)
+        t2 = tpe.submit(reader)
+        t1.result()  # No exception
+        assert t2.result() == [1, 2]  # No exception + correct data
+
+
+@pytest.mark.slow
+@pytest.mark.usefixtures("testctx")
+def test_error_in_one_cursor_affects_all_cursors(conn):
+    with conn.cursor() as cur:
+        cur.execute("insert into testctx values (1)")
+    conn.commit()
+
+    error_happened = threading.Event()
+
+    def cause_error():
+        with pytest.raises(psycopg.errors.UndefinedTable):
+            with conn.cursor() as cur:
+                cur.execute("SELECT * FROM nonexistent_table")
+        error_happened.set()
+
+    def try_query_after_error():
+        error_happened.wait()
+
+        with pytest.raises(psycopg.errors.InFailedSqlTransaction):
+            with conn.cursor() as cur:
+                cur.execute("select id from testctx")
+
+        # After rollback, should work again
+        conn.rollback()
+        with conn.cursor() as cur:
+            cur.execute("select id from testctx")
+            assert [row[0] for row in cur.fetchall()] == [1]
+
+    with ThreadPoolExecutor(max_workers=2) as tpe:
+        t1 = tpe.submit(cause_error)
+        t2 = tpe.submit(try_query_after_error)
+        t1.result()
+        t2.result()
+
+
+@pytest.mark.slow
+def test_same_cursor_from_multiple_threads_no_crash(conn):
+    """
+    This is only there to verify that there's no hard crash.
+    All exceptions are fine.
+    """
+    nthreads = 10
+    barrier = threading.Barrier(parties=nthreads)
+
+    cur = conn.cursor()
+
+    def worker():
+        """Multiple threads trying to use the same cursor"""
+        barrier.wait()
+        try:
+            cur.execute("select 1")
+        except Exception:
+            pass
+
+    with ThreadPoolExecutor(max_workers=nthreads) as tpe:
+        futures = [tpe.submit(worker) for _ in range(100)]
+        for future in futures:
+            future.result()
+
+    cur.close()
index fde74058dc76bab153b1b05f4b7dcce05487c3f0..af7724d9fa0b94e7d891d8947736d45278972f74 100644 (file)
@@ -9,6 +9,7 @@ import pytest
 
 from psycopg import Notify
 
+from .utils import skip_free_threaded
 from .acompat import Event, gather, sleep, spawn
 
 pytestmark = pytest.mark.crdb_skip("notify")
@@ -223,6 +224,7 @@ def test_notifies_blocking(conn):
 
 
 @pytest.mark.slow
+@skip_free_threaded("warnings are context-local in the free-threaded build >= 3.14")
 def test_generator_and_handler(conn, conn_cls, dsn, recwarn):
     # NOTE: we don't support generator+handlers anymore. So, if in the future
     # this behaviour will change, we will not consider it a regression. However
index 6faf9be4f1c0ab96ddb360f339857e763618a992..65679abe209077cb8256d31f85bd46b3bbc9fff0 100644 (file)
@@ -6,6 +6,7 @@ import pytest
 
 from psycopg import Notify
 
+from .utils import skip_free_threaded
 from .acompat import AEvent, alist, asleep, gather, spawn
 
 pytestmark = pytest.mark.crdb_skip("notify")
@@ -219,6 +220,7 @@ async def test_notifies_blocking(aconn):
 
 
 @pytest.mark.slow
+@skip_free_threaded("warnings are context-local in the free-threaded build >= 3.14")
 async def test_generator_and_handler(aconn, aconn_cls, dsn, recwarn):
     # NOTE: we don't support generator+handlers anymore. So, if in the future
     # this behaviour will change, we will not consider it a regression. However
index e6c5799180b136170bc0eba6c8d386d18cfadb96..191e584bf6fa29146c44b847ab5462993b7f248a 100644 (file)
@@ -5,6 +5,7 @@ import sys
 import asyncio
 import operator
 import selectors
+import sysconfig
 from typing import Any
 from contextlib import contextmanager
 from collections.abc import Callable
@@ -201,3 +202,10 @@ def asyncio_run(coro: Any, *, debug: bool | None = None) -> Any:
             asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
 
     return asyncio.run(coro, debug=debug, **kwargs)
+
+
+def skip_free_threaded(reason="unsafe under the free-threaded build"):
+    return pytest.mark.skipif(
+        sysconfig.get_config_var("Py_GIL_DISABLED"),
+        reason=reason,
+    )