]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added async execute
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 24 Mar 2020 06:59:29 +0000 (19:59 +1300)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 24 Mar 2020 08:05:54 +0000 (21:05 +1300)
psycopg3/cursor.py
tests/test_async_cursor.py [new file with mode: 0644]
tests/test_cursor.py

index e23fc5e01d2d8f6f87db56e8a4200585e376619c..ae803327ddb2cf8421d79275c31ed8a150448763 100644 (file)
@@ -17,64 +17,65 @@ class BaseCursor:
         self._result = None
         self._iresult = 0
 
+    def _execute_send(self, query, vars):
+        # Implement part of execute() before waiting common to sync and async
+        self._results = []
+        self._result = None
+        self._iresult = 0
+        codec = self.conn.codec
 
-class Cursor(BaseCursor):
-    def execute(self, query, vars=None):
-        with self.conn.lock:
-            self._results = []
-            self._result = None
-            self._iresult = 0
-            codec = self.conn.codec
-
-            if isinstance(query, str):
-                query = codec.encode(query)[0]
-
-            # process %% -> % only if there are paramters, even if empty list
-            if vars is not None:
-                query, order = query2pg(query, vars, codec)
-            if vars:
-                if order is not None:
-                    vars = reorder_params(vars, order)
-                params, formats = self._adapt_sequence(vars)
-                self.conn.pgconn.send_query_params(
-                    query, params, param_formats=formats
-                )
-            else:
-                self.conn.pgconn.send_query(query)
-
-            results = self.conn.wait(self.conn._exec_gen(self.conn.pgconn))
-            if not results:
-                raise exc.InternalError("got no result from the query")
-
-            badstats = {res.status for res in results} - {
-                ExecStatus.TUPLES_OK,
-                ExecStatus.COMMAND_OK,
-                ExecStatus.EMPTY_QUERY,
-            }
-            if not badstats:
-                self._results = results
-                self._result = results[0]
-                return self
-
-            if results[-1].status == ExecStatus.FATAL_ERROR:
-                ecls = exc.class_for_state(
-                    results[-1].error_field(DiagnosticField.SQLSTATE)
-                )
-                raise ecls(error_message(results[-1]))
-
-            elif badstats & {
-                ExecStatus.COPY_IN,
-                ExecStatus.COPY_OUT,
-                ExecStatus.COPY_BOTH,
-            }:
-                raise exc.ProgrammingError(
-                    "COPY cannot be used with execute(); use copy() insead"
-                )
-            else:
-                raise exc.InternalError(
-                    f"got unexpected status from query:"
-                    f" {', '.join(sorted(s.name for s in sorted(badstats)))}"
-                )
+        if isinstance(query, str):
+            query = codec.encode(query)[0]
+
+        # process %% -> % only if there are paramters, even if empty list
+        if vars is not None:
+            query, order = query2pg(query, vars, codec)
+        if vars:
+            if order is not None:
+                vars = reorder_params(vars, order)
+            params, formats = self._adapt_sequence(vars)
+            self.conn.pgconn.send_query_params(
+                query, params, param_formats=formats
+            )
+        else:
+            self.conn.pgconn.send_query(query)
+
+        return self.conn._exec_gen(self.conn.pgconn)
+
+    def _execute_results(self, results):
+        # Implement part of execute() after waiting common to sync and async
+        if not results:
+            raise exc.InternalError("got no result from the query")
+
+        badstats = {res.status for res in results} - {
+            ExecStatus.TUPLES_OK,
+            ExecStatus.COMMAND_OK,
+            ExecStatus.EMPTY_QUERY,
+        }
+        if not badstats:
+            self._results = results
+            self._result = results[0]
+            return
+
+        if results[-1].status == ExecStatus.FATAL_ERROR:
+            ecls = exc.class_for_state(
+                results[-1].error_field(DiagnosticField.SQLSTATE)
+            )
+            raise ecls(error_message(results[-1]))
+
+        elif badstats & {
+            ExecStatus.COPY_IN,
+            ExecStatus.COPY_OUT,
+            ExecStatus.COPY_BOTH,
+        }:
+            raise exc.ProgrammingError(
+                "COPY cannot be used with execute(); use copy() insead"
+            )
+        else:
+            raise exc.InternalError(
+                f"got unexpected status from query:"
+                f" {', '.join(sorted(s.name for s in sorted(badstats)))}"
+            )
 
     def nextset(self):
         self._iresult += 1
@@ -92,10 +93,22 @@ class Cursor(BaseCursor):
         return out, fmt
 
 
+class Cursor(BaseCursor):
+    def execute(self, query, vars=None):
+        with self.conn.lock:
+            gen = self._execute_send(query, vars)
+            results = self.conn.wait(gen)
+            self._execute_results(results)
+        return self
+
+
 class AsyncCursor(BaseCursor):
     async def execute(self, query, vars=None):
-        with self.conn.lock:
-            pass
+        with await self.conn.lock:
+            gen = self._execute_send(query, vars)
+            results = await self.conn.wait(gen)
+            self._execute_results(results)
+        return self
 
 
 class NamedCursorMixin:
diff --git a/tests/test_async_cursor.py b/tests/test_async_cursor.py
new file mode 100644 (file)
index 0000000..4c0dfd5
--- /dev/null
@@ -0,0 +1,22 @@
+def test_execute_many(aconn, loop):
+    cur = aconn.cursor()
+    rv = loop.run_until_complete(cur.execute("select 'foo'; select 'bar'"))
+    assert rv is cur
+    assert len(cur._results) == 2
+    assert cur._result.get_value(0, 0) == b"foo"
+    assert cur.nextset()
+    assert cur._result.get_value(0, 0) == b"bar"
+    assert cur.nextset() is None
+
+
+def test_execute_sequence(aconn, loop):
+    cur = aconn.cursor()
+    rv = loop.run_until_complete(
+        cur.execute("select %s, %s, %s", [1, "foo", None])
+    )
+    assert rv is cur
+    assert len(cur._results) == 1
+    assert cur._result.get_value(0, 0) == b"1"
+    assert cur._result.get_value(0, 1) == b"foo"
+    assert cur._result.get_value(0, 2) is None
+    assert cur.nextset() is None
index 6841c1c0f0f1bd1160af776d0883b0bb74dd0b50..52fc3f8172f9124a1016278c055cdf8c2525c084 100644 (file)
@@ -1,6 +1,7 @@
 def test_execute_many(conn):
     cur = conn.cursor()
-    cur.execute("select 'foo'; select 'bar'")
+    rv = cur.execute("select 'foo'; select 'bar'")
+    assert rv is cur
     assert len(cur._results) == 2
     assert cur._result.get_value(0, 0) == b"foo"
     assert cur.nextset()
@@ -10,7 +11,8 @@ def test_execute_many(conn):
 
 def test_execute_sequence(conn):
     cur = conn.cursor()
-    cur.execute("select %s, %s, %s", [1, "foo", None])
+    rv = cur.execute("select %s, %s, %s", [1, "foo", None])
+    assert rv is cur
     assert len(cur._results) == 1
     assert cur._result.get_value(0, 0) == b"1"
     assert cur._result.get_value(0, 1) == b"foo"