From 872fe952cb35dafda0fad8a24a8980229d2bc7b7 Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Tue, 24 Mar 2020 19:59:29 +1300 Subject: [PATCH] Added async execute --- psycopg3/cursor.py | 131 ++++++++++++++++++++----------------- tests/test_async_cursor.py | 22 +++++++ tests/test_cursor.py | 6 +- 3 files changed, 98 insertions(+), 61 deletions(-) create mode 100644 tests/test_async_cursor.py diff --git a/psycopg3/cursor.py b/psycopg3/cursor.py index e23fc5e01..ae803327d 100644 --- a/psycopg3/cursor.py +++ b/psycopg3/cursor.py @@ -17,64 +17,65 @@ class BaseCursor: self._result = None self._iresult = 0 + def _execute_send(self, query, vars): + # Implement part of execute() before waiting common to sync and async + self._results = [] + self._result = None + self._iresult = 0 + codec = self.conn.codec -class Cursor(BaseCursor): - def execute(self, query, vars=None): - with self.conn.lock: - self._results = [] - self._result = None - self._iresult = 0 - codec = self.conn.codec - - if isinstance(query, str): - query = codec.encode(query)[0] - - # process %% -> % only if there are paramters, even if empty list - if vars is not None: - query, order = query2pg(query, vars, codec) - if vars: - if order is not None: - vars = reorder_params(vars, order) - params, formats = self._adapt_sequence(vars) - self.conn.pgconn.send_query_params( - query, params, param_formats=formats - ) - else: - self.conn.pgconn.send_query(query) - - results = self.conn.wait(self.conn._exec_gen(self.conn.pgconn)) - if not results: - raise exc.InternalError("got no result from the query") - - badstats = {res.status for res in results} - { - ExecStatus.TUPLES_OK, - ExecStatus.COMMAND_OK, - ExecStatus.EMPTY_QUERY, - } - if not badstats: - self._results = results - self._result = results[0] - return self - - if results[-1].status == ExecStatus.FATAL_ERROR: - ecls = exc.class_for_state( - results[-1].error_field(DiagnosticField.SQLSTATE) - ) - raise ecls(error_message(results[-1])) - - elif badstats & { - ExecStatus.COPY_IN, - ExecStatus.COPY_OUT, - ExecStatus.COPY_BOTH, - }: - raise exc.ProgrammingError( - "COPY cannot be used with execute(); use copy() insead" - ) - else: - raise exc.InternalError( - f"got unexpected status from query:" - f" {', '.join(sorted(s.name for s in sorted(badstats)))}" - ) + if isinstance(query, str): + query = codec.encode(query)[0] + + # process %% -> % only if there are paramters, even if empty list + if vars is not None: + query, order = query2pg(query, vars, codec) + if vars: + if order is not None: + vars = reorder_params(vars, order) + params, formats = self._adapt_sequence(vars) + self.conn.pgconn.send_query_params( + query, params, param_formats=formats + ) + else: + self.conn.pgconn.send_query(query) + + return self.conn._exec_gen(self.conn.pgconn) + + def _execute_results(self, results): + # Implement part of execute() after waiting common to sync and async + if not results: + raise exc.InternalError("got no result from the query") + + badstats = {res.status for res in results} - { + ExecStatus.TUPLES_OK, + ExecStatus.COMMAND_OK, + ExecStatus.EMPTY_QUERY, + } + if not badstats: + self._results = results + self._result = results[0] + return + + if results[-1].status == ExecStatus.FATAL_ERROR: + ecls = exc.class_for_state( + results[-1].error_field(DiagnosticField.SQLSTATE) + ) + raise ecls(error_message(results[-1])) + + elif badstats & { + ExecStatus.COPY_IN, + ExecStatus.COPY_OUT, + ExecStatus.COPY_BOTH, + }: + raise exc.ProgrammingError( + "COPY cannot be used with execute(); use copy() insead" + ) + else: + raise exc.InternalError( + f"got unexpected status from query:" + f" {', '.join(sorted(s.name for s in sorted(badstats)))}" + ) def nextset(self): self._iresult += 1 @@ -92,10 +93,22 @@ class Cursor(BaseCursor): return out, fmt +class Cursor(BaseCursor): + def execute(self, query, vars=None): + with self.conn.lock: + gen = self._execute_send(query, vars) + results = self.conn.wait(gen) + self._execute_results(results) + return self + + class AsyncCursor(BaseCursor): async def execute(self, query, vars=None): - with self.conn.lock: - pass + with await self.conn.lock: + gen = self._execute_send(query, vars) + results = await self.conn.wait(gen) + self._execute_results(results) + return self class NamedCursorMixin: diff --git a/tests/test_async_cursor.py b/tests/test_async_cursor.py new file mode 100644 index 000000000..4c0dfd57c --- /dev/null +++ b/tests/test_async_cursor.py @@ -0,0 +1,22 @@ +def test_execute_many(aconn, loop): + cur = aconn.cursor() + rv = loop.run_until_complete(cur.execute("select 'foo'; select 'bar'")) + assert rv is cur + assert len(cur._results) == 2 + assert cur._result.get_value(0, 0) == b"foo" + assert cur.nextset() + assert cur._result.get_value(0, 0) == b"bar" + assert cur.nextset() is None + + +def test_execute_sequence(aconn, loop): + cur = aconn.cursor() + rv = loop.run_until_complete( + cur.execute("select %s, %s, %s", [1, "foo", None]) + ) + assert rv is cur + assert len(cur._results) == 1 + assert cur._result.get_value(0, 0) == b"1" + assert cur._result.get_value(0, 1) == b"foo" + assert cur._result.get_value(0, 2) is None + assert cur.nextset() is None diff --git a/tests/test_cursor.py b/tests/test_cursor.py index 6841c1c0f..52fc3f817 100644 --- a/tests/test_cursor.py +++ b/tests/test_cursor.py @@ -1,6 +1,7 @@ def test_execute_many(conn): cur = conn.cursor() - cur.execute("select 'foo'; select 'bar'") + rv = cur.execute("select 'foo'; select 'bar'") + assert rv is cur assert len(cur._results) == 2 assert cur._result.get_value(0, 0) == b"foo" assert cur.nextset() @@ -10,7 +11,8 @@ def test_execute_many(conn): def test_execute_sequence(conn): cur = conn.cursor() - cur.execute("select %s, %s, %s", [1, "foo", None]) + rv = cur.execute("select %s, %s, %s", [1, "foo", None]) + assert rv is cur assert len(cur._results) == 1 assert cur._result.get_value(0, 0) == b"1" assert cur._result.get_value(0, 1) == b"foo" -- 2.47.3