import logging
from os import getpid
from weakref import ref
-from functools import partial
-from ctypes import Array, pointer, string_at, create_string_buffer, byref
-from ctypes import addressof, c_char_p, c_int, c_size_t, c_ulong
+from ctypes import Array, POINTER, cast, pointer, string_at, create_string_buffer, byref
+from ctypes import addressof, c_char_p, c_int, c_size_t, c_ulong, c_void_p, py_object
from typing import Any, Callable, List, Optional, Sequence, Tuple
from typing import cast as t_cast, TYPE_CHECKING
return impl.PQlibVersion()
-def notice_receiver(
- arg: Any, result_ptr: impl.PGresult_struct, wconn: "ref[PGconn]"
-) -> None:
- pgconn = wconn()
+@impl.PQnoticeReceiver # type: ignore
+def notice_receiver(arg: c_void_p, result_ptr: impl.PGresult_struct) -> None:
+ pgconn = cast(arg, POINTER(py_object)).contents.value()
if not (pgconn and pgconn.notice_handler):
return
"_pgconn_ptr",
"notice_handler",
"notify_handler",
- "_notice_receiver",
+ "_self_ptr",
"_procpid",
"__weakref__",
)
self.notice_handler: Optional[Callable[["abc.PGresult"], None]] = None
self.notify_handler: Optional[Callable[[PGnotify], None]] = None
- self._notice_receiver = impl.PQnoticeReceiver( # type: ignore
- partial(notice_receiver, wconn=ref(self))
- )
- impl.PQsetNoticeReceiver(pgconn_ptr, self._notice_receiver, None)
+ # Keep alive for the lifetime of PGconn
+ self._self_ptr = py_object(ref(self))
+ impl.PQsetNoticeReceiver(pgconn_ptr, notice_receiver, byref(self._self_ptr))
self._procpid = getpid()
import queue
import signal
import threading
+import multiprocessing
import subprocess as sp
from typing import List
t = time.time() - t0
assert proc.returncode == 0
assert 1 < t < 2
+
+
+@pytest.mark.slow
+@pytest.mark.subprocess
+@pytest.mark.skipif(
+ multiprocessing.get_all_start_methods()[0] != "fork",
+ reason="problematic behavior only exhibited via fork",
+)
+def test_segfault_on_fork_close(dsn):
+ # https://github.com/psycopg/psycopg/issues/300
+ script = f"""\
+import gc
+import psycopg
+from multiprocessing import Pool
+
+def test(arg):
+ conn1 = psycopg.connect({dsn!r})
+ conn1.close()
+ conn1 = None
+ gc.collect()
+ return 1
+
+if __name__ == '__main__':
+ conn = psycopg.connect({dsn!r})
+ with Pool(2) as p:
+ pool_result = p.map_async(test, [1, 2])
+ pool_result.wait(timeout=5)
+ if pool_result.ready():
+ print(pool_result.get(timeout=1))
+"""
+ env = dict(os.environ)
+ env["PYTHONFAULTHANDLER"] = "1"
+ out = sp.check_output([sys.executable, "-s", "-c", script], env=env)
+ assert out.decode().rstrip() == "[1, 1]"