]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fetchone() made async on async cursor
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 8 Apr 2020 10:10:12 +0000 (22:10 +1200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 8 Apr 2020 10:10:27 +0000 (22:10 +1200)
psycopg3/cursor.py
tests/test_async_connection.py

index 6c809db73e5c4c097a37d779a9e21f8fc7eb7b42..f878daf7f093cd354257f2b80c8eb691e3ef521a 100644 (file)
@@ -131,12 +131,6 @@ class BaseCursor:
         else:
             return None
 
-    def fetchone(self) -> Optional[Sequence[Any]]:
-        rv = self._cast_row(self._pos)
-        if rv is not None:
-            self._pos += 1
-        return rv
-
     def _cast_row(self, n: int) -> Optional[Tuple[Any, ...]]:
         res = self.pgresult
         if res is None:
@@ -164,6 +158,12 @@ class Cursor(BaseCursor):
             self._execute_results(results)
         return self
 
+    def fetchone(self) -> Optional[Sequence[Any]]:
+        rv = self._cast_row(self._pos)
+        if rv is not None:
+            self._pos += 1
+        return rv
+
 
 class AsyncCursor(BaseCursor):
     conn: "AsyncConnection"
@@ -180,6 +180,12 @@ class AsyncCursor(BaseCursor):
             self._execute_results(results)
         return self
 
+    async def fetchone(self) -> Optional[Sequence[Any]]:
+        rv = self._cast_row(self._pos)
+        if rv is not None:
+            self._pos += 1
+        return rv
+
 
 class NamedCursorMixin:
     pass
index df1116566070b1c1817650dc5a562f9eab230604..d9255552304b91d5aa1234a85b78e9343880760b 100644 (file)
@@ -41,7 +41,7 @@ def test_rollback(loop, pq, aconn):
 def test_get_encoding(aconn, loop):
     cur = aconn.cursor()
     loop.run_until_complete(cur.execute("show client_encoding"))
-    (enc,) = cur.fetchone()
+    (enc,) = loop.run_until_complete(cur.fetchone())
     assert enc == aconn.encoding
 
 
@@ -59,7 +59,7 @@ def test_set_encoding(aconn, loop):
     assert aconn.encoding == newenc
     cur = aconn.cursor()
     loop.run_until_complete(cur.execute("show client_encoding"))
-    (enc,) = cur.fetchone()
+    (enc,) = loop.run_until_complete(cur.fetchone())
     assert enc == newenc