]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added async implementation for executemany, fetchmany, fetchall
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 14 Apr 2020 07:03:02 +0000 (19:03 +1200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 14 Apr 2020 10:19:10 +0000 (22:19 +1200)
psycopg3/cursor.py
tests/test_async_cursor.py

index e1b3f877713c87f131acfe2268e9f153d4fc1656..13856b50a88313882385695c23e252aa78396348 100644 (file)
@@ -344,12 +344,61 @@ class AsyncCursor(BaseCursor):
             self._execute_results(results)
         return self
 
+    async def executemany(
+        self, query: Query, vars_seq: Sequence[Params]
+    ) -> "AsyncCursor":
+        async with self.connection.lock:
+            self._start_query()
+            for i, vars in enumerate(vars_seq):
+                if i == 0:
+                    pgq = self._send_prepare(b"", query, vars)
+                    gen = generators.execute(self.connection.pgconn)
+                    (result,) = await self.connection.wait(gen)
+                    if result.status == self.ExecStatus.FATAL_ERROR:
+                        raise e.error_from_result(result)
+                else:
+                    pgq.dump(vars)
+
+                self._send_query_prepared(b"", pgq)
+                gen = generators.execute(self.connection.pgconn)
+                (result,) = await self.connection.wait(gen)
+                self._execute_results((result,))
+
+        return self
+
     async def fetchone(self) -> Optional[Sequence[Any]]:
         rv = self._load_row(self._pos)
         if rv is not None:
             self._pos += 1
         return rv
 
+    async def fetchmany(
+        self, size: Optional[int] = None
+    ) -> List[Sequence[Any]]:
+        if size is None:
+            size = self.arraysize
+
+        rv: List[Sequence[Any]] = []
+        while len(rv) < size:
+            row = self._load_row(self._pos)
+            if row is None:
+                break
+            self._pos += 1
+            rv.append(row)
+
+        return rv
+
+    async def fetchall(self) -> List[Sequence[Any]]:
+        rv: List[Sequence[Any]] = []
+        while 1:
+            row = self._load_row(self._pos)
+            if row is None:
+                break
+            self._pos += 1
+            rv.append(row)
+
+        return rv
+
 
 class NamedCursorMixin:
     pass
index 1f785bdb89a4febcdc24789b27b17ac92b2cdc6a..aa749ee4b055316eef399e25123d3c317b00ec51 100644 (file)
@@ -75,3 +75,109 @@ def test_fetchone(aconn, loop):
     assert row[2] is None
     row = loop.run_until_complete(cur.fetchone())
     assert row is None
+
+
+def test_execute_binary_result(aconn, loop):
+    cur = aconn.cursor(binary=True)
+    loop.run_until_complete(cur.execute("select %s, %s", ["foo", None]))
+    assert cur.pgresult.fformat(0) == 1
+
+    row = loop.run_until_complete(cur.fetchone())
+    assert row[0] == "foo"
+    assert row[1] is None
+    row = loop.run_until_complete(cur.fetchone())
+    assert row is None
+
+
+@pytest.mark.parametrize("encoding", ["utf8", "latin9"])
+def test_query_encode(aconn, loop, encoding):
+    loop.run_until_complete(aconn.set_client_encoding(encoding))
+    cur = aconn.cursor()
+    loop.run_until_complete(cur.execute("select '\u20ac'"))
+    (res,) = loop.run_until_complete(cur.fetchone())
+    assert res == "\u20ac"
+
+
+def test_query_badenc(aconn, loop):
+    loop.run_until_complete(aconn.set_client_encoding("latin1"))
+    cur = aconn.cursor()
+    with pytest.raises(UnicodeEncodeError):
+        loop.run_until_complete(cur.execute("select '\u20ac'"))
+
+
+@pytest.fixture(scope="session")
+def _execmany(svcconn):
+    cur = svcconn.cursor()
+    cur.execute(
+        """
+        drop table if exists execmany;
+        create table execmany (id serial primary key, num integer, data text)
+        """
+    )
+
+
+@pytest.fixture(scope="function")
+def execmany(svcconn, _execmany):
+    cur = svcconn.cursor()
+    cur.execute("truncate table execmany")
+
+
+def test_executemany(aconn, loop, execmany):
+    cur = aconn.cursor()
+    loop.run_until_complete(
+        cur.executemany(
+            "insert into execmany(num, data) values (%s, %s)",
+            [(10, "hello"), (20, "world")],
+        )
+    )
+    loop.run_until_complete(
+        cur.execute("select num, data from execmany order by 1")
+    )
+    rv = loop.run_until_complete(cur.fetchall())
+    assert rv == [(10, "hello"), (20, "world")]
+
+
+def test_executemany_name(aconn, loop, execmany):
+    cur = aconn.cursor()
+    loop.run_until_complete(
+        cur.executemany(
+            "insert into execmany(num, data) values (%(num)s, %(data)s)",
+            [
+                {"num": 11, "data": "hello", "x": 1},
+                {"num": 21, "data": "world"},
+            ],
+        )
+    )
+    loop.run_until_complete(
+        cur.execute("select num, data from execmany order by 1")
+    )
+    rv = loop.run_until_complete(cur.fetchall())
+    assert rv == [(11, "hello"), (21, "world")]
+
+
+@pytest.mark.xfail
+def test_executemany_rowcount(aconn, loop, execmany):
+    cur = aconn.cursor()
+    loop.run_until_complete(
+        cur.executemany(
+            "insert into execmany(num, data) values (%s, %s)",
+            [(10, "hello"), (20, "world")],
+        )
+    )
+    assert cur.rowcount == 2
+
+
+@pytest.mark.parametrize(
+    "query",
+    [
+        "insert into nosuchtable values (%s, %s)",
+        "copy (select %s, %s) to stdout",
+        "wat (%s, %s)",
+    ],
+)
+def test_executemany_badquery(aconn, loop, query):
+    cur = aconn.cursor()
+    with pytest.raises(psycopg3.DatabaseError):
+        loop.run_until_complete(
+            cur.executemany(query, [(10, "hello"), (20, "world")])
+        )