self._result = None
self._iresult = 0
+ def _execute_send(self, query, vars):
+ # Implement part of execute() before waiting common to sync and async
+ self._results = []
+ self._result = None
+ self._iresult = 0
+ codec = self.conn.codec
-class Cursor(BaseCursor):
- def execute(self, query, vars=None):
- with self.conn.lock:
- self._results = []
- self._result = None
- self._iresult = 0
- codec = self.conn.codec
-
- if isinstance(query, str):
- query = codec.encode(query)[0]
-
- # process %% -> % only if there are paramters, even if empty list
- if vars is not None:
- query, order = query2pg(query, vars, codec)
- if vars:
- if order is not None:
- vars = reorder_params(vars, order)
- params, formats = self._adapt_sequence(vars)
- self.conn.pgconn.send_query_params(
- query, params, param_formats=formats
- )
- else:
- self.conn.pgconn.send_query(query)
-
- results = self.conn.wait(self.conn._exec_gen(self.conn.pgconn))
- if not results:
- raise exc.InternalError("got no result from the query")
-
- badstats = {res.status for res in results} - {
- ExecStatus.TUPLES_OK,
- ExecStatus.COMMAND_OK,
- ExecStatus.EMPTY_QUERY,
- }
- if not badstats:
- self._results = results
- self._result = results[0]
- return self
-
- if results[-1].status == ExecStatus.FATAL_ERROR:
- ecls = exc.class_for_state(
- results[-1].error_field(DiagnosticField.SQLSTATE)
- )
- raise ecls(error_message(results[-1]))
-
- elif badstats & {
- ExecStatus.COPY_IN,
- ExecStatus.COPY_OUT,
- ExecStatus.COPY_BOTH,
- }:
- raise exc.ProgrammingError(
- "COPY cannot be used with execute(); use copy() insead"
- )
- else:
- raise exc.InternalError(
- f"got unexpected status from query:"
- f" {', '.join(sorted(s.name for s in sorted(badstats)))}"
- )
+ if isinstance(query, str):
+ query = codec.encode(query)[0]
+
+ # process %% -> % only if there are paramters, even if empty list
+ if vars is not None:
+ query, order = query2pg(query, vars, codec)
+ if vars:
+ if order is not None:
+ vars = reorder_params(vars, order)
+ params, formats = self._adapt_sequence(vars)
+ self.conn.pgconn.send_query_params(
+ query, params, param_formats=formats
+ )
+ else:
+ self.conn.pgconn.send_query(query)
+
+ return self.conn._exec_gen(self.conn.pgconn)
+
+ def _execute_results(self, results):
+ # Implement part of execute() after waiting common to sync and async
+ if not results:
+ raise exc.InternalError("got no result from the query")
+
+ badstats = {res.status for res in results} - {
+ ExecStatus.TUPLES_OK,
+ ExecStatus.COMMAND_OK,
+ ExecStatus.EMPTY_QUERY,
+ }
+ if not badstats:
+ self._results = results
+ self._result = results[0]
+ return
+
+ if results[-1].status == ExecStatus.FATAL_ERROR:
+ ecls = exc.class_for_state(
+ results[-1].error_field(DiagnosticField.SQLSTATE)
+ )
+ raise ecls(error_message(results[-1]))
+
+ elif badstats & {
+ ExecStatus.COPY_IN,
+ ExecStatus.COPY_OUT,
+ ExecStatus.COPY_BOTH,
+ }:
+ raise exc.ProgrammingError(
+ "COPY cannot be used with execute(); use copy() insead"
+ )
+ else:
+ raise exc.InternalError(
+ f"got unexpected status from query:"
+ f" {', '.join(sorted(s.name for s in sorted(badstats)))}"
+ )
def nextset(self):
self._iresult += 1
return out, fmt
+class Cursor(BaseCursor):
+ def execute(self, query, vars=None):
+ with self.conn.lock:
+ gen = self._execute_send(query, vars)
+ results = self.conn.wait(gen)
+ self._execute_results(results)
+ return self
+
+
class AsyncCursor(BaseCursor):
async def execute(self, query, vars=None):
- with self.conn.lock:
- pass
+ with await self.conn.lock:
+ gen = self._execute_send(query, vars)
+ results = await self.conn.wait(gen)
+ self._execute_results(results)
+ return self
class NamedCursorMixin:
--- /dev/null
+def test_execute_many(aconn, loop):
+ cur = aconn.cursor()
+ rv = loop.run_until_complete(cur.execute("select 'foo'; select 'bar'"))
+ assert rv is cur
+ assert len(cur._results) == 2
+ assert cur._result.get_value(0, 0) == b"foo"
+ assert cur.nextset()
+ assert cur._result.get_value(0, 0) == b"bar"
+ assert cur.nextset() is None
+
+
+def test_execute_sequence(aconn, loop):
+ cur = aconn.cursor()
+ rv = loop.run_until_complete(
+ cur.execute("select %s, %s, %s", [1, "foo", None])
+ )
+ assert rv is cur
+ assert len(cur._results) == 1
+ assert cur._result.get_value(0, 0) == b"1"
+ assert cur._result.get_value(0, 1) == b"foo"
+ assert cur._result.get_value(0, 2) is None
+ assert cur.nextset() is None
def test_execute_many(conn):
cur = conn.cursor()
- cur.execute("select 'foo'; select 'bar'")
+ rv = cur.execute("select 'foo'; select 'bar'")
+ assert rv is cur
assert len(cur._results) == 2
assert cur._result.get_value(0, 0) == b"foo"
assert cur.nextset()
def test_execute_sequence(conn):
cur = conn.cursor()
- cur.execute("select %s, %s, %s", [1, "foo", None])
+ rv = cur.execute("select %s, %s, %s", [1, "foo", None])
+ assert rv is cur
assert len(cur._results) == 1
assert cur._result.get_value(0, 0) == b"1"
assert cur._result.get_value(0, 1) == b"foo"