From: Daniele Varrazzo Date: Mon, 1 Apr 2024 23:18:50 +0000 (+0000) Subject: fix: lock the connection during a 'notifies()' call X-Git-Tag: 3.2.0~12^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=refs%2Fpull%2F760%2Fhead;p=thirdparty%2Fpsycopg.git fix: lock the connection during a 'notifies()' call With the previous implementation, it was possible to sneak an execute() while the generator is consumed. This gives the false impression that it's possible to use the connection while listening (see #756), which is false for reason better explored in #757. Therefore, lock the connection while listening to notifications. If someone wants to mix commands with listening on the same connection, they should do it collaboratively with an adequately short notifies() timeout. --- diff --git a/psycopg/psycopg/connection.py b/psycopg/psycopg/connection.py index 8051a1b5c..9558e29c6 100644 --- a/psycopg/psycopg/connection.py +++ b/psycopg/psycopg/connection.py @@ -326,7 +326,7 @@ class Connection(BaseConnection[Row]): notifications arrives in the same packet. """ # Allow interrupting the wait with a signal by reducing a long timeout - # into shorter interval. + # into shorter intervals. if timeout is not None: deadline = monotonic() + timeout interval = min(timeout, _WAIT_INTERVAL) @@ -336,34 +336,33 @@ class Connection(BaseConnection[Row]): nreceived = 0 - while True: - # Collect notifications. Also get the connection encoding if any - # notification is received to makes sure that they are consistent. - try: - with self.lock: + with self.lock: + enc = self.pgconn._encoding + while True: + try: ns = self.wait(notifies(self.pgconn), interval=interval) - if ns: - enc = self.pgconn._encoding - except e._NO_TRACEBACK as ex: - raise ex.with_traceback(None) + except e._NO_TRACEBACK as ex: + raise ex.with_traceback(None) - # Emit the notifications received. - for pgn in ns: - n = Notify(pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid) - yield n - nreceived += 1 - - # Stop if we have received enough notifications. - if stop_after is not None and nreceived >= stop_after: - break + # Emit the notifications received. + for pgn in ns: + n = Notify( + pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid + ) + yield n + nreceived += 1 - # Check the deadline after the loop to ensure that timeout=0 - # polls at least once. - if deadline: - interval = min(_WAIT_INTERVAL, deadline - monotonic()) - if interval < 0.0: + # Stop if we have received enough notifications. + if stop_after is not None and nreceived >= stop_after: break + # Check the deadline after the loop to ensure that timeout=0 + # polls at least once. + if deadline: + interval = min(_WAIT_INTERVAL, deadline - monotonic()) + if interval < 0.0: + break + @contextmanager def pipeline(self) -> Iterator[Pipeline]: """Context manager to switch the connection into pipeline mode.""" diff --git a/psycopg/psycopg/connection_async.py b/psycopg/psycopg/connection_async.py index 8252594b5..f82704e55 100644 --- a/psycopg/psycopg/connection_async.py +++ b/psycopg/psycopg/connection_async.py @@ -346,7 +346,7 @@ class AsyncConnection(BaseConnection[Row]): notifications arrives in the same packet. """ # Allow interrupting the wait with a signal by reducing a long timeout - # into shorter interval. + # into shorter intervals. if timeout is not None: deadline = monotonic() + timeout interval = min(timeout, _WAIT_INTERVAL) @@ -356,34 +356,33 @@ class AsyncConnection(BaseConnection[Row]): nreceived = 0 - while True: - # Collect notifications. Also get the connection encoding if any - # notification is received to makes sure that they are consistent. - try: - async with self.lock: + async with self.lock: + enc = self.pgconn._encoding + while True: + try: ns = await self.wait(notifies(self.pgconn), interval=interval) - if ns: - enc = self.pgconn._encoding - except e._NO_TRACEBACK as ex: - raise ex.with_traceback(None) + except e._NO_TRACEBACK as ex: + raise ex.with_traceback(None) - # Emit the notifications received. - for pgn in ns: - n = Notify(pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid) - yield n - nreceived += 1 - - # Stop if we have received enough notifications. - if stop_after is not None and nreceived >= stop_after: - break + # Emit the notifications received. + for pgn in ns: + n = Notify( + pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid + ) + yield n + nreceived += 1 - # Check the deadline after the loop to ensure that timeout=0 - # polls at least once. - if deadline: - interval = min(_WAIT_INTERVAL, deadline - monotonic()) - if interval < 0.0: + # Stop if we have received enough notifications. + if stop_after is not None and nreceived >= stop_after: break + # Check the deadline after the loop to ensure that timeout=0 + # polls at least once. + if deadline: + interval = min(_WAIT_INTERVAL, deadline - monotonic()) + if interval < 0.0: + break + @asynccontextmanager async def pipeline(self) -> AsyncIterator[AsyncPipeline]: """Context manager to switch the connection into pipeline mode.""" diff --git a/tests/test_notify.py b/tests/test_notify.py index c67b33115..e8bf6c91c 100644 --- a/tests/test_notify.py +++ b/tests/test_notify.py @@ -193,3 +193,25 @@ def test_stop_after_batch(conn_cls, conn, dsn): assert ns[1].payload == "2" finally: gather(worker) + + +@pytest.mark.slow +def test_notifies_blocking(conn): + + def listener(): + for _ in conn.notifies(timeout=1): + pass + + worker = spawn(listener) + try: + # Make sure the listener is listening + if not conn.lock.locked(): + sleep(0.01) + + t0 = time() + conn.execute("select 1") + dt = time() - t0 + finally: + gather(worker) + + assert dt > 0.5 diff --git a/tests/test_notify_async.py b/tests/test_notify_async.py index f4f0901d6..97e162846 100644 --- a/tests/test_notify_async.py +++ b/tests/test_notify_async.py @@ -190,3 +190,24 @@ async def test_stop_after_batch(aconn_cls, aconn, dsn): assert ns[1].payload == "2" finally: await gather(worker) + + +@pytest.mark.slow +async def test_notifies_blocking(aconn): + async def listener(): + async for _ in aconn.notifies(timeout=1): + pass + + worker = spawn(listener) + try: + # Make sure the listener is listening + if not aconn.lock.locked(): + await asleep(0.01) + + t0 = time() + await aconn.execute("select 1") + dt = time() - t0 + finally: + await gather(worker) + + assert dt > 0.5