]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix: use a static notification function for PQsetNoticeReceiver.
authorFlorian Apolloner <florian@apolloner.eu>
Sat, 27 Aug 2022 15:36:33 +0000 (17:36 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 28 Aug 2022 11:46:08 +0000 (12:46 +0100)
The connection is passed via the optional *arg.

Fixes #300

psycopg/psycopg/pq/pq_ctypes.py
tests/test_concurrency.py

index d0174adc695cdce35251983d329f2deafb96242d..6d29bd269711654358bcc079b5a9b2a2861be24a 100644 (file)
@@ -12,10 +12,9 @@ import sys
 import logging
 from os import getpid
 from weakref import ref
-from functools import partial
 
-from ctypes import Array, pointer, string_at, create_string_buffer, byref
-from ctypes import addressof, c_char_p, c_int, c_size_t, c_ulong
+from ctypes import Array, POINTER, cast, pointer, string_at, create_string_buffer, byref
+from ctypes import addressof, c_char_p, c_int, c_size_t, c_ulong, c_void_p, py_object
 from typing import Any, Callable, List, Optional, Sequence, Tuple
 from typing import cast as t_cast, TYPE_CHECKING
 
@@ -46,10 +45,9 @@ def version() -> int:
     return impl.PQlibVersion()
 
 
-def notice_receiver(
-    arg: Any, result_ptr: impl.PGresult_struct, wconn: "ref[PGconn]"
-) -> None:
-    pgconn = wconn()
+@impl.PQnoticeReceiver  # type: ignore
+def notice_receiver(arg: c_void_p, result_ptr: impl.PGresult_struct) -> None:
+    pgconn = cast(arg, POINTER(py_object)).contents.value()
     if not (pgconn and pgconn.notice_handler):
         return
 
@@ -71,7 +69,7 @@ class PGconn:
         "_pgconn_ptr",
         "notice_handler",
         "notify_handler",
-        "_notice_receiver",
+        "_self_ptr",
         "_procpid",
         "__weakref__",
     )
@@ -81,10 +79,9 @@ class PGconn:
         self.notice_handler: Optional[Callable[["abc.PGresult"], None]] = None
         self.notify_handler: Optional[Callable[[PGnotify], None]] = None
 
-        self._notice_receiver = impl.PQnoticeReceiver(  # type: ignore
-            partial(notice_receiver, wconn=ref(self))
-        )
-        impl.PQsetNoticeReceiver(pgconn_ptr, self._notice_receiver, None)
+        # Keep alive for the lifetime of PGconn
+        self._self_ptr = py_object(ref(self))
+        impl.PQsetNoticeReceiver(pgconn_ptr, notice_receiver, byref(self._self_ptr))
 
         self._procpid = getpid()
 
index 4a60bbf67b53adf8444911c3955c47c275d931da..eec24f1df1995f13aed1779d3a1e24ebe48adf92 100644 (file)
@@ -8,6 +8,7 @@ import time
 import queue
 import signal
 import threading
+import multiprocessing
 import subprocess as sp
 from typing import List
 
@@ -290,3 +291,37 @@ with psycopg.connect({dsn!r}) as conn:
     t = time.time() - t0
     assert proc.returncode == 0
     assert 1 < t < 2
+
+
+@pytest.mark.slow
+@pytest.mark.subprocess
+@pytest.mark.skipif(
+    multiprocessing.get_all_start_methods()[0] != "fork",
+    reason="problematic behavior only exhibited via fork",
+)
+def test_segfault_on_fork_close(dsn):
+    # https://github.com/psycopg/psycopg/issues/300
+    script = f"""\
+import gc
+import psycopg
+from multiprocessing import Pool
+
+def test(arg):
+    conn1 = psycopg.connect({dsn!r})
+    conn1.close()
+    conn1 = None
+    gc.collect()
+    return 1
+
+if __name__ == '__main__':
+    conn = psycopg.connect({dsn!r})
+    with Pool(2) as p:
+        pool_result = p.map_async(test, [1, 2])
+        pool_result.wait(timeout=5)
+        if pool_result.ready():
+            print(pool_result.get(timeout=1))
+"""
+    env = dict(os.environ)
+    env["PYTHONFAULTHANDLER"] = "1"
+    out = sp.check_output([sys.executable, "-s", "-c", script], env=env)
+    assert out.decode().rstrip() == "[1, 1]"