]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
refactor: move BaseConnetion in its own module
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 2 Sep 2023 16:31:18 +0000 (17:31 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 11 Oct 2023 21:45:38 +0000 (23:45 +0200)
This commit is brought you by: flight delayed 6 hours for technical failure.

22 files changed:
psycopg/psycopg/__init__.py
psycopg/psycopg/_adapters_map.py
psycopg/psycopg/_connection_base.py [new file with mode: 0644]
psycopg/psycopg/_encodings.py
psycopg/psycopg/_pipeline.py
psycopg/psycopg/_py_transformer.py
psycopg/psycopg/_typeinfo.py
psycopg/psycopg/abc.py
psycopg/psycopg/adapt.py
psycopg/psycopg/connection.py
psycopg/psycopg/connection_async.py
psycopg/psycopg/types/composite.py
psycopg/psycopg/types/datetime.py
psycopg/psycopg/types/enum.py
psycopg/psycopg/types/multirange.py
psycopg/psycopg/types/range.py
psycopg_c/psycopg_c/_psycopg.pyi
psycopg_pool/psycopg_pool/base.py
tests/test_connection.py
tests/test_connection_async.py
tests/test_module.py
tests/test_psycopg_dbapi20.py

index ff7d398a1d0494c424e6845bb7179cec4f76ef18..abeb0793042bbd9d8b65e1b021f21d5fd2afd832 100644 (file)
@@ -19,12 +19,13 @@ from .errors import InternalError, ProgrammingError, NotSupportedError
 from ._column import Column
 from .conninfo import ConnectionInfo
 from ._pipeline import Pipeline, AsyncPipeline
-from .connection import BaseConnection, Connection, Notify
+from .connection import Connection
 from .transaction import Rollback, Transaction, AsyncTransaction
 from .cursor_async import AsyncCursor
 from .server_cursor import AsyncServerCursor, ServerCursor
 from .client_cursor import AsyncClientCursor, ClientCursor
 from .raw_cursor import AsyncRawCursor, RawCursor
+from ._connection_base import BaseConnection, Notify
 from .connection_async import AsyncConnection
 
 from . import dbapi20
index 1c8981f58fa4e287f665a9df4fd861cb9617b351..fae5cb54521169f8ff729d4960e180e1a5d3eb1a 100644 (file)
@@ -15,7 +15,7 @@ from ._cmodule import _psycopg
 from ._typeinfo import TypesRegistry
 
 if TYPE_CHECKING:
-    from .connection import BaseConnection
+    from ._connection_base import BaseConnection
 
 RV = TypeVar("RV")
 
diff --git a/psycopg/psycopg/_connection_base.py b/psycopg/psycopg/_connection_base.py
new file mode 100644 (file)
index 0000000..fb8db5d
--- /dev/null
@@ -0,0 +1,648 @@
+"""
+psycopg connection objects
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import logging
+from typing import Callable, Generic
+from typing import List, NamedTuple, Optional, Type, TypeVar, Tuple, Union
+from typing import TYPE_CHECKING
+from weakref import ref, ReferenceType
+from warnings import warn
+from functools import partial
+from typing_extensions import TypeAlias
+
+from . import pq
+from . import errors as e
+from . import postgres
+from . import generators
+from .abc import ConnectionType, PQGen, PQGenConn, Query
+from .sql import Composable, SQL
+from ._tpc import Xid
+from .rows import Row
+from .adapt import AdaptersMap
+from ._enums import IsolationLevel
+from ._compat import LiteralString
+from .pq.misc import connection_summary
+from .conninfo import ConnectionInfo
+from ._pipeline import BasePipeline
+from ._encodings import pgconn_encoding
+from ._preparing import PrepareManager
+
+if TYPE_CHECKING:
+    from .pq.abc import PGconn, PGresult
+    from psycopg_pool.base import BasePool
+
+# Row Type variable for Cursor (when it needs to be distinguished from the
+# connection's one)
+CursorRow = TypeVar("CursorRow")
+
+TEXT = pq.Format.TEXT
+BINARY = pq.Format.BINARY
+
+OK = pq.ConnStatus.OK
+BAD = pq.ConnStatus.BAD
+
+COMMAND_OK = pq.ExecStatus.COMMAND_OK
+TUPLES_OK = pq.ExecStatus.TUPLES_OK
+FATAL_ERROR = pq.ExecStatus.FATAL_ERROR
+
+IDLE = pq.TransactionStatus.IDLE
+INTRANS = pq.TransactionStatus.INTRANS
+
+logger = logging.getLogger("psycopg")
+
+
+class Notify(NamedTuple):
+    """An asynchronous notification received from the database."""
+
+    channel: str
+    """The name of the channel on which the notification was received."""
+
+    payload: str
+    """The message attached to the notification."""
+
+    pid: int
+    """The PID of the backend process which sent the notification."""
+
+
+Notify.__module__ = "psycopg"
+
+NoticeHandler: TypeAlias = Callable[[e.Diagnostic], None]
+NotifyHandler: TypeAlias = Callable[[Notify], None]
+
+
+class BaseConnection(Generic[Row]):
+    """
+    Base class for different types of connections.
+
+    Share common functionalities such as access to the wrapped PGconn, but
+    allow different interfaces (sync/async).
+    """
+
+    # DBAPI2 exposed exceptions
+    Warning = e.Warning
+    Error = e.Error
+    InterfaceError = e.InterfaceError
+    DatabaseError = e.DatabaseError
+    DataError = e.DataError
+    OperationalError = e.OperationalError
+    IntegrityError = e.IntegrityError
+    InternalError = e.InternalError
+    ProgrammingError = e.ProgrammingError
+    NotSupportedError = e.NotSupportedError
+
+    # Enums useful for the connection
+    ConnStatus = pq.ConnStatus
+    TransactionStatus = pq.TransactionStatus
+
+    def __init__(self, pgconn: "PGconn"):
+        self.pgconn = pgconn
+        self._autocommit = False
+
+        # None, but set to a copy of the global adapters map as soon as requested.
+        self._adapters: Optional[AdaptersMap] = None
+
+        self._notice_handlers: List[NoticeHandler] = []
+        self._notify_handlers: List[NotifyHandler] = []
+
+        # Number of transaction blocks currently entered
+        self._num_transactions = 0
+
+        self._closed = False  # closed by an explicit close()
+        self._prepared: PrepareManager = PrepareManager()
+        self._tpc: Optional[Tuple[Xid, bool]] = None  # xid, prepared
+
+        wself = ref(self)
+        pgconn.notice_handler = partial(BaseConnection._notice_handler, wself)
+        pgconn.notify_handler = partial(BaseConnection._notify_handler, wself)
+
+        # Attribute is only set if the connection is from a pool so we can tell
+        # apart a connection in the pool too (when _pool = None)
+        self._pool: Optional["BasePool"]
+
+        self._pipeline: Optional[BasePipeline] = None
+
+        # Time after which the connection should be closed
+        self._expire_at: float
+
+        self._isolation_level: Optional[IsolationLevel] = None
+        self._read_only: Optional[bool] = None
+        self._deferrable: Optional[bool] = None
+        self._begin_statement = b""
+
+    def __del__(self) -> None:
+        # If fails on connection we might not have this attribute yet
+        if not hasattr(self, "pgconn"):
+            return
+
+        # Connection correctly closed
+        if self.closed:
+            return
+
+        # Connection in a pool so terminating with the program is normal
+        if hasattr(self, "_pool"):
+            return
+
+        warn(
+            f"connection {self} was deleted while still open."
+            " Please use 'with' or '.close()' to close the connection",
+            ResourceWarning,
+        )
+
+    def __repr__(self) -> str:
+        cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
+        info = connection_summary(self.pgconn)
+        return f"<{cls} {info} at 0x{id(self):x}>"
+
+    @property
+    def closed(self) -> bool:
+        """`!True` if the connection is closed."""
+        return self.pgconn.status == BAD
+
+    @property
+    def broken(self) -> bool:
+        """
+        `!True` if the connection was interrupted.
+
+        A broken connection is always `closed`, but wasn't closed in a clean
+        way, such as using `close()` or a `!with` block.
+        """
+        return self.pgconn.status == BAD and not self._closed
+
+    @property
+    def autocommit(self) -> bool:
+        """The autocommit state of the connection."""
+        return self._autocommit
+
+    @autocommit.setter
+    def autocommit(self, value: bool) -> None:
+        self._set_autocommit(value)
+
+    def _set_autocommit(self, value: bool) -> None:
+        raise NotImplementedError
+
+    def _set_autocommit_gen(self, value: bool) -> PQGen[None]:
+        yield from self._check_intrans_gen("autocommit")
+        self._autocommit = bool(value)
+
+    @property
+    def isolation_level(self) -> Optional[IsolationLevel]:
+        """
+        The isolation level of the new transactions started on the connection.
+        """
+        return self._isolation_level
+
+    @isolation_level.setter
+    def isolation_level(self, value: Optional[IsolationLevel]) -> None:
+        self._set_isolation_level(value)
+
+    def _set_isolation_level(self, value: Optional[IsolationLevel]) -> None:
+        raise NotImplementedError
+
+    def _set_isolation_level_gen(self, value: Optional[IsolationLevel]) -> PQGen[None]:
+        yield from self._check_intrans_gen("isolation_level")
+        self._isolation_level = IsolationLevel(value) if value is not None else None
+        self._begin_statement = b""
+
+    @property
+    def read_only(self) -> Optional[bool]:
+        """
+        The read-only state of the new transactions started on the connection.
+        """
+        return self._read_only
+
+    @read_only.setter
+    def read_only(self, value: Optional[bool]) -> None:
+        self._set_read_only(value)
+
+    def _set_read_only(self, value: Optional[bool]) -> None:
+        raise NotImplementedError
+
+    def _set_read_only_gen(self, value: Optional[bool]) -> PQGen[None]:
+        yield from self._check_intrans_gen("read_only")
+        self._read_only = bool(value) if value is not None else None
+        self._begin_statement = b""
+
+    @property
+    def deferrable(self) -> Optional[bool]:
+        """
+        The deferrable state of the new transactions started on the connection.
+        """
+        return self._deferrable
+
+    @deferrable.setter
+    def deferrable(self, value: Optional[bool]) -> None:
+        self._set_deferrable(value)
+
+    def _set_deferrable(self, value: Optional[bool]) -> None:
+        raise NotImplementedError
+
+    def _set_deferrable_gen(self, value: Optional[bool]) -> PQGen[None]:
+        yield from self._check_intrans_gen("deferrable")
+        self._deferrable = bool(value) if value is not None else None
+        self._begin_statement = b""
+
+    def _check_intrans_gen(self, attribute: str) -> PQGen[None]:
+        # Raise an exception if we are in a transaction
+        status = self.pgconn.transaction_status
+        if status == IDLE and self._pipeline:
+            yield from self._pipeline._sync_gen()
+            status = self.pgconn.transaction_status
+        if status != IDLE:
+            if self._num_transactions:
+                raise e.ProgrammingError(
+                    f"can't change {attribute!r} now: "
+                    "connection.transaction() context in progress"
+                )
+            else:
+                raise e.ProgrammingError(
+                    f"can't change {attribute!r} now: "
+                    "connection in transaction status "
+                    f"{pq.TransactionStatus(status).name}"
+                )
+
+    @property
+    def info(self) -> ConnectionInfo:
+        """A `ConnectionInfo` attribute to inspect connection properties."""
+        return ConnectionInfo(self.pgconn)
+
+    @property
+    def adapters(self) -> AdaptersMap:
+        if not self._adapters:
+            self._adapters = AdaptersMap(postgres.adapters)
+
+        return self._adapters
+
+    @property
+    def connection(self) -> "BaseConnection[Row]":
+        # implement the AdaptContext protocol
+        return self
+
+    def fileno(self) -> int:
+        """Return the file descriptor of the connection.
+
+        This function allows to use the connection as file-like object in
+        functions waiting for readiness, such as the ones defined in the
+        `selectors` module.
+        """
+        return self.pgconn.socket
+
+    def cancel(self) -> None:
+        """Cancel the current operation on the connection."""
+        # No-op if the connection is closed
+        # this allows to use the method as callback handler without caring
+        # about its life.
+        if self.closed:
+            return
+
+        if self._tpc and self._tpc[1]:
+            raise e.ProgrammingError(
+                "cancel() cannot be used with a prepared two-phase transaction"
+            )
+
+        self._try_cancel(self.pgconn)
+
+    @classmethod
+    def _try_cancel(cls, pgconn: "PGconn") -> None:
+        try:
+            # Can fail if the connection is closed
+            c = pgconn.get_cancel()
+        except Exception as ex:
+            logger.warning("couldn't try to cancel query: %s", ex)
+        else:
+            c.cancel()
+
+    def add_notice_handler(self, callback: NoticeHandler) -> None:
+        """
+        Register a callable to be invoked when a notice message is received.
+
+        :param callback: the callback to call upon message received.
+        :type callback: Callable[[~psycopg.errors.Diagnostic], None]
+        """
+        self._notice_handlers.append(callback)
+
+    def remove_notice_handler(self, callback: NoticeHandler) -> None:
+        """
+        Unregister a notice message callable previously registered.
+
+        :param callback: the callback to remove.
+        :type callback: Callable[[~psycopg.errors.Diagnostic], None]
+        """
+        self._notice_handlers.remove(callback)
+
+    @staticmethod
+    def _notice_handler(
+        wself: "ReferenceType[BaseConnection[Row]]", res: "PGresult"
+    ) -> None:
+        self = wself()
+        if not (self and self._notice_handlers):
+            return
+
+        diag = e.Diagnostic(res, pgconn_encoding(self.pgconn))
+        for cb in self._notice_handlers:
+            try:
+                cb(diag)
+            except Exception as ex:
+                logger.exception("error processing notice callback '%s': %s", cb, ex)
+
+    def add_notify_handler(self, callback: NotifyHandler) -> None:
+        """
+        Register a callable to be invoked whenever a notification is received.
+
+        :param callback: the callback to call upon notification received.
+        :type callback: Callable[[~psycopg.Notify], None]
+        """
+        self._notify_handlers.append(callback)
+
+    def remove_notify_handler(self, callback: NotifyHandler) -> None:
+        """
+        Unregister a notification callable previously registered.
+
+        :param callback: the callback to remove.
+        :type callback: Callable[[~psycopg.Notify], None]
+        """
+        self._notify_handlers.remove(callback)
+
+    @staticmethod
+    def _notify_handler(
+        wself: "ReferenceType[BaseConnection[Row]]", pgn: pq.PGnotify
+    ) -> None:
+        self = wself()
+        if not (self and self._notify_handlers):
+            return
+
+        enc = pgconn_encoding(self.pgconn)
+        n = Notify(pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid)
+        for cb in self._notify_handlers:
+            cb(n)
+
+    @property
+    def prepare_threshold(self) -> Optional[int]:
+        """
+        Number of times a query is executed before it is prepared.
+
+        - If it is set to 0, every query is prepared the first time it is
+          executed.
+        - If it is set to `!None`, prepared statements are disabled on the
+          connection.
+
+        Default value: 5
+        """
+        return self._prepared.prepare_threshold
+
+    @prepare_threshold.setter
+    def prepare_threshold(self, value: Optional[int]) -> None:
+        self._prepared.prepare_threshold = value
+
+    @property
+    def prepared_max(self) -> int:
+        """
+        Maximum number of prepared statements on the connection.
+
+        Default value: 100
+        """
+        return self._prepared.prepared_max
+
+    @prepared_max.setter
+    def prepared_max(self, value: int) -> None:
+        self._prepared.prepared_max = value
+
+    # Generators to perform high-level operations on the connection
+    #
+    # These operations are expressed in terms of non-blocking generators
+    # and the task of waiting when needed (when the generators yield) is left
+    # to the connections subclass, which might wait either in blocking mode
+    # or through asyncio.
+    #
+    # All these generators assume exclusive access to the connection: subclasses
+    # should have a lock and hold it before calling and consuming them.
+
+    @classmethod
+    def _connect_gen(
+        cls: Type[ConnectionType],
+        conninfo: str = "",
+        *,
+        autocommit: bool = False,
+    ) -> PQGenConn[ConnectionType]:
+        """Generator to connect to the database and create a new instance."""
+        pgconn = yield from generators.connect(conninfo)
+        conn = cls(pgconn)
+        conn._autocommit = bool(autocommit)
+        return conn
+
+    def _exec_command(
+        self, command: Query, result_format: pq.Format = TEXT
+    ) -> PQGen[Optional["PGresult"]]:
+        """
+        Generator to send a command and receive the result to the backend.
+
+        Only used to implement internal commands such as "commit", with eventual
+        arguments bound client-side. The cursor can do more complex stuff.
+        """
+        self._check_connection_ok()
+
+        if isinstance(command, str):
+            command = command.encode(pgconn_encoding(self.pgconn))
+        elif isinstance(command, Composable):
+            command = command.as_bytes(self)
+
+        if self._pipeline:
+            cmd = partial(
+                self.pgconn.send_query_params,
+                command,
+                None,
+                result_format=result_format,
+            )
+            self._pipeline.command_queue.append(cmd)
+            self._pipeline.result_queue.append(None)
+            return None
+
+        self.pgconn.send_query_params(command, None, result_format=result_format)
+
+        result = (yield from generators.execute(self.pgconn))[-1]
+        if result.status != COMMAND_OK and result.status != TUPLES_OK:
+            if result.status == FATAL_ERROR:
+                raise e.error_from_result(result, encoding=pgconn_encoding(self.pgconn))
+            else:
+                raise e.InterfaceError(
+                    f"unexpected result {pq.ExecStatus(result.status).name}"
+                    f" from command {command.decode()!r}"
+                )
+        return result
+
+    def _check_connection_ok(self) -> None:
+        if self.pgconn.status == OK:
+            return
+
+        if self.pgconn.status == BAD:
+            raise e.OperationalError("the connection is closed")
+        raise e.InterfaceError(
+            "cannot execute operations: the connection is"
+            f" in status {self.pgconn.status}"
+        )
+
+    def _start_query(self) -> PQGen[None]:
+        """Generator to start a transaction if necessary."""
+        if self._autocommit:
+            return
+
+        if self.pgconn.transaction_status != IDLE:
+            return
+
+        yield from self._exec_command(self._get_tx_start_command())
+        if self._pipeline:
+            yield from self._pipeline._sync_gen()
+
+    def _get_tx_start_command(self) -> bytes:
+        if self._begin_statement:
+            return self._begin_statement
+
+        parts = [b"BEGIN"]
+
+        if self.isolation_level is not None:
+            val = IsolationLevel(self.isolation_level)
+            parts.append(b"ISOLATION LEVEL")
+            parts.append(val.name.replace("_", " ").encode())
+
+        if self.read_only is not None:
+            parts.append(b"READ ONLY" if self.read_only else b"READ WRITE")
+
+        if self.deferrable is not None:
+            parts.append(b"DEFERRABLE" if self.deferrable else b"NOT DEFERRABLE")
+
+        self._begin_statement = b" ".join(parts)
+        return self._begin_statement
+
+    def _commit_gen(self) -> PQGen[None]:
+        """Generator implementing `Connection.commit()`."""
+        if self._num_transactions:
+            raise e.ProgrammingError(
+                "Explicit commit() forbidden within a Transaction "
+                "context. (Transaction will be automatically committed "
+                "on successful exit from context.)"
+            )
+        if self._tpc:
+            raise e.ProgrammingError(
+                "commit() cannot be used during a two-phase transaction"
+            )
+        if self.pgconn.transaction_status == IDLE:
+            return
+
+        yield from self._exec_command(b"COMMIT")
+
+        if self._pipeline:
+            yield from self._pipeline._sync_gen()
+
+    def _rollback_gen(self) -> PQGen[None]:
+        """Generator implementing `Connection.rollback()`."""
+        if self._num_transactions:
+            raise e.ProgrammingError(
+                "Explicit rollback() forbidden within a Transaction "
+                "context. (Either raise Rollback() or allow "
+                "an exception to propagate out of the context.)"
+            )
+        if self._tpc:
+            raise e.ProgrammingError(
+                "rollback() cannot be used during a two-phase transaction"
+            )
+
+        # Get out of a "pipeline aborted" state
+        if self._pipeline:
+            yield from self._pipeline._sync_gen()
+
+        if self.pgconn.transaction_status == IDLE:
+            return
+
+        yield from self._exec_command(b"ROLLBACK")
+        self._prepared.clear()
+        for cmd in self._prepared.get_maintenance_commands():
+            yield from self._exec_command(cmd)
+
+        if self._pipeline:
+            yield from self._pipeline._sync_gen()
+
+    def xid(self, format_id: int, gtrid: str, bqual: str) -> Xid:
+        """
+        Returns a `Xid` to pass to the `!tpc_*()` methods of this connection.
+
+        The argument types and constraints are explained in
+        :ref:`two-phase-commit`.
+
+        The values passed to the method will be available on the returned
+        object as the members `~Xid.format_id`, `~Xid.gtrid`, `~Xid.bqual`.
+        """
+        self._check_tpc()
+        return Xid.from_parts(format_id, gtrid, bqual)
+
+    def _tpc_begin_gen(self, xid: Union[Xid, str]) -> PQGen[None]:
+        self._check_tpc()
+
+        if not isinstance(xid, Xid):
+            xid = Xid.from_string(xid)
+
+        if self.pgconn.transaction_status != IDLE:
+            raise e.ProgrammingError(
+                "can't start two-phase transaction: connection in status"
+                f" {pq.TransactionStatus(self.pgconn.transaction_status).name}"
+            )
+
+        if self._autocommit:
+            raise e.ProgrammingError(
+                "can't use two-phase transactions in autocommit mode"
+            )
+
+        self._tpc = (xid, False)
+        yield from self._exec_command(self._get_tx_start_command())
+
+    def _tpc_prepare_gen(self) -> PQGen[None]:
+        if not self._tpc:
+            raise e.ProgrammingError(
+                "'tpc_prepare()' must be called inside a two-phase transaction"
+            )
+        if self._tpc[1]:
+            raise e.ProgrammingError(
+                "'tpc_prepare()' cannot be used during a prepared two-phase transaction"
+            )
+        xid = self._tpc[0]
+        self._tpc = (xid, True)
+        yield from self._exec_command(SQL("PREPARE TRANSACTION {}").format(str(xid)))
+        if self._pipeline:
+            yield from self._pipeline._sync_gen()
+
+    def _tpc_finish_gen(
+        self, action: LiteralString, xid: Union[Xid, str, None]
+    ) -> PQGen[None]:
+        fname = f"tpc_{action.lower()}()"
+        if xid is None:
+            if not self._tpc:
+                raise e.ProgrammingError(
+                    f"{fname} without xid must must be"
+                    " called inside a two-phase transaction"
+                )
+            xid = self._tpc[0]
+        else:
+            if self._tpc:
+                raise e.ProgrammingError(
+                    f"{fname} with xid must must be called"
+                    " outside a two-phase transaction"
+                )
+            if not isinstance(xid, Xid):
+                xid = Xid.from_string(xid)
+
+        if self._tpc and not self._tpc[1]:
+            meth: Callable[[], PQGen[None]]
+            meth = getattr(self, f"_{action.lower()}_gen")
+            self._tpc = None
+            yield from meth()
+        else:
+            yield from self._exec_command(
+                SQL("{} PREPARED {}").format(SQL(action), str(xid))
+            )
+            self._tpc = None
+
+    def _check_tpc(self) -> None:
+        """Raise NotSupportedError if TPC is not supported."""
+        # TPC supported on every supported PostgreSQL version.
+        pass
index 876acb975df7d57362b093cfc7e71e02bc9da985..3f14c260869993bdf81534771722870bf57fadd9 100644 (file)
@@ -15,7 +15,7 @@ from ._compat import cache
 
 if TYPE_CHECKING:
     from .pq.abc import PGconn
-    from .connection import BaseConnection
+    from ._connection_base import BaseConnection
 
 OK = ConnStatus.OK
 
index 2223f491ac6bdd79e7a6543fc1138648c9aa109a..ff7228eeea3f814ee5f9a3f02f918aef5dd43941 100644 (file)
@@ -21,7 +21,8 @@ from .generators import pipeline_communicate, fetch_many, send
 if TYPE_CHECKING:
     from .pq.abc import PGresult
     from ._cursor_base import BaseCursor
-    from .connection import BaseConnection, Connection
+    from .connection import Connection
+    from ._connection_base import BaseConnection
     from .connection_async import AsyncConnection
 
 
index 0438725c3b9e3ab9568202493cf74d710e215d51..7174893c5e804c65182ed69ff253b615a367432c 100644 (file)
@@ -25,7 +25,7 @@ from ._encodings import pgconn_encoding
 if TYPE_CHECKING:
     from .adapt import AdaptersMap
     from .pq.abc import PGresult
-    from .connection import BaseConnection
+    from ._connection_base import BaseConnection
 
 DumperCache: TypeAlias = Dict[DumperKey, abc.Dumper]
 OidDumperCache: TypeAlias = Dict[int, abc.Dumper]
index dcbb2c0950821f7f4229b34ae018c746d1eefa1d..141ed04271682113890d3cf55c36f41f21b5d1aa 100644 (file)
@@ -18,8 +18,9 @@ from .rows import dict_row
 from ._encodings import conn_encoding
 
 if TYPE_CHECKING:
-    from .connection import BaseConnection, Connection
+    from .connection import Connection
     from .connection_async import AsyncConnection
+    from ._connection_base import BaseConnection
 
 T = TypeVar("T", bound="TypeInfo")
 RegistryKey: TypeAlias = Union[str, int, Tuple[type, int]]
index 0ee2037ac5d5c7805c33ff30c3b7e4de4aff19af..8edcdc10711b8e35d06e3227c23ac8319399df78 100644 (file)
@@ -18,8 +18,8 @@ if TYPE_CHECKING:
     from .rows import Row, RowMaker
     from .pq.abc import PGresult
     from .waiting import Wait, Ready
-    from .connection import BaseConnection
     from ._adapters_map import AdaptersMap
+    from ._connection_base import BaseConnection
 
 NoneType: type = type(None)
 
index f09bd6a84e3d3c3443b966a32a746a93453bb6ed..31a7104296d660bd8f3b834b19afa70de78056f0 100644 (file)
@@ -15,7 +15,7 @@ from ._transformer import Transformer as Transformer
 from ._adapters_map import AdaptersMap as AdaptersMap  # noqa: F401
 
 if TYPE_CHECKING:
-    from .connection import BaseConnection
+    from ._connection_base import BaseConnection
 
 Buffer = abc.Buffer
 
index a6571d5d7b683321ede9c5c98e3f90c08374a470..73724f616f23f1ea2abcca0e8c36e5517a2779eb 100644 (file)
@@ -7,656 +7,41 @@ psycopg connection objects
 import logging
 import threading
 from types import TracebackType
-from typing import Any, Callable, cast, Dict, Generator, Generic, Iterator
-from typing import List, NamedTuple, Optional, Type, TypeVar, Tuple, Union
+from typing import Any, cast, Dict, Generator, Iterator
+from typing import List, Optional, Type, TypeVar, Union
 from typing import overload, TYPE_CHECKING
-from weakref import ref, ReferenceType
-from warnings import warn
-from functools import partial
 from contextlib import contextmanager
-from typing_extensions import TypeAlias
 
 from . import pq
 from . import errors as e
 from . import waiting
-from . import postgres
-from .abc import AdaptContext, ConnectionType, Params, Query, RV
+from .abc import AdaptContext, Params, Query, RV
 from .abc import PQGen, PQGenConn
-from .sql import Composable, SQL
 from ._tpc import Xid
 from .rows import Row, RowFactory, tuple_row, TupleRow, args_row
 from .adapt import AdaptersMap
 from ._enums import IsolationLevel
 from .cursor import Cursor
-from ._compat import LiteralString
-from .pq.misc import connection_summary
-from .conninfo import make_conninfo, conninfo_to_dict, ConnectionInfo
-from ._pipeline import BasePipeline, Pipeline
-from .generators import notifies, connect, execute
+from .conninfo import make_conninfo, conninfo_to_dict
+from ._pipeline import Pipeline
+from .generators import notifies
 from ._encodings import pgconn_encoding
-from ._preparing import PrepareManager
 from .transaction import Transaction
 from .server_cursor import ServerCursor
+from ._connection_base import BaseConnection, CursorRow, Notify
 
 if TYPE_CHECKING:
-    from .pq.abc import PGconn, PGresult
-    from psycopg_pool.base import BasePool
-
-
-# Row Type variable for Cursor (when it needs to be distinguished from the
-# connection's one)
-CursorRow = TypeVar("CursorRow")
+    from .pq.abc import PGconn
 
 TEXT = pq.Format.TEXT
 BINARY = pq.Format.BINARY
 
-OK = pq.ConnStatus.OK
-BAD = pq.ConnStatus.BAD
-
-COMMAND_OK = pq.ExecStatus.COMMAND_OK
-TUPLES_OK = pq.ExecStatus.TUPLES_OK
-FATAL_ERROR = pq.ExecStatus.FATAL_ERROR
-
 IDLE = pq.TransactionStatus.IDLE
 INTRANS = pq.TransactionStatus.INTRANS
 
 logger = logging.getLogger("psycopg")
 
 
-class Notify(NamedTuple):
-    """An asynchronous notification received from the database."""
-
-    channel: str
-    """The name of the channel on which the notification was received."""
-
-    payload: str
-    """The message attached to the notification."""
-
-    pid: int
-    """The PID of the backend process which sent the notification."""
-
-
-Notify.__module__ = "psycopg"
-
-NoticeHandler: TypeAlias = Callable[[e.Diagnostic], None]
-NotifyHandler: TypeAlias = Callable[[Notify], None]
-
-
-class BaseConnection(Generic[Row]):
-    """
-    Base class for different types of connections.
-
-    Share common functionalities such as access to the wrapped PGconn, but
-    allow different interfaces (sync/async).
-    """
-
-    # DBAPI2 exposed exceptions
-    Warning = e.Warning
-    Error = e.Error
-    InterfaceError = e.InterfaceError
-    DatabaseError = e.DatabaseError
-    DataError = e.DataError
-    OperationalError = e.OperationalError
-    IntegrityError = e.IntegrityError
-    InternalError = e.InternalError
-    ProgrammingError = e.ProgrammingError
-    NotSupportedError = e.NotSupportedError
-
-    # Enums useful for the connection
-    ConnStatus = pq.ConnStatus
-    TransactionStatus = pq.TransactionStatus
-
-    def __init__(self, pgconn: "PGconn"):
-        self.pgconn = pgconn
-        self._autocommit = False
-
-        # None, but set to a copy of the global adapters map as soon as requested.
-        self._adapters: Optional[AdaptersMap] = None
-
-        self._notice_handlers: List[NoticeHandler] = []
-        self._notify_handlers: List[NotifyHandler] = []
-
-        # Number of transaction blocks currently entered
-        self._num_transactions = 0
-
-        self._closed = False  # closed by an explicit close()
-        self._prepared: PrepareManager = PrepareManager()
-        self._tpc: Optional[Tuple[Xid, bool]] = None  # xid, prepared
-
-        wself = ref(self)
-        pgconn.notice_handler = partial(BaseConnection._notice_handler, wself)
-        pgconn.notify_handler = partial(BaseConnection._notify_handler, wself)
-
-        # Attribute is only set if the connection is from a pool so we can tell
-        # apart a connection in the pool too (when _pool = None)
-        self._pool: Optional["BasePool"]
-
-        self._pipeline: Optional[BasePipeline] = None
-
-        # Time after which the connection should be closed
-        self._expire_at: float
-
-        self._isolation_level: Optional[IsolationLevel] = None
-        self._read_only: Optional[bool] = None
-        self._deferrable: Optional[bool] = None
-        self._begin_statement = b""
-
-    def __del__(self) -> None:
-        # If fails on connection we might not have this attribute yet
-        if not hasattr(self, "pgconn"):
-            return
-
-        # Connection correctly closed
-        if self.closed:
-            return
-
-        # Connection in a pool so terminating with the program is normal
-        if hasattr(self, "_pool"):
-            return
-
-        warn(
-            f"connection {self} was deleted while still open."
-            " Please use 'with' or '.close()' to close the connection",
-            ResourceWarning,
-        )
-
-    def __repr__(self) -> str:
-        cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
-        info = connection_summary(self.pgconn)
-        return f"<{cls} {info} at 0x{id(self):x}>"
-
-    @property
-    def closed(self) -> bool:
-        """`!True` if the connection is closed."""
-        return self.pgconn.status == BAD
-
-    @property
-    def broken(self) -> bool:
-        """
-        `!True` if the connection was interrupted.
-
-        A broken connection is always `closed`, but wasn't closed in a clean
-        way, such as using `close()` or a `!with` block.
-        """
-        return self.pgconn.status == BAD and not self._closed
-
-    @property
-    def autocommit(self) -> bool:
-        """The autocommit state of the connection."""
-        return self._autocommit
-
-    @autocommit.setter
-    def autocommit(self, value: bool) -> None:
-        self._set_autocommit(value)
-
-    def _set_autocommit(self, value: bool) -> None:
-        raise NotImplementedError
-
-    def _set_autocommit_gen(self, value: bool) -> PQGen[None]:
-        yield from self._check_intrans_gen("autocommit")
-        self._autocommit = bool(value)
-
-    @property
-    def isolation_level(self) -> Optional[IsolationLevel]:
-        """
-        The isolation level of the new transactions started on the connection.
-        """
-        return self._isolation_level
-
-    @isolation_level.setter
-    def isolation_level(self, value: Optional[IsolationLevel]) -> None:
-        self._set_isolation_level(value)
-
-    def _set_isolation_level(self, value: Optional[IsolationLevel]) -> None:
-        raise NotImplementedError
-
-    def _set_isolation_level_gen(self, value: Optional[IsolationLevel]) -> PQGen[None]:
-        yield from self._check_intrans_gen("isolation_level")
-        self._isolation_level = IsolationLevel(value) if value is not None else None
-        self._begin_statement = b""
-
-    @property
-    def read_only(self) -> Optional[bool]:
-        """
-        The read-only state of the new transactions started on the connection.
-        """
-        return self._read_only
-
-    @read_only.setter
-    def read_only(self, value: Optional[bool]) -> None:
-        self._set_read_only(value)
-
-    def _set_read_only(self, value: Optional[bool]) -> None:
-        raise NotImplementedError
-
-    def _set_read_only_gen(self, value: Optional[bool]) -> PQGen[None]:
-        yield from self._check_intrans_gen("read_only")
-        self._read_only = bool(value) if value is not None else None
-        self._begin_statement = b""
-
-    @property
-    def deferrable(self) -> Optional[bool]:
-        """
-        The deferrable state of the new transactions started on the connection.
-        """
-        return self._deferrable
-
-    @deferrable.setter
-    def deferrable(self, value: Optional[bool]) -> None:
-        self._set_deferrable(value)
-
-    def _set_deferrable(self, value: Optional[bool]) -> None:
-        raise NotImplementedError
-
-    def _set_deferrable_gen(self, value: Optional[bool]) -> PQGen[None]:
-        yield from self._check_intrans_gen("deferrable")
-        self._deferrable = bool(value) if value is not None else None
-        self._begin_statement = b""
-
-    def _check_intrans_gen(self, attribute: str) -> PQGen[None]:
-        # Raise an exception if we are in a transaction
-        status = self.pgconn.transaction_status
-        if status == IDLE and self._pipeline:
-            yield from self._pipeline._sync_gen()
-            status = self.pgconn.transaction_status
-        if status != IDLE:
-            if self._num_transactions:
-                raise e.ProgrammingError(
-                    f"can't change {attribute!r} now: "
-                    "connection.transaction() context in progress"
-                )
-            else:
-                raise e.ProgrammingError(
-                    f"can't change {attribute!r} now: "
-                    "connection in transaction status "
-                    f"{pq.TransactionStatus(status).name}"
-                )
-
-    @property
-    def info(self) -> ConnectionInfo:
-        """A `ConnectionInfo` attribute to inspect connection properties."""
-        return ConnectionInfo(self.pgconn)
-
-    @property
-    def adapters(self) -> AdaptersMap:
-        if not self._adapters:
-            self._adapters = AdaptersMap(postgres.adapters)
-
-        return self._adapters
-
-    @property
-    def connection(self) -> "BaseConnection[Row]":
-        # implement the AdaptContext protocol
-        return self
-
-    def fileno(self) -> int:
-        """Return the file descriptor of the connection.
-
-        This function allows to use the connection as file-like object in
-        functions waiting for readiness, such as the ones defined in the
-        `selectors` module.
-        """
-        return self.pgconn.socket
-
-    def cancel(self) -> None:
-        """Cancel the current operation on the connection."""
-        # No-op if the connection is closed
-        # this allows to use the method as callback handler without caring
-        # about its life.
-        if self.closed:
-            return
-
-        if self._tpc and self._tpc[1]:
-            raise e.ProgrammingError(
-                "cancel() cannot be used with a prepared two-phase transaction"
-            )
-
-        self._try_cancel(self.pgconn)
-
-    @classmethod
-    def _try_cancel(cls, pgconn: "PGconn") -> None:
-        try:
-            # Can fail if the connection is closed
-            c = pgconn.get_cancel()
-        except Exception as ex:
-            logger.warning("couldn't try to cancel query: %s", ex)
-        else:
-            c.cancel()
-
-    def add_notice_handler(self, callback: NoticeHandler) -> None:
-        """
-        Register a callable to be invoked when a notice message is received.
-
-        :param callback: the callback to call upon message received.
-        :type callback: Callable[[~psycopg.errors.Diagnostic], None]
-        """
-        self._notice_handlers.append(callback)
-
-    def remove_notice_handler(self, callback: NoticeHandler) -> None:
-        """
-        Unregister a notice message callable previously registered.
-
-        :param callback: the callback to remove.
-        :type callback: Callable[[~psycopg.errors.Diagnostic], None]
-        """
-        self._notice_handlers.remove(callback)
-
-    @staticmethod
-    def _notice_handler(
-        wself: "ReferenceType[BaseConnection[Row]]", res: "PGresult"
-    ) -> None:
-        self = wself()
-        if not (self and self._notice_handlers):
-            return
-
-        diag = e.Diagnostic(res, pgconn_encoding(self.pgconn))
-        for cb in self._notice_handlers:
-            try:
-                cb(diag)
-            except Exception as ex:
-                logger.exception("error processing notice callback '%s': %s", cb, ex)
-
-    def add_notify_handler(self, callback: NotifyHandler) -> None:
-        """
-        Register a callable to be invoked whenever a notification is received.
-
-        :param callback: the callback to call upon notification received.
-        :type callback: Callable[[~psycopg.Notify], None]
-        """
-        self._notify_handlers.append(callback)
-
-    def remove_notify_handler(self, callback: NotifyHandler) -> None:
-        """
-        Unregister a notification callable previously registered.
-
-        :param callback: the callback to remove.
-        :type callback: Callable[[~psycopg.Notify], None]
-        """
-        self._notify_handlers.remove(callback)
-
-    @staticmethod
-    def _notify_handler(
-        wself: "ReferenceType[BaseConnection[Row]]", pgn: pq.PGnotify
-    ) -> None:
-        self = wself()
-        if not (self and self._notify_handlers):
-            return
-
-        enc = pgconn_encoding(self.pgconn)
-        n = Notify(pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid)
-        for cb in self._notify_handlers:
-            cb(n)
-
-    @property
-    def prepare_threshold(self) -> Optional[int]:
-        """
-        Number of times a query is executed before it is prepared.
-
-        - If it is set to 0, every query is prepared the first time it is
-          executed.
-        - If it is set to `!None`, prepared statements are disabled on the
-          connection.
-
-        Default value: 5
-        """
-        return self._prepared.prepare_threshold
-
-    @prepare_threshold.setter
-    def prepare_threshold(self, value: Optional[int]) -> None:
-        self._prepared.prepare_threshold = value
-
-    @property
-    def prepared_max(self) -> int:
-        """
-        Maximum number of prepared statements on the connection.
-
-        Default value: 100
-        """
-        return self._prepared.prepared_max
-
-    @prepared_max.setter
-    def prepared_max(self, value: int) -> None:
-        self._prepared.prepared_max = value
-
-    # Generators to perform high-level operations on the connection
-    #
-    # These operations are expressed in terms of non-blocking generators
-    # and the task of waiting when needed (when the generators yield) is left
-    # to the connections subclass, which might wait either in blocking mode
-    # or through asyncio.
-    #
-    # All these generators assume exclusive access to the connection: subclasses
-    # should have a lock and hold it before calling and consuming them.
-
-    @classmethod
-    def _connect_gen(
-        cls: Type[ConnectionType],
-        conninfo: str = "",
-        *,
-        autocommit: bool = False,
-    ) -> PQGenConn[ConnectionType]:
-        """Generator to connect to the database and create a new instance."""
-        pgconn = yield from connect(conninfo)
-        conn = cls(pgconn)
-        conn._autocommit = bool(autocommit)
-        return conn
-
-    def _exec_command(
-        self, command: Query, result_format: pq.Format = TEXT
-    ) -> PQGen[Optional["PGresult"]]:
-        """
-        Generator to send a command and receive the result to the backend.
-
-        Only used to implement internal commands such as "commit", with eventual
-        arguments bound client-side. The cursor can do more complex stuff.
-        """
-        self._check_connection_ok()
-
-        if isinstance(command, str):
-            command = command.encode(pgconn_encoding(self.pgconn))
-        elif isinstance(command, Composable):
-            command = command.as_bytes(self)
-
-        if self._pipeline:
-            cmd = partial(
-                self.pgconn.send_query_params,
-                command,
-                None,
-                result_format=result_format,
-            )
-            self._pipeline.command_queue.append(cmd)
-            self._pipeline.result_queue.append(None)
-            return None
-
-        self.pgconn.send_query_params(command, None, result_format=result_format)
-
-        result = (yield from execute(self.pgconn))[-1]
-        if result.status != COMMAND_OK and result.status != TUPLES_OK:
-            if result.status == FATAL_ERROR:
-                raise e.error_from_result(result, encoding=pgconn_encoding(self.pgconn))
-            else:
-                raise e.InterfaceError(
-                    f"unexpected result {pq.ExecStatus(result.status).name}"
-                    f" from command {command.decode()!r}"
-                )
-        return result
-
-    def _check_connection_ok(self) -> None:
-        if self.pgconn.status == OK:
-            return
-
-        if self.pgconn.status == BAD:
-            raise e.OperationalError("the connection is closed")
-        raise e.InterfaceError(
-            "cannot execute operations: the connection is"
-            f" in status {self.pgconn.status}"
-        )
-
-    def _start_query(self) -> PQGen[None]:
-        """Generator to start a transaction if necessary."""
-        if self._autocommit:
-            return
-
-        if self.pgconn.transaction_status != IDLE:
-            return
-
-        yield from self._exec_command(self._get_tx_start_command())
-        if self._pipeline:
-            yield from self._pipeline._sync_gen()
-
-    def _get_tx_start_command(self) -> bytes:
-        if self._begin_statement:
-            return self._begin_statement
-
-        parts = [b"BEGIN"]
-
-        if self.isolation_level is not None:
-            val = IsolationLevel(self.isolation_level)
-            parts.append(b"ISOLATION LEVEL")
-            parts.append(val.name.replace("_", " ").encode())
-
-        if self.read_only is not None:
-            parts.append(b"READ ONLY" if self.read_only else b"READ WRITE")
-
-        if self.deferrable is not None:
-            parts.append(b"DEFERRABLE" if self.deferrable else b"NOT DEFERRABLE")
-
-        self._begin_statement = b" ".join(parts)
-        return self._begin_statement
-
-    def _commit_gen(self) -> PQGen[None]:
-        """Generator implementing `Connection.commit()`."""
-        if self._num_transactions:
-            raise e.ProgrammingError(
-                "Explicit commit() forbidden within a Transaction "
-                "context. (Transaction will be automatically committed "
-                "on successful exit from context.)"
-            )
-        if self._tpc:
-            raise e.ProgrammingError(
-                "commit() cannot be used during a two-phase transaction"
-            )
-        if self.pgconn.transaction_status == IDLE:
-            return
-
-        yield from self._exec_command(b"COMMIT")
-
-        if self._pipeline:
-            yield from self._pipeline._sync_gen()
-
-    def _rollback_gen(self) -> PQGen[None]:
-        """Generator implementing `Connection.rollback()`."""
-        if self._num_transactions:
-            raise e.ProgrammingError(
-                "Explicit rollback() forbidden within a Transaction "
-                "context. (Either raise Rollback() or allow "
-                "an exception to propagate out of the context.)"
-            )
-        if self._tpc:
-            raise e.ProgrammingError(
-                "rollback() cannot be used during a two-phase transaction"
-            )
-
-        # Get out of a "pipeline aborted" state
-        if self._pipeline:
-            yield from self._pipeline._sync_gen()
-
-        if self.pgconn.transaction_status == IDLE:
-            return
-
-        yield from self._exec_command(b"ROLLBACK")
-        self._prepared.clear()
-        for cmd in self._prepared.get_maintenance_commands():
-            yield from self._exec_command(cmd)
-
-        if self._pipeline:
-            yield from self._pipeline._sync_gen()
-
-    def xid(self, format_id: int, gtrid: str, bqual: str) -> Xid:
-        """
-        Returns a `Xid` to pass to the `!tpc_*()` methods of this connection.
-
-        The argument types and constraints are explained in
-        :ref:`two-phase-commit`.
-
-        The values passed to the method will be available on the returned
-        object as the members `~Xid.format_id`, `~Xid.gtrid`, `~Xid.bqual`.
-        """
-        self._check_tpc()
-        return Xid.from_parts(format_id, gtrid, bqual)
-
-    def _tpc_begin_gen(self, xid: Union[Xid, str]) -> PQGen[None]:
-        self._check_tpc()
-
-        if not isinstance(xid, Xid):
-            xid = Xid.from_string(xid)
-
-        if self.pgconn.transaction_status != IDLE:
-            raise e.ProgrammingError(
-                "can't start two-phase transaction: connection in status"
-                f" {pq.TransactionStatus(self.pgconn.transaction_status).name}"
-            )
-
-        if self._autocommit:
-            raise e.ProgrammingError(
-                "can't use two-phase transactions in autocommit mode"
-            )
-
-        self._tpc = (xid, False)
-        yield from self._exec_command(self._get_tx_start_command())
-
-    def _tpc_prepare_gen(self) -> PQGen[None]:
-        if not self._tpc:
-            raise e.ProgrammingError(
-                "'tpc_prepare()' must be called inside a two-phase transaction"
-            )
-        if self._tpc[1]:
-            raise e.ProgrammingError(
-                "'tpc_prepare()' cannot be used during a prepared two-phase transaction"
-            )
-        xid = self._tpc[0]
-        self._tpc = (xid, True)
-        yield from self._exec_command(SQL("PREPARE TRANSACTION {}").format(str(xid)))
-        if self._pipeline:
-            yield from self._pipeline._sync_gen()
-
-    def _tpc_finish_gen(
-        self, action: LiteralString, xid: Union[Xid, str, None]
-    ) -> PQGen[None]:
-        fname = f"tpc_{action.lower()}()"
-        if xid is None:
-            if not self._tpc:
-                raise e.ProgrammingError(
-                    f"{fname} without xid must must be"
-                    " called inside a two-phase transaction"
-                )
-            xid = self._tpc[0]
-        else:
-            if self._tpc:
-                raise e.ProgrammingError(
-                    f"{fname} with xid must must be called"
-                    " outside a two-phase transaction"
-                )
-            if not isinstance(xid, Xid):
-                xid = Xid.from_string(xid)
-
-        if self._tpc and not self._tpc[1]:
-            meth: Callable[[], PQGen[None]]
-            meth = getattr(self, f"_{action.lower()}_gen")
-            self._tpc = None
-            yield from meth()
-        else:
-            yield from self._exec_command(
-                SQL("{} PREPARED {}").format(SQL(action), str(xid))
-            )
-            self._tpc = None
-
-    def _check_tpc(self) -> None:
-        """Raise NotSupportedError if TPC is not supported."""
-        # TPC supported on every supported PostgreSQL version.
-        pass
-
-
 class Connection(BaseConnection[Row]):
     """
     Wrapper for a connection to the database.
index 416d00cee9cf7da99e8433ffcf631739827c4ea0..bdae30d4596d614e4929c22f0595ccc22c1a1f6a 100644 (file)
@@ -23,11 +23,11 @@ from ._enums import IsolationLevel
 from .conninfo import make_conninfo, conninfo_to_dict, resolve_hostaddr_async
 from ._pipeline import AsyncPipeline
 from ._encodings import pgconn_encoding
-from .connection import BaseConnection, CursorRow, Notify
 from .generators import notifies
 from .transaction import AsyncTransaction
 from .cursor_async import AsyncCursor
 from .server_cursor import AsyncServerCursor
+from ._connection_base import BaseConnection, CursorRow, Notify
 
 if TYPE_CHECKING:
     from .pq.abc import PGconn
index 2824287addd2505aff8ab8cf7e16fad9153aa26e..d116273c40fdbbe9f3e0e0b6362e462b3cb8db3e 100644 (file)
@@ -22,7 +22,7 @@ from .._typeinfo import TypeInfo
 from .._encodings import _as_python_identifier
 
 if TYPE_CHECKING:
-    from ..connection import BaseConnection
+    from .._connection_base import BaseConnection
 
 _struct_oidlen = struct.Struct("!Ii")
 _pack_oidlen = cast(Callable[[int, int], bytes], _struct_oidlen.pack)
index 401b0b4c8daff4b31cdce0d660ba402f5125346b..614bfca32591df704d8fe250fbed5273367c9bf9 100644 (file)
@@ -18,7 +18,7 @@ from ..errors import InterfaceError, DataError
 from .._struct import pack_int4, pack_int8, unpack_int4, unpack_int8
 
 if TYPE_CHECKING:
-    from ..connection import BaseConnection
+    from .._connection_base import BaseConnection
 
 _struct_timetz = struct.Struct("!qi")  # microseconds, sec tz offset
 _pack_timetz = cast(Callable[[int, int], bytes], _struct_timetz.pack)
index a1eb62ebe78a6ddc7bc008325fbcb6b0798d81cd..3035214ba36c5899b9e1e18886b003254815beec 100644 (file)
@@ -17,7 +17,7 @@ from .._encodings import conn_encoding
 from .._typeinfo import TypeInfo
 
 if TYPE_CHECKING:
-    from ..connection import BaseConnection
+    from .._connection_base import BaseConnection
 
 E = TypeVar("E", bound=Enum)
 
index 8af08de207b80660d33b21ca28433f565fa1f1b3..d672f6be8d4fb7d1697ddc1402c0419b9d951475 100644 (file)
@@ -25,7 +25,7 @@ from .range import Range, T, load_range_text, load_range_binary
 from .range import dump_range_text, dump_range_binary, fail_dump
 
 if TYPE_CHECKING:
-    from ..connection import BaseConnection
+    from .._connection_base import BaseConnection
 
 
 class MultirangeInfo(TypeInfo):
index 71127919e00a90956747b0dcba43da3cf303ee86..6290ca080d84efef81c020b0e0793d539a57aa64 100644 (file)
@@ -23,7 +23,7 @@ from .._struct import pack_len, unpack_len
 from .._typeinfo import TypeInfo, TypesRegistry
 
 if TYPE_CHECKING:
-    from ..connection import BaseConnection
+    from .._connection_base import BaseConnection
 
 RANGE_EMPTY = 0x01  # range is empty
 RANGE_LB_INC = 0x02  # lower bound is inclusive
index bd7c63d91e4b7412b4a250bda55fec0d273eff25..7d456ba538d8cbcffcd86108d2ff70638e8255b6 100644 (file)
@@ -9,12 +9,10 @@ information. Will submit a bug.
 
 from typing import Any, Iterable, List, Optional, Sequence, Tuple
 
-from psycopg import pq
-from psycopg import abc
+from psycopg import pq, abc, BaseConnection
 from psycopg.rows import Row, RowMaker
 from psycopg.adapt import AdaptersMap, PyFormat
 from psycopg.pq.abc import PGconn, PGresult
-from psycopg.connection import BaseConnection
 from psycopg._compat import Deque
 
 class Transformer(abc.AdaptContext):
index 13823bccd2ac9ae8da4965b6503bafb461a6c262..b0e3081364ac501685371494c1175b556414e976 100644 (file)
@@ -14,7 +14,7 @@ from .errors import PoolClosed
 from ._compat import Counter, Deque
 
 if TYPE_CHECKING:
-    from psycopg.connection import BaseConnection
+    from psycopg._connection_base import BaseConnection
 
 
 class BasePool:
index 73ec2fba847001c9f3d614e8c2187e7fddda8737..914bfa647cdbea1dda98104d83d46e39497dc74c 100644 (file)
@@ -396,7 +396,7 @@ def test_connect_args(conn_cls, monkeypatch, setpgenv, pgconn, args, kwargs, wan
         yield
 
     setpgenv({})
-    monkeypatch.setattr(psycopg.connection, "connect", fake_connect)
+    monkeypatch.setattr(psycopg.generators, "connect", fake_connect)
     conn = conn_cls.connect(*args, **kwargs)
     assert conninfo_to_dict(the_conninfo) == conninfo_to_dict(want)
     conn.close()
@@ -415,7 +415,7 @@ def test_connect_badargs(conn_cls, monkeypatch, pgconn, args, kwargs, exctype):
         return pgconn
         yield
 
-    monkeypatch.setattr(psycopg.connection, "connect", fake_connect)
+    monkeypatch.setattr(psycopg.generators, "connect", fake_connect)
     with pytest.raises(exctype):
         conn_cls.connect(*args, **kwargs)
 
index d336c19de0c4285b5773f41c7abd3ccdf48a0a0c..b9a0c005e239216a0c11427afb33bcf2c0dff66f 100644 (file)
@@ -395,7 +395,7 @@ async def test_connect_args(
         yield
 
     setpgenv({})
-    monkeypatch.setattr(psycopg.connection, "connect", fake_connect)
+    monkeypatch.setattr(psycopg.generators, "connect", fake_connect)
     conn = await aconn_cls.connect(*args, **kwargs)
     assert conninfo_to_dict(the_conninfo) == conninfo_to_dict(want)
     await conn.close()
@@ -414,7 +414,7 @@ async def test_connect_badargs(aconn_cls, monkeypatch, pgconn, args, kwargs, exc
         return pgconn
         yield
 
-    monkeypatch.setattr(psycopg.connection, "connect", fake_connect)
+    monkeypatch.setattr(psycopg.generators, "connect", fake_connect)
     with pytest.raises(exctype):
         await aconn_cls.connect(*args, **kwargs)
 
index 794ef0f89ec6ca2db34de070ff23a6561b7a314b..c6b3e08e312cfd16950166af255fb62943f087db 100644 (file)
@@ -17,7 +17,7 @@ def test_connect(monkeypatch, dsn, args, kwargs, want_conninfo):
     # Details of the params manipulation are in test_conninfo.
     import psycopg.connection
 
-    orig_connect = psycopg.connection.connect  # type: ignore
+    orig_connect = psycopg.generators.connect
 
     got_conninfo = None
 
@@ -26,7 +26,7 @@ def test_connect(monkeypatch, dsn, args, kwargs, want_conninfo):
         got_conninfo = conninfo
         return orig_connect(dsn)
 
-    monkeypatch.setattr(psycopg.connection, "connect", mock_connect)
+    monkeypatch.setattr(psycopg.generators, "connect", mock_connect)
 
     conn = psycopg.connect(*args, **kwargs)
     assert got_conninfo == want_conninfo
index 82a5d730c7cc21b2e0f205ff588bb153e3337ad3..69d4e8d8aa1004d0c50358fe905fd16895e6e027 100644 (file)
@@ -141,7 +141,7 @@ def test_connect_args(monkeypatch, pgconn, args, kwargs, want):
         return pgconn
         yield
 
-    monkeypatch.setattr(psycopg.connection, "connect", fake_connect)
+    monkeypatch.setattr(psycopg.generators, "connect", fake_connect)
     conn = psycopg.connect(*args, **kwargs)
     assert conninfo_to_dict(the_conninfo) == conninfo_to_dict(want)
     conn.close()