]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added first implementation of executemany based on prepared queries
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 11 Apr 2020 15:55:32 +0000 (03:55 +1200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 11 Apr 2020 15:56:30 +0000 (03:56 +1200)
There's a lot of repetition here

psycopg3/cursor.py
tests/test_cursor.py

index 545732c22134cddda7859431239d34089eebe899..6dfad11521ef8bf1ff42280f8e98ae36e3993fd8 100644 (file)
@@ -191,7 +191,7 @@ class BaseCursor:
 
         return generators.execute(self.connection.pgconn)
 
-    def _execute_results(self, results: List[pq.PGresult]) -> None:
+    def _execute_results(self, results: Sequence[pq.PGresult]) -> None:
         """
         Implement part of execute() after waiting common to sync and async
         """
@@ -202,7 +202,7 @@ class BaseCursor:
         statuses = {res.status for res in results}
         badstats = statuses - {S.TUPLES_OK, S.COMMAND_OK, S.EMPTY_QUERY}
         if not badstats:
-            self._results = results
+            self._results = list(results)
             self.pgresult = results[0]
             return
 
@@ -219,6 +219,75 @@ class BaseCursor:
                 f" {', '.join(sorted(s.name for s in sorted(badstats)))}"
             )
 
+    def _send_prepare(
+        self, name: bytes, query: Query, vars: Optional[Params]
+    ) -> "PQGen[List[pq.PGresult]]":
+        """
+        Implement part of execute() before waiting common to sync and async
+        """
+        from .adapt import Transformer
+
+        if self.closed:
+            raise e.OperationalError("the cursor is closed")
+
+        if self.connection.closed:
+            raise e.OperationalError("the connection is closed")
+
+        if self.connection.status != self.connection.ConnStatus.OK:
+            raise e.InterfaceError(
+                f"cannot execute operations: the connection is"
+                f" in status {self.connection.status}"
+            )
+
+        self._reset()
+        self._transformer = Transformer(self)
+
+        codec = self.connection.codec
+
+        if isinstance(query, str):
+            query = codec.encode(query)[0]
+
+        # process %% -> % only if there are paramters, even if empty list
+        if vars is not None:
+            query, formats, order = query2pg(query, vars, codec)
+
+        if order is not None:
+            assert isinstance(vars, Mapping)
+            vars = reorder_params(vars, order)
+        assert isinstance(vars, Sequence)
+        params, types = self._transformer.dump_sequence(vars, formats)
+        self.connection.pgconn.send_prepare(
+            name, query, param_types=types,
+        )
+        self._order = order
+        self._formats = formats
+        return generators.execute(self.connection.pgconn)
+
+    def _send_query_prepared(
+        self, name: bytes, vars: Optional[Params]
+    ) -> "PQGen[List[pq.PGresult]]":
+        if self.connection.closed:
+            raise e.OperationalError("the connection is closed")
+
+        if self.connection.status != self.connection.ConnStatus.OK:
+            raise e.InterfaceError(
+                f"cannot execute operations: the connection is"
+                f" in status {self.connection.status}"
+            )
+
+        if self._order is not None:
+            assert isinstance(vars, Mapping)
+            vars = reorder_params(vars, self._order)
+        assert isinstance(vars, Sequence)
+        params, types = self._transformer.dump_sequence(vars, self._formats)
+        self.connection.pgconn.send_query_prepared(
+            name,
+            params,
+            param_formats=self._formats,
+            result_format=pq.Format(self.binary),
+        )
+        return generators.execute(self.connection.pgconn)
+
     def nextset(self) -> Optional[bool]:
         self._iresult += 1
         if self._iresult < len(self._results):
@@ -264,10 +333,17 @@ class Cursor(BaseCursor):
         self, query: Query, vars_seq: Sequence[Params]
     ) -> "Cursor":
         with self.connection.lock:
-            for vars in vars_seq:
-                gen = self._execute_send(query, vars)
-                results = self.connection.wait(gen)
-                self._execute_results(results)
+            for i, vars in enumerate(vars_seq):
+                if i == 0:
+                    gen = self._send_prepare(b"", query, vars)
+                    (result,) = self.connection.wait(gen)
+                    if result.status == self.ExecStatus.FATAL_ERROR:
+                        raise e.error_from_result(result)
+
+                gen = self._send_query_prepared(b"", vars)
+                (result,) = self.connection.wait(gen)
+                self._execute_results((result,))
+
         return self
 
     def fetchone(self) -> Optional[Sequence[Any]]:
index e9ca7a359c7eed3546e934435fca7b7ee770f23c..86b589cbe6c9b8b28166faa0341507a08f1debc2 100644 (file)
@@ -100,3 +100,64 @@ def test_query_badenc(conn):
     cur = conn.cursor()
     with pytest.raises(UnicodeEncodeError):
         cur.execute("select '\u20ac'")
+
+
+@pytest.fixture(scope="session")
+def _execmany(svcconn):
+    cur = svcconn.cursor()
+    cur.execute(
+        """
+        drop table if exists execmany;
+        create table execmany (id serial primary key, num integer, data text)
+        """
+    )
+
+
+@pytest.fixture(scope="function")
+def execmany(svcconn, _execmany):
+    cur = svcconn.cursor()
+    cur.execute("truncate table execmany")
+
+
+def test_executemany(conn, execmany):
+    cur = conn.cursor()
+    cur.executemany(
+        "insert into execmany(num, data) values (%s, %s)",
+        [(10, "hello"), (20, "world")],
+    )
+    cur.execute("select num, data from execmany order by 1")
+    assert cur.fetchall() == [(10, "hello"), (20, "world")]
+
+
+def test_executemany_name(conn, execmany):
+    cur = conn.cursor()
+    cur.executemany(
+        "insert into execmany(num, data) values (%(num)s, %(data)s)",
+        [{"num": 11, "data": "hello", "x": 1}, {"num": 21, "data": "world"}],
+    )
+    cur.execute("select num, data from execmany order by 1")
+    assert cur.fetchall() == [(11, "hello"), (21, "world")]
+
+
+@pytest.mark.xfail
+def test_executemany_rowcount(conn, execmany):
+    cur = conn.cursor()
+    cur.executemany(
+        "insert into execmany(num, data) values (%s, %s)",
+        [(10, "hello"), (20, "world")],
+    )
+    assert cur.rowcount == 2
+
+
+@pytest.mark.parametrize(
+    "query",
+    [
+        "insert into nosuchtable values (%s, %s)",
+        "copy (select %s, %s) to stdout",
+        "wat (%s, %s)",
+    ],
+)
+def test_executemany_badquery(conn, query):
+    cur = conn.cursor()
+    with pytest.raises(psycopg3.DatabaseError):
+        cur.executemany(query, [(10, "hello"), (20, "world")])