]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Dropped use of `send()` on Connection.notifies generator
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 11 Nov 2020 19:25:34 +0000 (19:25 +0000)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 11 Nov 2020 22:57:47 +0000 (22:57 +0000)
`close()` is good enough.

psycopg3/psycopg3/__init__.py
psycopg3/psycopg3/connection.py
tests/test_concurrency.py
tests/test_concurrency_async.py
tests/test_connection.py

index aeb87d83b700e525b983232e05548f495eaec47b..780beece7ff6344dc9d7258d618b33488f66824e 100644 (file)
@@ -5,7 +5,7 @@ psycopg3 -- PostgreSQL database adapter for Python
 # Copyright (C) 2020 The Psycopg Team
 
 from . import pq
-from .connection import AsyncConnection, Connection
+from .connection import AsyncConnection, Connection, Notify
 
 from .errors import (
     Warning,
@@ -47,7 +47,7 @@ __all__ = (
     ["Warning", "Error", "InterfaceError", "DatabaseError", "DataError"]
     + ["OperationalError", "IntegrityError", "InternalError"]
     + ["ProgrammingError", "NotSupportedError"]
-    + ["AsyncConnection", "Connection", "connect"]
+    + ["AsyncConnection", "Connection", "Notify"]
     + ["BINARY", "DATETIME", "NUMBER", "ROWID", "STRING"]
     + ["Binary", "Date", "DateFromTicks", "Time", "TimeFromTicks"]
     + ["Timestamp", "TimestampFromTicks"]
index 9bbbae04c9af5835df7bbebdccd50ac6ae3f3d5d..c7bf46a8299d6d4620ef21fcd737f38560288a32 100644 (file)
@@ -8,7 +8,7 @@ import logging
 import asyncio
 import threading
 from types import TracebackType
-from typing import Any, AsyncGenerator, Callable, Generator, List, NamedTuple
+from typing import Any, AsyncIterator, Callable, Iterator, List, NamedTuple
 from typing import Optional, Type, cast
 from weakref import ref, ReferenceType
 from functools import partial
@@ -315,7 +315,8 @@ class Connection(BaseConnection):
             if result.status != ExecStatus.TUPLES_OK:
                 raise e.error_from_result(result, encoding=self._pyenc)
 
-    def notifies(self) -> Generator[Optional[Notify], bool, None]:
+    def notifies(self) -> Iterator[Notify]:
+        """Generate a stream of `Notify`"""
         while 1:
             with self.lock:
                 ns = self.wait(notifies(self.pgconn))
@@ -325,9 +326,7 @@ class Connection(BaseConnection):
                     pgn.extra.decode(self._pyenc),
                     pgn.be_pid,
                 )
-                if (yield n):
-                    yield None  # for the send who stopped us
-                    return
+                yield n
 
     def _set_autocommit(self, value: bool) -> None:
         with self.lock:
@@ -445,7 +444,7 @@ class AsyncConnection(BaseConnection):
             if result.status != ExecStatus.TUPLES_OK:
                 raise e.error_from_result(result, encoding=self._pyenc)
 
-    async def notifies(self) -> AsyncGenerator[Optional[Notify], bool]:
+    async def notifies(self) -> AsyncIterator[Notify]:
         while 1:
             async with self.lock:
                 ns = await self.wait(notifies(self.pgconn))
@@ -455,9 +454,7 @@ class AsyncConnection(BaseConnection):
                     pgn.extra.decode(self._pyenc),
                     pgn.be_pid,
                 )
-                if (yield n):
-                    yield None
-                    return
+                yield n
 
     def _set_autocommit(self, value: bool) -> None:
         raise AttributeError(
index 40b6b05a29b12738e4a68c445432d966a871aa4b..2fc88c3172e7e64d4c517d28f8415eb410579b07 100644 (file)
@@ -105,29 +105,33 @@ t.join()
 
 @pytest.mark.slow
 def test_notifies(conn, dsn):
-    nconn = psycopg3.connect(dsn)
+    nconn = psycopg3.connect(dsn, autocommit=True)
     npid = nconn.pgconn.backend_pid
 
     def notifier():
         time.sleep(0.25)
-        nconn.pgconn.exec_(b"notify foo, '1'")
+        nconn.cursor().execute("notify foo, '1'")
         time.sleep(0.25)
-        nconn.pgconn.exec_(b"notify foo, '2'")
-        nconn.close()
+        nconn.cursor().execute("notify foo, '2'")
+
+    conn.autocommit = True
+    conn.cursor().execute("listen foo")
 
-    conn.pgconn.exec_(b"listen foo")
     t0 = time.time()
     t = threading.Thread(target=notifier)
     t.start()
+
     ns = []
     gen = conn.notifies()
     for n in gen:
         ns.append((n, time.time()))
         if len(ns) >= 2:
-            gen.send(True)
+            gen.close()
+
     assert len(ns) == 2
 
     n, t1 = ns[0]
+    assert isinstance(n, psycopg3.Notify)
     assert n.pid == npid
     assert n.channel == "foo"
     assert n.payload == "1"
index 7502574f680b79d273e020d88eb8e7336c8f8dc9..02327e256bbcd5cff28b16cf8bef9e8c1c6ad7a5 100644 (file)
@@ -56,23 +56,26 @@ async def test_concurrent_execution(dsn):
 
 @pytest.mark.slow
 async def test_notifies(aconn, dsn):
-    nconn = await psycopg3.AsyncConnection.connect(dsn)
+    nconn = await psycopg3.AsyncConnection.connect(dsn, autocommit=True)
     npid = nconn.pgconn.backend_pid
 
     async def notifier():
+        cur = await nconn.cursor()
         await asyncio.sleep(0.25)
-        nconn.pgconn.exec_(b"notify foo, '1'")
+        await cur.execute("notify foo, '1'")
         await asyncio.sleep(0.25)
-        nconn.pgconn.exec_(b"notify foo, '2'")
+        await cur.execute("notify foo, '2'")
         await nconn.close()
 
     async def receiver():
-        aconn.pgconn.exec_(b"listen foo")
+        await aconn.set_autocommit(True)
+        cur = await aconn.cursor()
+        await cur.execute("listen foo")
         gen = aconn.notifies()
         async for n in gen:
             ns.append((n, time.time()))
             if len(ns) >= 2:
-                gen.send(True)
+                gen.close()
 
     ns = []
     t0 = time.time()
index 66bbda62c9880c18c8d6d5f4ba43ae19d41cc08e..eabd0aad10cd513010799a8b2a1c088f022e30ed 100644 (file)
@@ -4,7 +4,7 @@ import logging
 import weakref
 
 import psycopg3
-from psycopg3 import Connection
+from psycopg3 import Connection, Notify
 from psycopg3.errors import UndefinedTable
 from psycopg3.conninfo import conninfo_to_dict
 
@@ -384,6 +384,7 @@ def test_notify_handlers(conn):
     assert len(nots1) == 1
     assert len(nots2) == 2
     n = nots2[1]
+    assert isinstance(n, Notify)
     assert n.channel == "foo"
     assert n.payload == "n2"
     assert n.pid == conn.pgconn.backend_pid