]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
refactor: move the BaseCursor class to its own module
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 19 Aug 2023 17:18:59 +0000 (18:18 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 11 Oct 2023 21:45:38 +0000 (23:45 +0200)
psycopg/psycopg/_column.py
psycopg/psycopg/_cursor_base.py [new file with mode: 0644]
psycopg/psycopg/_pipeline.py
psycopg/psycopg/client_cursor.py
psycopg/psycopg/copy.py
psycopg/psycopg/cursor.py
psycopg/psycopg/cursor_async.py
psycopg/psycopg/raw_cursor.py
psycopg/psycopg/rows.py
psycopg/psycopg/server_cursor.py

index 50577e637cb442d7ce1bd17143b34f5154e6a0b4..331df6266bc757942796b1a1a9111cc50abbc4fc 100644 (file)
@@ -8,7 +8,7 @@ from typing import Any, NamedTuple, Optional, Sequence, TYPE_CHECKING
 from operator import attrgetter
 
 if TYPE_CHECKING:
-    from .cursor import BaseCursor
+    from ._cursor_base import BaseCursor
 
 
 class ColumnData(NamedTuple):
diff --git a/psycopg/psycopg/_cursor_base.py b/psycopg/psycopg/_cursor_base.py
new file mode 100644 (file)
index 0000000..1f53a88
--- /dev/null
@@ -0,0 +1,624 @@
+"""
+Psycopg BaseCursor object
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from functools import partial
+from typing import Any, Generic, Iterable, List
+from typing import Optional, NoReturn, Sequence, Tuple, Type
+from typing import TYPE_CHECKING
+
+from . import pq
+from . import adapt
+from . import errors as e
+from .abc import ConnectionType, Query, Params, PQGen
+from .rows import Row, RowMaker
+from ._column import Column
+from .pq.misc import connection_summary
+from ._queries import PostgresQuery, PostgresClientQuery
+from ._encodings import pgconn_encoding
+from ._preparing import Prepare
+from .generators import execute, fetch, send
+
+if TYPE_CHECKING:
+    from .abc import Transformer
+    from .pq.abc import PGconn, PGresult
+
+TEXT = pq.Format.TEXT
+BINARY = pq.Format.BINARY
+
+EMPTY_QUERY = pq.ExecStatus.EMPTY_QUERY
+COMMAND_OK = pq.ExecStatus.COMMAND_OK
+TUPLES_OK = pq.ExecStatus.TUPLES_OK
+COPY_OUT = pq.ExecStatus.COPY_OUT
+COPY_IN = pq.ExecStatus.COPY_IN
+COPY_BOTH = pq.ExecStatus.COPY_BOTH
+FATAL_ERROR = pq.ExecStatus.FATAL_ERROR
+SINGLE_TUPLE = pq.ExecStatus.SINGLE_TUPLE
+PIPELINE_ABORTED = pq.ExecStatus.PIPELINE_ABORTED
+
+ACTIVE = pq.TransactionStatus.ACTIVE
+
+
+class BaseCursor(Generic[ConnectionType, Row]):
+    __slots__ = """
+        _conn format _adapters arraysize _closed _results pgresult _pos
+        _iresult _rowcount _query _tx _last_query _row_factory _make_row
+        _pgconn _execmany_returning
+        __weakref__
+        """.split()
+
+    ExecStatus = pq.ExecStatus
+
+    _tx: "Transformer"
+    _make_row: RowMaker[Row]
+    _pgconn: "PGconn"
+    _query_cls: Type[PostgresQuery] = PostgresQuery
+
+    def __init__(self, connection: ConnectionType):
+        self._conn = connection
+        self.format = TEXT
+        self._pgconn = connection.pgconn
+        self._adapters = adapt.AdaptersMap(connection.adapters)
+        self.arraysize = 1
+        self._closed = False
+        self._last_query: Optional[Query] = None
+        self._reset()
+
+    def _reset(self, reset_query: bool = True) -> None:
+        self._results: List["PGresult"] = []
+        self.pgresult: Optional["PGresult"] = None
+        self._pos = 0
+        self._iresult = 0
+        self._rowcount = -1
+        self._query: Optional[PostgresQuery]
+        # None if executemany() not executing, True/False according to returning state
+        self._execmany_returning: Optional[bool] = None
+        if reset_query:
+            self._query = None
+
+    def __repr__(self) -> str:
+        cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
+        info = connection_summary(self._pgconn)
+        if self._closed:
+            status = "closed"
+        elif self.pgresult:
+            status = pq.ExecStatus(self.pgresult.status).name
+        else:
+            status = "no result"
+        return f"<{cls} [{status}] {info} at 0x{id(self):x}>"
+
+    @property
+    def connection(self) -> ConnectionType:
+        """The connection this cursor is using."""
+        return self._conn
+
+    @property
+    def adapters(self) -> adapt.AdaptersMap:
+        return self._adapters
+
+    @property
+    def closed(self) -> bool:
+        """`True` if the cursor is closed."""
+        return self._closed
+
+    @property
+    def description(self) -> Optional[List[Column]]:
+        """
+        A list of `Column` objects describing the current resultset.
+
+        `!None` if the current resultset didn't return tuples.
+        """
+        res = self.pgresult
+
+        # We return columns if we have nfields, but also if we don't but
+        # the query said we got tuples (mostly to handle the super useful
+        # query "SELECT ;"
+        if res and (
+            res.nfields or res.status == TUPLES_OK or res.status == SINGLE_TUPLE
+        ):
+            return [Column(self, i) for i in range(res.nfields)]
+        else:
+            return None
+
+    @property
+    def rowcount(self) -> int:
+        """Number of records affected by the precedent operation."""
+        return self._rowcount
+
+    @property
+    def rownumber(self) -> Optional[int]:
+        """Index of the next row to fetch in the current result.
+
+        `!None` if there is no result to fetch.
+        """
+        tuples = self.pgresult and self.pgresult.status == TUPLES_OK
+        return self._pos if tuples else None
+
+    def setinputsizes(self, sizes: Sequence[Any]) -> None:
+        # no-op
+        pass
+
+    def setoutputsize(self, size: Any, column: Optional[int] = None) -> None:
+        # no-op
+        pass
+
+    def nextset(self) -> Optional[bool]:
+        """
+        Move to the result set of the next query executed through `executemany()`
+        or to the next result set if `execute()` returned more than one.
+
+        Return `!True` if a new result is available, which will be the one
+        methods `!fetch*()` will operate on.
+        """
+        if self._iresult < len(self._results) - 1:
+            self._select_current_result(self._iresult + 1)
+            return True
+        else:
+            return None
+
+    @property
+    def statusmessage(self) -> Optional[str]:
+        """
+        The command status tag from the last SQL command executed.
+
+        `!None` if the cursor doesn't have a result available.
+        """
+        msg = self.pgresult.command_status if self.pgresult else None
+        return msg.decode() if msg else None
+
+    def _make_row_maker(self) -> RowMaker[Row]:
+        raise NotImplementedError
+
+    #
+    # Generators for the high level operations on the cursor
+    #
+    # Like for sync/async connections, these are implemented as generators
+    # so that different concurrency strategies (threads,asyncio) can use their
+    # own way of waiting (or better, `connection.wait()`).
+    #
+
+    def _execute_gen(
+        self,
+        query: Query,
+        params: Optional[Params] = None,
+        *,
+        prepare: Optional[bool] = None,
+        binary: Optional[bool] = None,
+    ) -> PQGen[None]:
+        """Generator implementing `Cursor.execute()`."""
+        yield from self._start_query(query)
+        pgq = self._convert_query(query, params)
+        yield from self._maybe_prepare_gen(pgq, prepare=prepare, binary=binary)
+        if self._conn._pipeline:
+            yield from self._conn._pipeline._communicate_gen()
+
+        self._last_query = query
+
+        for cmd in self._conn._prepared.get_maintenance_commands():
+            yield from self._conn._exec_command(cmd)
+
+    def _executemany_gen_pipeline(
+        self, query: Query, params_seq: Iterable[Params], returning: bool
+    ) -> PQGen[None]:
+        """
+        Generator implementing `Cursor.executemany()` with pipelines available.
+        """
+        pipeline = self._conn._pipeline
+        assert pipeline
+
+        yield from self._start_query(query)
+        if not returning:
+            self._rowcount = 0
+
+        assert self._execmany_returning is None
+        self._execmany_returning = returning
+
+        first = True
+        for params in params_seq:
+            if first:
+                pgq = self._convert_query(query, params)
+                self._query = pgq
+                first = False
+            else:
+                pgq.dump(params)
+
+            yield from self._maybe_prepare_gen(pgq, prepare=True)
+            yield from pipeline._communicate_gen()
+
+        self._last_query = query
+
+        if returning:
+            yield from pipeline._fetch_gen(flush=True)
+
+        for cmd in self._conn._prepared.get_maintenance_commands():
+            yield from self._conn._exec_command(cmd)
+
+    def _executemany_gen_no_pipeline(
+        self, query: Query, params_seq: Iterable[Params], returning: bool
+    ) -> PQGen[None]:
+        """
+        Generator implementing `Cursor.executemany()` with pipelines not available.
+        """
+        yield from self._start_query(query)
+        if not returning:
+            self._rowcount = 0
+
+        assert self._execmany_returning is None
+        self._execmany_returning = returning
+
+        first = True
+        for params in params_seq:
+            if first:
+                pgq = self._convert_query(query, params)
+                self._query = pgq
+                first = False
+            else:
+                pgq.dump(params)
+
+            yield from self._maybe_prepare_gen(pgq, prepare=True)
+
+        self._last_query = query
+
+        for cmd in self._conn._prepared.get_maintenance_commands():
+            yield from self._conn._exec_command(cmd)
+
+    def _maybe_prepare_gen(
+        self,
+        pgq: PostgresQuery,
+        *,
+        prepare: Optional[bool] = None,
+        binary: Optional[bool] = None,
+    ) -> PQGen[None]:
+        # Check if the query is prepared or needs preparing
+        prep, name = self._get_prepared(pgq, prepare)
+        if prep is Prepare.NO:
+            # The query must be executed without preparing
+            self._execute_send(pgq, binary=binary)
+        else:
+            # If the query is not already prepared, prepare it.
+            if prep is Prepare.SHOULD:
+                self._send_prepare(name, pgq)
+                if not self._conn._pipeline:
+                    (result,) = yield from execute(self._pgconn)
+                    if result.status == FATAL_ERROR:
+                        raise e.error_from_result(result, encoding=self._encoding)
+            # Then execute it.
+            self._send_query_prepared(name, pgq, binary=binary)
+
+        # Update the prepare state of the query.
+        # If an operation requires to flush our prepared statements cache,
+        # it will be added to the maintenance commands to execute later.
+        key = self._conn._prepared.maybe_add_to_cache(pgq, prep, name)
+
+        if self._conn._pipeline:
+            queued = None
+            if key is not None:
+                queued = (key, prep, name)
+            self._conn._pipeline.result_queue.append((self, queued))
+            return
+
+        # run the query
+        results = yield from execute(self._pgconn)
+
+        if key is not None:
+            self._conn._prepared.validate(key, prep, name, results)
+
+        self._check_results(results)
+        self._set_results(results)
+
+    def _get_prepared(
+        self, pgq: PostgresQuery, prepare: Optional[bool] = None
+    ) -> Tuple[Prepare, bytes]:
+        return self._conn._prepared.get(pgq, prepare)
+
+    def _stream_send_gen(
+        self,
+        query: Query,
+        params: Optional[Params] = None,
+        *,
+        binary: Optional[bool] = None,
+    ) -> PQGen[None]:
+        """Generator to send the query for `Cursor.stream()`."""
+        yield from self._start_query(query)
+        pgq = self._convert_query(query, params)
+        self._execute_send(pgq, binary=binary, force_extended=True)
+        self._pgconn.set_single_row_mode()
+        self._last_query = query
+        yield from send(self._pgconn)
+
+    def _stream_fetchone_gen(self, first: bool) -> PQGen[Optional["PGresult"]]:
+        res = yield from fetch(self._pgconn)
+        if res is None:
+            return None
+
+        status = res.status
+        if status == SINGLE_TUPLE:
+            self.pgresult = res
+            self._tx.set_pgresult(res, set_loaders=first)
+            if first:
+                self._make_row = self._make_row_maker()
+            return res
+
+        elif status == TUPLES_OK or status == COMMAND_OK:
+            # End of single row results
+            while res:
+                res = yield from fetch(self._pgconn)
+            if status != TUPLES_OK:
+                raise e.ProgrammingError(
+                    "the operation in stream() didn't produce a result"
+                )
+            return None
+
+        else:
+            # Errors, unexpected values
+            return self._raise_for_result(res)
+
+    def _start_query(self, query: Optional[Query] = None) -> PQGen[None]:
+        """Generator to start the processing of a query.
+
+        It is implemented as generator because it may send additional queries,
+        such as `begin`.
+        """
+        if self.closed:
+            raise e.InterfaceError("the cursor is closed")
+
+        self._reset()
+        if not self._last_query or (self._last_query is not query):
+            self._last_query = None
+            self._tx = adapt.Transformer(self)
+        yield from self._conn._start_query()
+
+    def _start_copy_gen(
+        self, statement: Query, params: Optional[Params] = None
+    ) -> PQGen[None]:
+        """Generator implementing sending a command for `Cursor.copy()."""
+
+        # The connection gets in an unrecoverable state if we attempt COPY in
+        # pipeline mode. Forbid it explicitly.
+        if self._conn._pipeline:
+            raise e.NotSupportedError("COPY cannot be used in pipeline mode")
+
+        yield from self._start_query()
+
+        # Merge the params client-side
+        if params:
+            pgq = PostgresClientQuery(self._tx)
+            pgq.convert(statement, params)
+            statement = pgq.query
+
+        query = self._convert_query(statement)
+
+        self._execute_send(query, binary=False)
+        results = yield from execute(self._pgconn)
+        if len(results) != 1:
+            raise e.ProgrammingError("COPY cannot be mixed with other operations")
+
+        self._check_copy_result(results[0])
+        self._set_results(results)
+
+    def _execute_send(
+        self,
+        query: PostgresQuery,
+        *,
+        force_extended: bool = False,
+        binary: Optional[bool] = None,
+    ) -> None:
+        """
+        Implement part of execute() before waiting common to sync and async.
+
+        This is not a generator, but a normal non-blocking function.
+        """
+        if binary is None:
+            fmt = self.format
+        else:
+            fmt = BINARY if binary else TEXT
+
+        self._query = query
+
+        if self._conn._pipeline:
+            # In pipeline mode always use PQsendQueryParams - see #314
+            # Multiple statements in the same query are not allowed anyway.
+            self._conn._pipeline.command_queue.append(
+                partial(
+                    self._pgconn.send_query_params,
+                    query.query,
+                    query.params,
+                    param_formats=query.formats,
+                    param_types=query.types,
+                    result_format=fmt,
+                )
+            )
+        elif force_extended or query.params or fmt == BINARY:
+            self._pgconn.send_query_params(
+                query.query,
+                query.params,
+                param_formats=query.formats,
+                param_types=query.types,
+                result_format=fmt,
+            )
+        else:
+            # If we can, let's use simple query protocol,
+            # as it can execute more than one statement in a single query.
+            self._pgconn.send_query(query.query)
+
+    def _convert_query(
+        self, query: Query, params: Optional[Params] = None
+    ) -> PostgresQuery:
+        pgq = self._query_cls(self._tx)
+        pgq.convert(query, params)
+        return pgq
+
+    def _check_results(self, results: List["PGresult"]) -> None:
+        """
+        Verify that the results of a query are valid.
+
+        Verify that the query returned at least one result and that they all
+        represent a valid result from the database.
+        """
+        if not results:
+            raise e.InternalError("got no result from the query")
+
+        for res in results:
+            status = res.status
+            if status != TUPLES_OK and status != COMMAND_OK and status != EMPTY_QUERY:
+                self._raise_for_result(res)
+
+    def _raise_for_result(self, result: "PGresult") -> NoReturn:
+        """
+        Raise an appropriate error message for an unexpected database result
+        """
+        status = result.status
+        if status == FATAL_ERROR:
+            raise e.error_from_result(result, encoding=self._encoding)
+        elif status == PIPELINE_ABORTED:
+            raise e.PipelineAborted("pipeline aborted")
+        elif status == COPY_IN or status == COPY_OUT or status == COPY_BOTH:
+            raise e.ProgrammingError(
+                "COPY cannot be used with this method; use copy() instead"
+            )
+        else:
+            raise e.InternalError(
+                "unexpected result status from query:" f" {pq.ExecStatus(status).name}"
+            )
+
+    def _select_current_result(
+        self, i: int, format: Optional[pq.Format] = None
+    ) -> None:
+        """
+        Select one of the results in the cursor as the active one.
+        """
+        self._iresult = i
+        res = self.pgresult = self._results[i]
+
+        # Note: the only reason to override format is to correctly set
+        # binary loaders on server-side cursors, because send_describe_portal
+        # only returns a text result.
+        self._tx.set_pgresult(res, format=format)
+
+        self._pos = 0
+
+        if res.status == TUPLES_OK:
+            self._rowcount = self.pgresult.ntuples
+
+        # COPY_OUT has never info about nrows. We need such result for the
+        # columns in order to return a `description`, but not overwrite the
+        # cursor rowcount (which was set by the Copy object).
+        elif res.status != COPY_OUT:
+            nrows = self.pgresult.command_tuples
+            self._rowcount = nrows if nrows is not None else -1
+
+        self._make_row = self._make_row_maker()
+
+    def _set_results(self, results: List["PGresult"]) -> None:
+        if self._execmany_returning is None:
+            # Received from execute()
+            self._results[:] = results
+            self._select_current_result(0)
+
+        else:
+            # Received from executemany()
+            if self._execmany_returning:
+                first_batch = not self._results
+                self._results.extend(results)
+                if first_batch:
+                    self._select_current_result(0)
+            else:
+                # In non-returning case, set rowcount to the cumulated number of
+                # rows of executed queries.
+                for res in results:
+                    self._rowcount += res.command_tuples or 0
+
+    def _send_prepare(self, name: bytes, query: PostgresQuery) -> None:
+        if self._conn._pipeline:
+            self._conn._pipeline.command_queue.append(
+                partial(
+                    self._pgconn.send_prepare,
+                    name,
+                    query.query,
+                    param_types=query.types,
+                )
+            )
+            self._conn._pipeline.result_queue.append(None)
+        else:
+            self._pgconn.send_prepare(name, query.query, param_types=query.types)
+
+    def _send_query_prepared(
+        self, name: bytes, pgq: PostgresQuery, *, binary: Optional[bool] = None
+    ) -> None:
+        if binary is None:
+            fmt = self.format
+        else:
+            fmt = BINARY if binary else TEXT
+
+        if self._conn._pipeline:
+            self._conn._pipeline.command_queue.append(
+                partial(
+                    self._pgconn.send_query_prepared,
+                    name,
+                    pgq.params,
+                    param_formats=pgq.formats,
+                    result_format=fmt,
+                )
+            )
+        else:
+            self._pgconn.send_query_prepared(
+                name, pgq.params, param_formats=pgq.formats, result_format=fmt
+            )
+
+    def _check_result_for_fetch(self) -> None:
+        if self.closed:
+            raise e.InterfaceError("the cursor is closed")
+        res = self.pgresult
+        if not res:
+            raise e.ProgrammingError("no result available")
+
+        status = res.status
+        if status == TUPLES_OK:
+            return
+        elif status == FATAL_ERROR:
+            raise e.error_from_result(res, encoding=self._encoding)
+        elif status == PIPELINE_ABORTED:
+            raise e.PipelineAborted("pipeline aborted")
+        else:
+            raise e.ProgrammingError("the last operation didn't produce a result")
+
+    def _check_copy_result(self, result: "PGresult") -> None:
+        """
+        Check that the value returned in a copy() operation is a legit COPY.
+        """
+        status = result.status
+        if status == COPY_IN or status == COPY_OUT:
+            return
+        elif status == FATAL_ERROR:
+            raise e.error_from_result(result, encoding=self._encoding)
+        else:
+            raise e.ProgrammingError(
+                "copy() should be used only with COPY ... TO STDOUT or COPY ..."
+                f" FROM STDIN statements, got {pq.ExecStatus(status).name}"
+            )
+
+    def _scroll(self, value: int, mode: str) -> None:
+        self._check_result_for_fetch()
+        assert self.pgresult
+        if mode == "relative":
+            newpos = self._pos + value
+        elif mode == "absolute":
+            newpos = value
+        else:
+            raise ValueError(f"bad mode: {mode}. It should be 'relative' or 'absolute'")
+        if not 0 <= newpos < self.pgresult.ntuples:
+            raise IndexError("position out of bound")
+        self._pos = newpos
+
+    def _close(self) -> None:
+        """Non-blocking part of closing. Common to sync/async."""
+        # Don't reset the query because it may be useful to investigate after
+        # an error.
+        self._reset(reset_query=False)
+        self._closed = True
+
+    @property
+    def _encoding(self) -> str:
+        return pgconn_encoding(self._pgconn)
index 297e87c1d6fb0e9ad1abfd8379b8035bcf1e09f3..2223f491ac6bdd79e7a6543fc1138648c9aa109a 100644 (file)
@@ -20,7 +20,7 @@ from .generators import pipeline_communicate, fetch_many, send
 
 if TYPE_CHECKING:
     from .pq.abc import PGresult
-    from .cursor import BaseCursor
+    from ._cursor_base import BaseCursor
     from .connection import BaseConnection, Connection
     from .connection_async import AsyncConnection
 
index 77cdd4416abcc8ab135a4dff9f1d73e4dc7c4c52..24d7b45cb241df2dab76051b48fadf3c2759a108 100644 (file)
@@ -14,8 +14,9 @@ from . import adapt
 from . import errors as e
 from .abc import ConnectionType, Query, Params
 from .rows import Row
-from .cursor import BaseCursor, Cursor
+from .cursor import Cursor
 from ._preparing import Prepare
+from ._cursor_base import BaseCursor
 from .cursor_async import AsyncCursor
 
 if TYPE_CHECKING:
index d52e9b93d3dc03705b0c38f4afc9e830645c7345..bf54e90be90274ca97437a9d3a630a567fd8de97 100644 (file)
@@ -24,7 +24,8 @@ from ._encodings import pgconn_encoding
 from .generators import copy_from, copy_to, copy_end
 
 if TYPE_CHECKING:
-    from .cursor import BaseCursor, Cursor
+    from .cursor import Cursor
+    from ._cursor_base import BaseCursor
     from .cursor_async import AsyncCursor
     from .connection import Connection  # noqa: F401
     from .connection_async import AsyncConnection  # noqa: F401
index 7353f558d356164b185ce163f36f780c58b9387e..c26c73abfbdf24f68049b4b56631a408b2a4482e 100644 (file)
 """
-psycopg cursor objects
+Psycopg Cursor object
 """
 
 # Copyright (C) 2020 The Psycopg Team
 
-from functools import partial
 from types import TracebackType
-from typing import Any, Generic, Iterable, Iterator, List
-from typing import Optional, NoReturn, Sequence, Tuple, Type, TypeVar
+from typing import Any, Iterable, Iterator, List, Optional, Type, TypeVar
 from typing import overload, TYPE_CHECKING
 from contextlib import contextmanager
 
 from . import pq
-from . import adapt
 from . import errors as e
-from .abc import ConnectionType, Query, Params, PQGen
+from .abc import Query, Params
 from .copy import Copy, Writer as CopyWriter
 from .rows import Row, RowMaker, RowFactory
-from ._column import Column
-from .pq.misc import connection_summary
-from ._queries import PostgresQuery, PostgresClientQuery
 from ._pipeline import Pipeline
-from ._encodings import pgconn_encoding
-from ._preparing import Prepare
-from .generators import execute, fetch, send
+from ._cursor_base import BaseCursor
 
 if TYPE_CHECKING:
-    from .abc import Transformer
-    from .pq.abc import PGconn, PGresult
     from .connection import Connection
 
-TEXT = pq.Format.TEXT
-BINARY = pq.Format.BINARY
-
-EMPTY_QUERY = pq.ExecStatus.EMPTY_QUERY
-COMMAND_OK = pq.ExecStatus.COMMAND_OK
-TUPLES_OK = pq.ExecStatus.TUPLES_OK
-COPY_OUT = pq.ExecStatus.COPY_OUT
-COPY_IN = pq.ExecStatus.COPY_IN
-COPY_BOTH = pq.ExecStatus.COPY_BOTH
-FATAL_ERROR = pq.ExecStatus.FATAL_ERROR
-SINGLE_TUPLE = pq.ExecStatus.SINGLE_TUPLE
-PIPELINE_ABORTED = pq.ExecStatus.PIPELINE_ABORTED
-
 ACTIVE = pq.TransactionStatus.ACTIVE
 
 
-class BaseCursor(Generic[ConnectionType, Row]):
-    __slots__ = """
-        _conn format _adapters arraysize _closed _results pgresult _pos
-        _iresult _rowcount _query _tx _last_query _row_factory _make_row
-        _pgconn _execmany_returning
-        __weakref__
-        """.split()
-
-    ExecStatus = pq.ExecStatus
-
-    _tx: "Transformer"
-    _make_row: RowMaker[Row]
-    _pgconn: "PGconn"
-    _query_cls: Type[PostgresQuery] = PostgresQuery
-
-    def __init__(self, connection: ConnectionType):
-        self._conn = connection
-        self.format = TEXT
-        self._pgconn = connection.pgconn
-        self._adapters = adapt.AdaptersMap(connection.adapters)
-        self.arraysize = 1
-        self._closed = False
-        self._last_query: Optional[Query] = None
-        self._reset()
-
-    def _reset(self, reset_query: bool = True) -> None:
-        self._results: List["PGresult"] = []
-        self.pgresult: Optional["PGresult"] = None
-        self._pos = 0
-        self._iresult = 0
-        self._rowcount = -1
-        self._query: Optional[PostgresQuery]
-        # None if executemany() not executing, True/False according to returning state
-        self._execmany_returning: Optional[bool] = None
-        if reset_query:
-            self._query = None
-
-    def __repr__(self) -> str:
-        cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
-        info = connection_summary(self._pgconn)
-        if self._closed:
-            status = "closed"
-        elif self.pgresult:
-            status = pq.ExecStatus(self.pgresult.status).name
-        else:
-            status = "no result"
-        return f"<{cls} [{status}] {info} at 0x{id(self):x}>"
-
-    @property
-    def connection(self) -> ConnectionType:
-        """The connection this cursor is using."""
-        return self._conn
-
-    @property
-    def adapters(self) -> adapt.AdaptersMap:
-        return self._adapters
-
-    @property
-    def closed(self) -> bool:
-        """`True` if the cursor is closed."""
-        return self._closed
-
-    @property
-    def description(self) -> Optional[List[Column]]:
-        """
-        A list of `Column` objects describing the current resultset.
-
-        `!None` if the current resultset didn't return tuples.
-        """
-        res = self.pgresult
-
-        # We return columns if we have nfields, but also if we don't but
-        # the query said we got tuples (mostly to handle the super useful
-        # query "SELECT ;"
-        if res and (
-            res.nfields or res.status == TUPLES_OK or res.status == SINGLE_TUPLE
-        ):
-            return [Column(self, i) for i in range(res.nfields)]
-        else:
-            return None
-
-    @property
-    def rowcount(self) -> int:
-        """Number of records affected by the precedent operation."""
-        return self._rowcount
-
-    @property
-    def rownumber(self) -> Optional[int]:
-        """Index of the next row to fetch in the current result.
-
-        `!None` if there is no result to fetch.
-        """
-        tuples = self.pgresult and self.pgresult.status == TUPLES_OK
-        return self._pos if tuples else None
-
-    def setinputsizes(self, sizes: Sequence[Any]) -> None:
-        # no-op
-        pass
-
-    def setoutputsize(self, size: Any, column: Optional[int] = None) -> None:
-        # no-op
-        pass
-
-    def nextset(self) -> Optional[bool]:
-        """
-        Move to the result set of the next query executed through `executemany()`
-        or to the next result set if `execute()` returned more than one.
-
-        Return `!True` if a new result is available, which will be the one
-        methods `!fetch*()` will operate on.
-        """
-        if self._iresult < len(self._results) - 1:
-            self._select_current_result(self._iresult + 1)
-            return True
-        else:
-            return None
-
-    @property
-    def statusmessage(self) -> Optional[str]:
-        """
-        The command status tag from the last SQL command executed.
-
-        `!None` if the cursor doesn't have a result available.
-        """
-        msg = self.pgresult.command_status if self.pgresult else None
-        return msg.decode() if msg else None
-
-    def _make_row_maker(self) -> RowMaker[Row]:
-        raise NotImplementedError
-
-    #
-    # Generators for the high level operations on the cursor
-    #
-    # Like for sync/async connections, these are implemented as generators
-    # so that different concurrency strategies (threads,asyncio) can use their
-    # own way of waiting (or better, `connection.wait()`).
-    #
-
-    def _execute_gen(
-        self,
-        query: Query,
-        params: Optional[Params] = None,
-        *,
-        prepare: Optional[bool] = None,
-        binary: Optional[bool] = None,
-    ) -> PQGen[None]:
-        """Generator implementing `Cursor.execute()`."""
-        yield from self._start_query(query)
-        pgq = self._convert_query(query, params)
-        yield from self._maybe_prepare_gen(pgq, prepare=prepare, binary=binary)
-        if self._conn._pipeline:
-            yield from self._conn._pipeline._communicate_gen()
-
-        self._last_query = query
-
-        for cmd in self._conn._prepared.get_maintenance_commands():
-            yield from self._conn._exec_command(cmd)
-
-    def _executemany_gen_pipeline(
-        self, query: Query, params_seq: Iterable[Params], returning: bool
-    ) -> PQGen[None]:
-        """
-        Generator implementing `Cursor.executemany()` with pipelines available.
-        """
-        pipeline = self._conn._pipeline
-        assert pipeline
-
-        yield from self._start_query(query)
-        if not returning:
-            self._rowcount = 0
-
-        assert self._execmany_returning is None
-        self._execmany_returning = returning
-
-        first = True
-        for params in params_seq:
-            if first:
-                pgq = self._convert_query(query, params)
-                self._query = pgq
-                first = False
-            else:
-                pgq.dump(params)
-
-            yield from self._maybe_prepare_gen(pgq, prepare=True)
-            yield from pipeline._communicate_gen()
-
-        self._last_query = query
-
-        if returning:
-            yield from pipeline._fetch_gen(flush=True)
-
-        for cmd in self._conn._prepared.get_maintenance_commands():
-            yield from self._conn._exec_command(cmd)
-
-    def _executemany_gen_no_pipeline(
-        self, query: Query, params_seq: Iterable[Params], returning: bool
-    ) -> PQGen[None]:
-        """
-        Generator implementing `Cursor.executemany()` with pipelines not available.
-        """
-        yield from self._start_query(query)
-        if not returning:
-            self._rowcount = 0
-
-        assert self._execmany_returning is None
-        self._execmany_returning = returning
-
-        first = True
-        for params in params_seq:
-            if first:
-                pgq = self._convert_query(query, params)
-                self._query = pgq
-                first = False
-            else:
-                pgq.dump(params)
-
-            yield from self._maybe_prepare_gen(pgq, prepare=True)
-
-        self._last_query = query
-
-        for cmd in self._conn._prepared.get_maintenance_commands():
-            yield from self._conn._exec_command(cmd)
-
-    def _maybe_prepare_gen(
-        self,
-        pgq: PostgresQuery,
-        *,
-        prepare: Optional[bool] = None,
-        binary: Optional[bool] = None,
-    ) -> PQGen[None]:
-        # Check if the query is prepared or needs preparing
-        prep, name = self._get_prepared(pgq, prepare)
-        if prep is Prepare.NO:
-            # The query must be executed without preparing
-            self._execute_send(pgq, binary=binary)
-        else:
-            # If the query is not already prepared, prepare it.
-            if prep is Prepare.SHOULD:
-                self._send_prepare(name, pgq)
-                if not self._conn._pipeline:
-                    (result,) = yield from execute(self._pgconn)
-                    if result.status == FATAL_ERROR:
-                        raise e.error_from_result(result, encoding=self._encoding)
-            # Then execute it.
-            self._send_query_prepared(name, pgq, binary=binary)
-
-        # Update the prepare state of the query.
-        # If an operation requires to flush our prepared statements cache,
-        # it will be added to the maintenance commands to execute later.
-        key = self._conn._prepared.maybe_add_to_cache(pgq, prep, name)
-
-        if self._conn._pipeline:
-            queued = None
-            if key is not None:
-                queued = (key, prep, name)
-            self._conn._pipeline.result_queue.append((self, queued))
-            return
-
-        # run the query
-        results = yield from execute(self._pgconn)
-
-        if key is not None:
-            self._conn._prepared.validate(key, prep, name, results)
-
-        self._check_results(results)
-        self._set_results(results)
-
-    def _get_prepared(
-        self, pgq: PostgresQuery, prepare: Optional[bool] = None
-    ) -> Tuple[Prepare, bytes]:
-        return self._conn._prepared.get(pgq, prepare)
-
-    def _stream_send_gen(
-        self,
-        query: Query,
-        params: Optional[Params] = None,
-        *,
-        binary: Optional[bool] = None,
-    ) -> PQGen[None]:
-        """Generator to send the query for `Cursor.stream()`."""
-        yield from self._start_query(query)
-        pgq = self._convert_query(query, params)
-        self._execute_send(pgq, binary=binary, force_extended=True)
-        self._pgconn.set_single_row_mode()
-        self._last_query = query
-        yield from send(self._pgconn)
-
-    def _stream_fetchone_gen(self, first: bool) -> PQGen[Optional["PGresult"]]:
-        res = yield from fetch(self._pgconn)
-        if res is None:
-            return None
-
-        status = res.status
-        if status == SINGLE_TUPLE:
-            self.pgresult = res
-            self._tx.set_pgresult(res, set_loaders=first)
-            if first:
-                self._make_row = self._make_row_maker()
-            return res
-
-        elif status == TUPLES_OK or status == COMMAND_OK:
-            # End of single row results
-            while res:
-                res = yield from fetch(self._pgconn)
-            if status != TUPLES_OK:
-                raise e.ProgrammingError(
-                    "the operation in stream() didn't produce a result"
-                )
-            return None
-
-        else:
-            # Errors, unexpected values
-            return self._raise_for_result(res)
-
-    def _start_query(self, query: Optional[Query] = None) -> PQGen[None]:
-        """Generator to start the processing of a query.
-
-        It is implemented as generator because it may send additional queries,
-        such as `begin`.
-        """
-        if self.closed:
-            raise e.InterfaceError("the cursor is closed")
-
-        self._reset()
-        if not self._last_query or (self._last_query is not query):
-            self._last_query = None
-            self._tx = adapt.Transformer(self)
-        yield from self._conn._start_query()
-
-    def _start_copy_gen(
-        self, statement: Query, params: Optional[Params] = None
-    ) -> PQGen[None]:
-        """Generator implementing sending a command for `Cursor.copy()."""
-
-        # The connection gets in an unrecoverable state if we attempt COPY in
-        # pipeline mode. Forbid it explicitly.
-        if self._conn._pipeline:
-            raise e.NotSupportedError("COPY cannot be used in pipeline mode")
-
-        yield from self._start_query()
-
-        # Merge the params client-side
-        if params:
-            pgq = PostgresClientQuery(self._tx)
-            pgq.convert(statement, params)
-            statement = pgq.query
-
-        query = self._convert_query(statement)
-
-        self._execute_send(query, binary=False)
-        results = yield from execute(self._pgconn)
-        if len(results) != 1:
-            raise e.ProgrammingError("COPY cannot be mixed with other operations")
-
-        self._check_copy_result(results[0])
-        self._set_results(results)
-
-    def _execute_send(
-        self,
-        query: PostgresQuery,
-        *,
-        force_extended: bool = False,
-        binary: Optional[bool] = None,
-    ) -> None:
-        """
-        Implement part of execute() before waiting common to sync and async.
-
-        This is not a generator, but a normal non-blocking function.
-        """
-        if binary is None:
-            fmt = self.format
-        else:
-            fmt = BINARY if binary else TEXT
-
-        self._query = query
-
-        if self._conn._pipeline:
-            # In pipeline mode always use PQsendQueryParams - see #314
-            # Multiple statements in the same query are not allowed anyway.
-            self._conn._pipeline.command_queue.append(
-                partial(
-                    self._pgconn.send_query_params,
-                    query.query,
-                    query.params,
-                    param_formats=query.formats,
-                    param_types=query.types,
-                    result_format=fmt,
-                )
-            )
-        elif force_extended or query.params or fmt == BINARY:
-            self._pgconn.send_query_params(
-                query.query,
-                query.params,
-                param_formats=query.formats,
-                param_types=query.types,
-                result_format=fmt,
-            )
-        else:
-            # If we can, let's use simple query protocol,
-            # as it can execute more than one statement in a single query.
-            self._pgconn.send_query(query.query)
-
-    def _convert_query(
-        self, query: Query, params: Optional[Params] = None
-    ) -> PostgresQuery:
-        pgq = self._query_cls(self._tx)
-        pgq.convert(query, params)
-        return pgq
-
-    def _check_results(self, results: List["PGresult"]) -> None:
-        """
-        Verify that the results of a query are valid.
-
-        Verify that the query returned at least one result and that they all
-        represent a valid result from the database.
-        """
-        if not results:
-            raise e.InternalError("got no result from the query")
-
-        for res in results:
-            status = res.status
-            if status != TUPLES_OK and status != COMMAND_OK and status != EMPTY_QUERY:
-                self._raise_for_result(res)
-
-    def _raise_for_result(self, result: "PGresult") -> NoReturn:
-        """
-        Raise an appropriate error message for an unexpected database result
-        """
-        status = result.status
-        if status == FATAL_ERROR:
-            raise e.error_from_result(result, encoding=self._encoding)
-        elif status == PIPELINE_ABORTED:
-            raise e.PipelineAborted("pipeline aborted")
-        elif status == COPY_IN or status == COPY_OUT or status == COPY_BOTH:
-            raise e.ProgrammingError(
-                "COPY cannot be used with this method; use copy() instead"
-            )
-        else:
-            raise e.InternalError(
-                "unexpected result status from query:" f" {pq.ExecStatus(status).name}"
-            )
-
-    def _select_current_result(
-        self, i: int, format: Optional[pq.Format] = None
-    ) -> None:
-        """
-        Select one of the results in the cursor as the active one.
-        """
-        self._iresult = i
-        res = self.pgresult = self._results[i]
-
-        # Note: the only reason to override format is to correctly set
-        # binary loaders on server-side cursors, because send_describe_portal
-        # only returns a text result.
-        self._tx.set_pgresult(res, format=format)
-
-        self._pos = 0
-
-        if res.status == TUPLES_OK:
-            self._rowcount = self.pgresult.ntuples
-
-        # COPY_OUT has never info about nrows. We need such result for the
-        # columns in order to return a `description`, but not overwrite the
-        # cursor rowcount (which was set by the Copy object).
-        elif res.status != COPY_OUT:
-            nrows = self.pgresult.command_tuples
-            self._rowcount = nrows if nrows is not None else -1
-
-        self._make_row = self._make_row_maker()
-
-    def _set_results(self, results: List["PGresult"]) -> None:
-        if self._execmany_returning is None:
-            # Received from execute()
-            self._results[:] = results
-            self._select_current_result(0)
-
-        else:
-            # Received from executemany()
-            if self._execmany_returning:
-                first_batch = not self._results
-                self._results.extend(results)
-                if first_batch:
-                    self._select_current_result(0)
-            else:
-                # In non-returning case, set rowcount to the cumulated number of
-                # rows of executed queries.
-                for res in results:
-                    self._rowcount += res.command_tuples or 0
-
-    def _send_prepare(self, name: bytes, query: PostgresQuery) -> None:
-        if self._conn._pipeline:
-            self._conn._pipeline.command_queue.append(
-                partial(
-                    self._pgconn.send_prepare,
-                    name,
-                    query.query,
-                    param_types=query.types,
-                )
-            )
-            self._conn._pipeline.result_queue.append(None)
-        else:
-            self._pgconn.send_prepare(name, query.query, param_types=query.types)
-
-    def _send_query_prepared(
-        self, name: bytes, pgq: PostgresQuery, *, binary: Optional[bool] = None
-    ) -> None:
-        if binary is None:
-            fmt = self.format
-        else:
-            fmt = BINARY if binary else TEXT
-
-        if self._conn._pipeline:
-            self._conn._pipeline.command_queue.append(
-                partial(
-                    self._pgconn.send_query_prepared,
-                    name,
-                    pgq.params,
-                    param_formats=pgq.formats,
-                    result_format=fmt,
-                )
-            )
-        else:
-            self._pgconn.send_query_prepared(
-                name, pgq.params, param_formats=pgq.formats, result_format=fmt
-            )
-
-    def _check_result_for_fetch(self) -> None:
-        if self.closed:
-            raise e.InterfaceError("the cursor is closed")
-        res = self.pgresult
-        if not res:
-            raise e.ProgrammingError("no result available")
-
-        status = res.status
-        if status == TUPLES_OK:
-            return
-        elif status == FATAL_ERROR:
-            raise e.error_from_result(res, encoding=self._encoding)
-        elif status == PIPELINE_ABORTED:
-            raise e.PipelineAborted("pipeline aborted")
-        else:
-            raise e.ProgrammingError("the last operation didn't produce a result")
-
-    def _check_copy_result(self, result: "PGresult") -> None:
-        """
-        Check that the value returned in a copy() operation is a legit COPY.
-        """
-        status = result.status
-        if status == COPY_IN or status == COPY_OUT:
-            return
-        elif status == FATAL_ERROR:
-            raise e.error_from_result(result, encoding=self._encoding)
-        else:
-            raise e.ProgrammingError(
-                "copy() should be used only with COPY ... TO STDOUT or COPY ..."
-                f" FROM STDIN statements, got {pq.ExecStatus(status).name}"
-            )
-
-    def _scroll(self, value: int, mode: str) -> None:
-        self._check_result_for_fetch()
-        assert self.pgresult
-        if mode == "relative":
-            newpos = self._pos + value
-        elif mode == "absolute":
-            newpos = value
-        else:
-            raise ValueError(f"bad mode: {mode}. It should be 'relative' or 'absolute'")
-        if not 0 <= newpos < self.pgresult.ntuples:
-            raise IndexError("position out of bound")
-        self._pos = newpos
-
-    def _close(self) -> None:
-        """Non-blocking part of closing. Common to sync/async."""
-        # Don't reset the query because it may be useful to investigate after
-        # an error.
-        self._reset(reset_query=False)
-        self._closed = True
-
-    @property
-    def _encoding(self) -> str:
-        return pgconn_encoding(self._pgconn)
-
-
 class Cursor(BaseCursor["Connection[Any]", Row]):
     __module__ = "psycopg"
     __slots__ = ()
index 58fce6420a7fef0bc1d04d39a60c616c0eea81a4..5289bb62a492840119ee336ec5e95c0e9ec83bbf 100644 (file)
@@ -1,12 +1,12 @@
 """
-psycopg async cursor objects
+Psycopg AsyncCursor object
 """
 
 # Copyright (C) 2020 The Psycopg Team
 
 from types import TracebackType
-from typing import Any, AsyncIterator, Iterable, List
-from typing import Optional, Type, TypeVar, TYPE_CHECKING, overload
+from typing import Any, AsyncIterator, Iterable, List, Optional, Type, TypeVar
+from typing import TYPE_CHECKING, overload
 from contextlib import asynccontextmanager
 
 from . import pq
@@ -14,8 +14,8 @@ from . import errors as e
 from .abc import Query, Params
 from .copy import AsyncCopy, AsyncWriter as AsyncCopyWriter
 from .rows import Row, RowMaker, AsyncRowFactory
-from .cursor import BaseCursor
 from ._pipeline import Pipeline
+from ._cursor_base import BaseCursor
 
 if TYPE_CHECKING:
     from .connection_async import AsyncConnection
index 7fc4090bf34685777329ef83cf719690a26121d7..d0984da7e0f5947a62a0cd48ff48dae312aaea3a 100644 (file)
@@ -10,9 +10,10 @@ from .abc import ConnectionType, Query, Params
 from .sql import Composable
 from .rows import Row
 from ._enums import PyFormat
-from .cursor import BaseCursor, Cursor
+from .cursor import Cursor
 from .cursor_async import AsyncCursor
 from ._queries import PostgresQuery
+from ._cursor_base import BaseCursor
 
 if TYPE_CHECKING:
     from typing import Any  # noqa: F401
index 5655a43dd802b84f76b18a51658cb7568afa1fb7..f9b78e5e2850a8b312b05adcb314b5c31745245d 100644 (file)
@@ -15,7 +15,8 @@ from . import errors as e
 from ._encodings import _as_python_identifier
 
 if TYPE_CHECKING:
-    from .cursor import BaseCursor, Cursor
+    from .cursor import Cursor
+    from ._cursor_base import BaseCursor
     from .cursor_async import AsyncCursor
     from psycopg.pq.abc import PGresult
 
index 7a86e599d45c6bdb225ca311091c36a351b8b3e0..eada346f9eefa08d36bdadc25bdf2eac52f2dc34 100644 (file)
@@ -13,8 +13,9 @@ from . import sql
 from . import errors as e
 from .abc import ConnectionType, Query, Params, PQGen
 from .rows import Row, RowFactory, AsyncRowFactory
-from .cursor import BaseCursor, Cursor
+from .cursor import Cursor
 from .generators import execute
+from ._cursor_base import BaseCursor
 from .cursor_async import AsyncCursor
 
 if TYPE_CHECKING: