]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added cursor.status property and exposing ExecStatus on the cursor
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 11 Apr 2020 11:38:18 +0000 (23:38 +1200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 11 Apr 2020 11:42:02 +0000 (23:42 +1200)
psycopg3/cursor.py
tests/test_async_cursor.py
tests/test_cursor.py

index 2cf593e2710f98a9785c6bd29a463600f2fea91e..c2ac2da8a03643df1d568d88eb4631377c07e74d 100644 (file)
@@ -9,19 +9,19 @@ from operator import attrgetter
 from typing import Any, List, Mapping, Optional, Sequence, Tuple, TYPE_CHECKING
 
 from . import errors as e
-from .pq import ConnStatus, ExecStatus, PGresult, Format
+from . import pq
 from .utils.queries import query2pg, reorder_params
 from .utils.typing import Query, Params
 
 if TYPE_CHECKING:
     from .connection import BaseConnection, Connection, AsyncConnection
     from .connection import QueryGen
-    from .adapt import DumpersMap, LoadersMap
+    from .adapt import DumpersMap, LoadersMap, Transformer
 
 
 class Column(Sequence[Any]):
     def __init__(
-        self, pgresult: PGresult, index: int, codec: codecs.CodecInfo
+        self, pgresult: pq.PGresult, index: int, codec: codecs.CodecInfo
     ):
         self._pgresult = pgresult
         self._index = index
@@ -58,6 +58,10 @@ class Column(Sequence[Any]):
 
 
 class BaseCursor:
+    ExecStatus = pq.ExecStatus
+
+    _transformer: "Transformer"
+
     def __init__(self, conn: "BaseConnection", binary: bool = False):
         self.conn = conn
         self.binary = binary
@@ -67,31 +71,37 @@ class BaseCursor:
         self.arraysize = 1
         self._closed = False
 
-    def __del__(self):
+    def _reset(self) -> None:
+        self._results: List[pq.PGresult] = []
+        self.pgresult = None
+        self._pos = 0
+        self._iresult = 0
+
+    def __del__(self) -> None:
         self.close()
 
     def close(self) -> None:
         self._closed = True
+        self._reset()
 
     @property
     def closed(self) -> bool:
         return self._closed
 
-    def _reset(self) -> None:
-        from .adapt import Transformer
-
-        self._transformer = Transformer(self)
-        self._results: List[PGresult] = []
-        self.pgresult: Optional[PGresult] = None
-        self._pos = 0
-        self._iresult = 0
+    @property
+    def status(self) -> Optional[pq.ExecStatus]:
+        res = self.pgresult
+        if res is not None:
+            return res.status
+        else:
+            return None
 
     @property
-    def pgresult(self) -> Optional[PGresult]:
+    def pgresult(self) -> Optional[pq.PGresult]:
         return self._pgresult
 
     @pgresult.setter
-    def pgresult(self, result: Optional[PGresult]) -> None:
+    def pgresult(self, result: Optional[pq.PGresult]) -> None:
         self._pgresult = result
         if result is not None and self._transformer is not None:
             self._transformer.set_row_types(
@@ -102,14 +112,14 @@ class BaseCursor:
     @property
     def description(self) -> Optional[List[Column]]:
         res = self.pgresult
-        if res is None or res.status != ExecStatus.TUPLES_OK:
+        if res is None or res.status != self.ExecStatus.TUPLES_OK:
             return None
         return [Column(res, i, self.conn.codec) for i in range(res.nfields)]
 
     @property
     def rowcount(self) -> int:
         res = self.pgresult
-        if res is None or res.status != ExecStatus.TUPLES_OK:
+        if res is None or res.status != self.ExecStatus.TUPLES_OK:
             return -1
         else:
             return res.ntuples
@@ -128,19 +138,22 @@ class BaseCursor:
         """
         Implement part of execute() before waiting common to sync and async
         """
+        from .adapt import Transformer
+
         if self.closed:
             raise e.OperationalError("the cursor is closed")
 
         if self.conn.closed:
             raise e.OperationalError("the connection is closed")
 
-        if self.conn.pgconn.status != ConnStatus.OK:
+        if self.conn.status != self.conn.ConnStatus.OK:
             raise e.InterfaceError(
                 f"cannot execute operations: the connection is"
-                f" in status {self.conn.pgconn.status}"
+                f" in status {self.conn.status}"
             )
 
         self._reset()
+        self._transformer = Transformer(self)
 
         codec = self.conn.codec
 
@@ -161,45 +174,39 @@ class BaseCursor:
                 params,
                 param_formats=formats,
                 param_types=types,
-                result_format=Format(self.binary),
+                result_format=pq.Format(self.binary),
             )
         else:
             # if we don't have to, let's use exec_ as it can run more than
             # one query in one go
             if self.binary:
                 self.conn.pgconn.send_query_params(
-                    query, (), result_format=Format(self.binary)
+                    query, (), result_format=pq.Format(self.binary)
                 )
             else:
                 self.conn.pgconn.send_query(query)
 
         return self.conn._exec_gen(self.conn.pgconn)
 
-    def _execute_results(self, results: List[PGresult]) -> None:
+    def _execute_results(self, results: List[pq.PGresult]) -> None:
         """
         Implement part of execute() after waiting common to sync and async
         """
         if not results:
             raise e.InternalError("got no result from the query")
 
-        badstats = {res.status for res in results} - {
-            ExecStatus.TUPLES_OK,
-            ExecStatus.COMMAND_OK,
-            ExecStatus.EMPTY_QUERY,
-        }
+        S = self.ExecStatus
+        statuses = {res.status for res in results}
+        badstats = statuses - {S.TUPLES_OK, S.COMMAND_OK, S.EMPTY_QUERY}
         if not badstats:
             self._results = results
             self.pgresult = results[0]
             return
 
-        if results[-1].status == ExecStatus.FATAL_ERROR:
+        if results[-1].status == S.FATAL_ERROR:
             raise e.error_from_result(results[-1])
 
-        elif badstats & {
-            ExecStatus.COPY_IN,
-            ExecStatus.COPY_OUT,
-            ExecStatus.COPY_BOTH,
-        }:
+        elif badstats & {S.COPY_IN, S.COPY_OUT, S.COPY_BOTH}:
             raise e.ProgrammingError(
                 "COPY cannot be used with execute(); use copy() insead"
             )
@@ -222,7 +229,7 @@ class BaseCursor:
         res = self.pgresult
         if res is None:
             raise e.ProgrammingError("no result available")
-        elif res.status != ExecStatus.TUPLES_OK:
+        elif res.status != self.ExecStatus.TUPLES_OK:
             raise e.ProgrammingError(
                 "the last operation didn't produce a result"
             )
index 7c258778e1cea521810dc9bdbf49434559451928..707fb87efdf7aba8685d7603ca8b3c13ed636f7c 100644 (file)
@@ -15,8 +15,21 @@ def test_close(aconn, loop):
     assert cur.closed
 
 
+def test_status(aconn, loop):
+    cur = aconn.cursor()
+    assert cur.status is None
+    loop.run_until_complete(cur.execute("reset all"))
+    assert cur.status == cur.ExecStatus.COMMAND_OK
+    loop.run_until_complete(cur.execute("select 1"))
+    assert cur.status == cur.ExecStatus.TUPLES_OK
+    cur.close()
+    assert cur.status is None
+
+
 def test_execute_many_results(aconn, loop):
     cur = aconn.cursor()
+    assert cur.nextset() is None
+
     rv = loop.run_until_complete(cur.execute("select 'foo'; select 'bar'"))
     assert rv is cur
     assert len(cur._results) == 2
@@ -25,6 +38,9 @@ def test_execute_many_results(aconn, loop):
     assert cur.pgresult.get_value(0, 0) == b"bar"
     assert cur.nextset() is None
 
+    cur.close()
+    assert cur.nextset() is None
+
 
 def test_execute_sequence(aconn, loop):
     cur = aconn.cursor()
@@ -37,3 +53,25 @@ def test_execute_sequence(aconn, loop):
     assert cur.pgresult.get_value(0, 1) == b"foo"
     assert cur.pgresult.get_value(0, 2) is None
     assert cur.nextset() is None
+
+
+@pytest.mark.parametrize("query", ["", " ", ";"])
+def test_execute_empty_query(aconn, loop, query):
+    cur = aconn.cursor()
+    loop.run_until_complete(cur.execute(query))
+    assert cur.status == cur.ExecStatus.EMPTY_QUERY
+    with pytest.raises(psycopg3.ProgrammingError):
+        loop.run_until_complete(cur.fetchone())
+
+
+def test_fetchone(aconn, loop):
+    cur = aconn.cursor()
+    loop.run_until_complete(cur.execute("select %s, %s, %s", [1, "foo", None]))
+    assert cur.pgresult.fformat(0) == 0
+
+    row = loop.run_until_complete(cur.fetchone())
+    assert row[0] == 1
+    assert row[1] == "foo"
+    assert row[2] is None
+    row = loop.run_until_complete(cur.fetchone())
+    assert row is None
index 962bed7c2605b99327f8c24e8a77601e868fde61..e9ca7a359c7eed3546e934435fca7b7ee770f23c 100644 (file)
@@ -15,8 +15,21 @@ def test_close(conn):
     assert cur.closed
 
 
+def test_status(conn):
+    cur = conn.cursor()
+    assert cur.status is None
+    cur.execute("reset all")
+    assert cur.status == cur.ExecStatus.COMMAND_OK
+    cur.execute("select 1")
+    assert cur.status == cur.ExecStatus.TUPLES_OK
+    cur.close()
+    assert cur.status is None
+
+
 def test_execute_many_results(conn):
     cur = conn.cursor()
+    assert cur.nextset() is None
+
     rv = cur.execute("select 'foo'; select 'bar'")
     assert rv is cur
     assert len(cur._results) == 2
@@ -25,6 +38,9 @@ def test_execute_many_results(conn):
     assert cur.pgresult.get_value(0, 0) == b"bar"
     assert cur.nextset() is None
 
+    cur.close()
+    assert cur.nextset() is None
+
 
 def test_execute_sequence(conn):
     cur = conn.cursor()
@@ -37,6 +53,15 @@ def test_execute_sequence(conn):
     assert cur.nextset() is None
 
 
+@pytest.mark.parametrize("query", ["", " ", ";"])
+def test_execute_empty_query(conn, query):
+    cur = conn.cursor()
+    cur.execute(query)
+    assert cur.status == cur.ExecStatus.EMPTY_QUERY
+    with pytest.raises(psycopg3.ProgrammingError):
+        cur.fetchone()
+
+
 def test_fetchone(conn):
     cur = conn.cursor()
     cur.execute("select %s, %s, %s", [1, "foo", None])