From: Daniele Varrazzo Date: Sat, 2 Sep 2023 16:31:18 +0000 (+0100) Subject: refactor: move BaseConnetion in its own module X-Git-Tag: pool-3.2.0~12^2~44 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=47a4b8146e9f1160bf46780ecb0ff3caa07a6aa1;p=thirdparty%2Fpsycopg.git refactor: move BaseConnetion in its own module This commit is brought you by: flight delayed 6 hours for technical failure. --- diff --git a/psycopg/psycopg/__init__.py b/psycopg/psycopg/__init__.py index ff7d398a1..abeb07930 100644 --- a/psycopg/psycopg/__init__.py +++ b/psycopg/psycopg/__init__.py @@ -19,12 +19,13 @@ from .errors import InternalError, ProgrammingError, NotSupportedError from ._column import Column from .conninfo import ConnectionInfo from ._pipeline import Pipeline, AsyncPipeline -from .connection import BaseConnection, Connection, Notify +from .connection import Connection from .transaction import Rollback, Transaction, AsyncTransaction from .cursor_async import AsyncCursor from .server_cursor import AsyncServerCursor, ServerCursor from .client_cursor import AsyncClientCursor, ClientCursor from .raw_cursor import AsyncRawCursor, RawCursor +from ._connection_base import BaseConnection, Notify from .connection_async import AsyncConnection from . import dbapi20 diff --git a/psycopg/psycopg/_adapters_map.py b/psycopg/psycopg/_adapters_map.py index 1c8981f58..fae5cb545 100644 --- a/psycopg/psycopg/_adapters_map.py +++ b/psycopg/psycopg/_adapters_map.py @@ -15,7 +15,7 @@ from ._cmodule import _psycopg from ._typeinfo import TypesRegistry if TYPE_CHECKING: - from .connection import BaseConnection + from ._connection_base import BaseConnection RV = TypeVar("RV") diff --git a/psycopg/psycopg/_connection_base.py b/psycopg/psycopg/_connection_base.py new file mode 100644 index 000000000..fb8db5d91 --- /dev/null +++ b/psycopg/psycopg/_connection_base.py @@ -0,0 +1,648 @@ +""" +psycopg connection objects +""" + +# Copyright (C) 2020 The Psycopg Team + +import logging +from typing import Callable, Generic +from typing import List, NamedTuple, Optional, Type, TypeVar, Tuple, Union +from typing import TYPE_CHECKING +from weakref import ref, ReferenceType +from warnings import warn +from functools import partial +from typing_extensions import TypeAlias + +from . import pq +from . import errors as e +from . import postgres +from . import generators +from .abc import ConnectionType, PQGen, PQGenConn, Query +from .sql import Composable, SQL +from ._tpc import Xid +from .rows import Row +from .adapt import AdaptersMap +from ._enums import IsolationLevel +from ._compat import LiteralString +from .pq.misc import connection_summary +from .conninfo import ConnectionInfo +from ._pipeline import BasePipeline +from ._encodings import pgconn_encoding +from ._preparing import PrepareManager + +if TYPE_CHECKING: + from .pq.abc import PGconn, PGresult + from psycopg_pool.base import BasePool + +# Row Type variable for Cursor (when it needs to be distinguished from the +# connection's one) +CursorRow = TypeVar("CursorRow") + +TEXT = pq.Format.TEXT +BINARY = pq.Format.BINARY + +OK = pq.ConnStatus.OK +BAD = pq.ConnStatus.BAD + +COMMAND_OK = pq.ExecStatus.COMMAND_OK +TUPLES_OK = pq.ExecStatus.TUPLES_OK +FATAL_ERROR = pq.ExecStatus.FATAL_ERROR + +IDLE = pq.TransactionStatus.IDLE +INTRANS = pq.TransactionStatus.INTRANS + +logger = logging.getLogger("psycopg") + + +class Notify(NamedTuple): + """An asynchronous notification received from the database.""" + + channel: str + """The name of the channel on which the notification was received.""" + + payload: str + """The message attached to the notification.""" + + pid: int + """The PID of the backend process which sent the notification.""" + + +Notify.__module__ = "psycopg" + +NoticeHandler: TypeAlias = Callable[[e.Diagnostic], None] +NotifyHandler: TypeAlias = Callable[[Notify], None] + + +class BaseConnection(Generic[Row]): + """ + Base class for different types of connections. + + Share common functionalities such as access to the wrapped PGconn, but + allow different interfaces (sync/async). + """ + + # DBAPI2 exposed exceptions + Warning = e.Warning + Error = e.Error + InterfaceError = e.InterfaceError + DatabaseError = e.DatabaseError + DataError = e.DataError + OperationalError = e.OperationalError + IntegrityError = e.IntegrityError + InternalError = e.InternalError + ProgrammingError = e.ProgrammingError + NotSupportedError = e.NotSupportedError + + # Enums useful for the connection + ConnStatus = pq.ConnStatus + TransactionStatus = pq.TransactionStatus + + def __init__(self, pgconn: "PGconn"): + self.pgconn = pgconn + self._autocommit = False + + # None, but set to a copy of the global adapters map as soon as requested. + self._adapters: Optional[AdaptersMap] = None + + self._notice_handlers: List[NoticeHandler] = [] + self._notify_handlers: List[NotifyHandler] = [] + + # Number of transaction blocks currently entered + self._num_transactions = 0 + + self._closed = False # closed by an explicit close() + self._prepared: PrepareManager = PrepareManager() + self._tpc: Optional[Tuple[Xid, bool]] = None # xid, prepared + + wself = ref(self) + pgconn.notice_handler = partial(BaseConnection._notice_handler, wself) + pgconn.notify_handler = partial(BaseConnection._notify_handler, wself) + + # Attribute is only set if the connection is from a pool so we can tell + # apart a connection in the pool too (when _pool = None) + self._pool: Optional["BasePool"] + + self._pipeline: Optional[BasePipeline] = None + + # Time after which the connection should be closed + self._expire_at: float + + self._isolation_level: Optional[IsolationLevel] = None + self._read_only: Optional[bool] = None + self._deferrable: Optional[bool] = None + self._begin_statement = b"" + + def __del__(self) -> None: + # If fails on connection we might not have this attribute yet + if not hasattr(self, "pgconn"): + return + + # Connection correctly closed + if self.closed: + return + + # Connection in a pool so terminating with the program is normal + if hasattr(self, "_pool"): + return + + warn( + f"connection {self} was deleted while still open." + " Please use 'with' or '.close()' to close the connection", + ResourceWarning, + ) + + def __repr__(self) -> str: + cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}" + info = connection_summary(self.pgconn) + return f"<{cls} {info} at 0x{id(self):x}>" + + @property + def closed(self) -> bool: + """`!True` if the connection is closed.""" + return self.pgconn.status == BAD + + @property + def broken(self) -> bool: + """ + `!True` if the connection was interrupted. + + A broken connection is always `closed`, but wasn't closed in a clean + way, such as using `close()` or a `!with` block. + """ + return self.pgconn.status == BAD and not self._closed + + @property + def autocommit(self) -> bool: + """The autocommit state of the connection.""" + return self._autocommit + + @autocommit.setter + def autocommit(self, value: bool) -> None: + self._set_autocommit(value) + + def _set_autocommit(self, value: bool) -> None: + raise NotImplementedError + + def _set_autocommit_gen(self, value: bool) -> PQGen[None]: + yield from self._check_intrans_gen("autocommit") + self._autocommit = bool(value) + + @property + def isolation_level(self) -> Optional[IsolationLevel]: + """ + The isolation level of the new transactions started on the connection. + """ + return self._isolation_level + + @isolation_level.setter + def isolation_level(self, value: Optional[IsolationLevel]) -> None: + self._set_isolation_level(value) + + def _set_isolation_level(self, value: Optional[IsolationLevel]) -> None: + raise NotImplementedError + + def _set_isolation_level_gen(self, value: Optional[IsolationLevel]) -> PQGen[None]: + yield from self._check_intrans_gen("isolation_level") + self._isolation_level = IsolationLevel(value) if value is not None else None + self._begin_statement = b"" + + @property + def read_only(self) -> Optional[bool]: + """ + The read-only state of the new transactions started on the connection. + """ + return self._read_only + + @read_only.setter + def read_only(self, value: Optional[bool]) -> None: + self._set_read_only(value) + + def _set_read_only(self, value: Optional[bool]) -> None: + raise NotImplementedError + + def _set_read_only_gen(self, value: Optional[bool]) -> PQGen[None]: + yield from self._check_intrans_gen("read_only") + self._read_only = bool(value) if value is not None else None + self._begin_statement = b"" + + @property + def deferrable(self) -> Optional[bool]: + """ + The deferrable state of the new transactions started on the connection. + """ + return self._deferrable + + @deferrable.setter + def deferrable(self, value: Optional[bool]) -> None: + self._set_deferrable(value) + + def _set_deferrable(self, value: Optional[bool]) -> None: + raise NotImplementedError + + def _set_deferrable_gen(self, value: Optional[bool]) -> PQGen[None]: + yield from self._check_intrans_gen("deferrable") + self._deferrable = bool(value) if value is not None else None + self._begin_statement = b"" + + def _check_intrans_gen(self, attribute: str) -> PQGen[None]: + # Raise an exception if we are in a transaction + status = self.pgconn.transaction_status + if status == IDLE and self._pipeline: + yield from self._pipeline._sync_gen() + status = self.pgconn.transaction_status + if status != IDLE: + if self._num_transactions: + raise e.ProgrammingError( + f"can't change {attribute!r} now: " + "connection.transaction() context in progress" + ) + else: + raise e.ProgrammingError( + f"can't change {attribute!r} now: " + "connection in transaction status " + f"{pq.TransactionStatus(status).name}" + ) + + @property + def info(self) -> ConnectionInfo: + """A `ConnectionInfo` attribute to inspect connection properties.""" + return ConnectionInfo(self.pgconn) + + @property + def adapters(self) -> AdaptersMap: + if not self._adapters: + self._adapters = AdaptersMap(postgres.adapters) + + return self._adapters + + @property + def connection(self) -> "BaseConnection[Row]": + # implement the AdaptContext protocol + return self + + def fileno(self) -> int: + """Return the file descriptor of the connection. + + This function allows to use the connection as file-like object in + functions waiting for readiness, such as the ones defined in the + `selectors` module. + """ + return self.pgconn.socket + + def cancel(self) -> None: + """Cancel the current operation on the connection.""" + # No-op if the connection is closed + # this allows to use the method as callback handler without caring + # about its life. + if self.closed: + return + + if self._tpc and self._tpc[1]: + raise e.ProgrammingError( + "cancel() cannot be used with a prepared two-phase transaction" + ) + + self._try_cancel(self.pgconn) + + @classmethod + def _try_cancel(cls, pgconn: "PGconn") -> None: + try: + # Can fail if the connection is closed + c = pgconn.get_cancel() + except Exception as ex: + logger.warning("couldn't try to cancel query: %s", ex) + else: + c.cancel() + + def add_notice_handler(self, callback: NoticeHandler) -> None: + """ + Register a callable to be invoked when a notice message is received. + + :param callback: the callback to call upon message received. + :type callback: Callable[[~psycopg.errors.Diagnostic], None] + """ + self._notice_handlers.append(callback) + + def remove_notice_handler(self, callback: NoticeHandler) -> None: + """ + Unregister a notice message callable previously registered. + + :param callback: the callback to remove. + :type callback: Callable[[~psycopg.errors.Diagnostic], None] + """ + self._notice_handlers.remove(callback) + + @staticmethod + def _notice_handler( + wself: "ReferenceType[BaseConnection[Row]]", res: "PGresult" + ) -> None: + self = wself() + if not (self and self._notice_handlers): + return + + diag = e.Diagnostic(res, pgconn_encoding(self.pgconn)) + for cb in self._notice_handlers: + try: + cb(diag) + except Exception as ex: + logger.exception("error processing notice callback '%s': %s", cb, ex) + + def add_notify_handler(self, callback: NotifyHandler) -> None: + """ + Register a callable to be invoked whenever a notification is received. + + :param callback: the callback to call upon notification received. + :type callback: Callable[[~psycopg.Notify], None] + """ + self._notify_handlers.append(callback) + + def remove_notify_handler(self, callback: NotifyHandler) -> None: + """ + Unregister a notification callable previously registered. + + :param callback: the callback to remove. + :type callback: Callable[[~psycopg.Notify], None] + """ + self._notify_handlers.remove(callback) + + @staticmethod + def _notify_handler( + wself: "ReferenceType[BaseConnection[Row]]", pgn: pq.PGnotify + ) -> None: + self = wself() + if not (self and self._notify_handlers): + return + + enc = pgconn_encoding(self.pgconn) + n = Notify(pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid) + for cb in self._notify_handlers: + cb(n) + + @property + def prepare_threshold(self) -> Optional[int]: + """ + Number of times a query is executed before it is prepared. + + - If it is set to 0, every query is prepared the first time it is + executed. + - If it is set to `!None`, prepared statements are disabled on the + connection. + + Default value: 5 + """ + return self._prepared.prepare_threshold + + @prepare_threshold.setter + def prepare_threshold(self, value: Optional[int]) -> None: + self._prepared.prepare_threshold = value + + @property + def prepared_max(self) -> int: + """ + Maximum number of prepared statements on the connection. + + Default value: 100 + """ + return self._prepared.prepared_max + + @prepared_max.setter + def prepared_max(self, value: int) -> None: + self._prepared.prepared_max = value + + # Generators to perform high-level operations on the connection + # + # These operations are expressed in terms of non-blocking generators + # and the task of waiting when needed (when the generators yield) is left + # to the connections subclass, which might wait either in blocking mode + # or through asyncio. + # + # All these generators assume exclusive access to the connection: subclasses + # should have a lock and hold it before calling and consuming them. + + @classmethod + def _connect_gen( + cls: Type[ConnectionType], + conninfo: str = "", + *, + autocommit: bool = False, + ) -> PQGenConn[ConnectionType]: + """Generator to connect to the database and create a new instance.""" + pgconn = yield from generators.connect(conninfo) + conn = cls(pgconn) + conn._autocommit = bool(autocommit) + return conn + + def _exec_command( + self, command: Query, result_format: pq.Format = TEXT + ) -> PQGen[Optional["PGresult"]]: + """ + Generator to send a command and receive the result to the backend. + + Only used to implement internal commands such as "commit", with eventual + arguments bound client-side. The cursor can do more complex stuff. + """ + self._check_connection_ok() + + if isinstance(command, str): + command = command.encode(pgconn_encoding(self.pgconn)) + elif isinstance(command, Composable): + command = command.as_bytes(self) + + if self._pipeline: + cmd = partial( + self.pgconn.send_query_params, + command, + None, + result_format=result_format, + ) + self._pipeline.command_queue.append(cmd) + self._pipeline.result_queue.append(None) + return None + + self.pgconn.send_query_params(command, None, result_format=result_format) + + result = (yield from generators.execute(self.pgconn))[-1] + if result.status != COMMAND_OK and result.status != TUPLES_OK: + if result.status == FATAL_ERROR: + raise e.error_from_result(result, encoding=pgconn_encoding(self.pgconn)) + else: + raise e.InterfaceError( + f"unexpected result {pq.ExecStatus(result.status).name}" + f" from command {command.decode()!r}" + ) + return result + + def _check_connection_ok(self) -> None: + if self.pgconn.status == OK: + return + + if self.pgconn.status == BAD: + raise e.OperationalError("the connection is closed") + raise e.InterfaceError( + "cannot execute operations: the connection is" + f" in status {self.pgconn.status}" + ) + + def _start_query(self) -> PQGen[None]: + """Generator to start a transaction if necessary.""" + if self._autocommit: + return + + if self.pgconn.transaction_status != IDLE: + return + + yield from self._exec_command(self._get_tx_start_command()) + if self._pipeline: + yield from self._pipeline._sync_gen() + + def _get_tx_start_command(self) -> bytes: + if self._begin_statement: + return self._begin_statement + + parts = [b"BEGIN"] + + if self.isolation_level is not None: + val = IsolationLevel(self.isolation_level) + parts.append(b"ISOLATION LEVEL") + parts.append(val.name.replace("_", " ").encode()) + + if self.read_only is not None: + parts.append(b"READ ONLY" if self.read_only else b"READ WRITE") + + if self.deferrable is not None: + parts.append(b"DEFERRABLE" if self.deferrable else b"NOT DEFERRABLE") + + self._begin_statement = b" ".join(parts) + return self._begin_statement + + def _commit_gen(self) -> PQGen[None]: + """Generator implementing `Connection.commit()`.""" + if self._num_transactions: + raise e.ProgrammingError( + "Explicit commit() forbidden within a Transaction " + "context. (Transaction will be automatically committed " + "on successful exit from context.)" + ) + if self._tpc: + raise e.ProgrammingError( + "commit() cannot be used during a two-phase transaction" + ) + if self.pgconn.transaction_status == IDLE: + return + + yield from self._exec_command(b"COMMIT") + + if self._pipeline: + yield from self._pipeline._sync_gen() + + def _rollback_gen(self) -> PQGen[None]: + """Generator implementing `Connection.rollback()`.""" + if self._num_transactions: + raise e.ProgrammingError( + "Explicit rollback() forbidden within a Transaction " + "context. (Either raise Rollback() or allow " + "an exception to propagate out of the context.)" + ) + if self._tpc: + raise e.ProgrammingError( + "rollback() cannot be used during a two-phase transaction" + ) + + # Get out of a "pipeline aborted" state + if self._pipeline: + yield from self._pipeline._sync_gen() + + if self.pgconn.transaction_status == IDLE: + return + + yield from self._exec_command(b"ROLLBACK") + self._prepared.clear() + for cmd in self._prepared.get_maintenance_commands(): + yield from self._exec_command(cmd) + + if self._pipeline: + yield from self._pipeline._sync_gen() + + def xid(self, format_id: int, gtrid: str, bqual: str) -> Xid: + """ + Returns a `Xid` to pass to the `!tpc_*()` methods of this connection. + + The argument types and constraints are explained in + :ref:`two-phase-commit`. + + The values passed to the method will be available on the returned + object as the members `~Xid.format_id`, `~Xid.gtrid`, `~Xid.bqual`. + """ + self._check_tpc() + return Xid.from_parts(format_id, gtrid, bqual) + + def _tpc_begin_gen(self, xid: Union[Xid, str]) -> PQGen[None]: + self._check_tpc() + + if not isinstance(xid, Xid): + xid = Xid.from_string(xid) + + if self.pgconn.transaction_status != IDLE: + raise e.ProgrammingError( + "can't start two-phase transaction: connection in status" + f" {pq.TransactionStatus(self.pgconn.transaction_status).name}" + ) + + if self._autocommit: + raise e.ProgrammingError( + "can't use two-phase transactions in autocommit mode" + ) + + self._tpc = (xid, False) + yield from self._exec_command(self._get_tx_start_command()) + + def _tpc_prepare_gen(self) -> PQGen[None]: + if not self._tpc: + raise e.ProgrammingError( + "'tpc_prepare()' must be called inside a two-phase transaction" + ) + if self._tpc[1]: + raise e.ProgrammingError( + "'tpc_prepare()' cannot be used during a prepared two-phase transaction" + ) + xid = self._tpc[0] + self._tpc = (xid, True) + yield from self._exec_command(SQL("PREPARE TRANSACTION {}").format(str(xid))) + if self._pipeline: + yield from self._pipeline._sync_gen() + + def _tpc_finish_gen( + self, action: LiteralString, xid: Union[Xid, str, None] + ) -> PQGen[None]: + fname = f"tpc_{action.lower()}()" + if xid is None: + if not self._tpc: + raise e.ProgrammingError( + f"{fname} without xid must must be" + " called inside a two-phase transaction" + ) + xid = self._tpc[0] + else: + if self._tpc: + raise e.ProgrammingError( + f"{fname} with xid must must be called" + " outside a two-phase transaction" + ) + if not isinstance(xid, Xid): + xid = Xid.from_string(xid) + + if self._tpc and not self._tpc[1]: + meth: Callable[[], PQGen[None]] + meth = getattr(self, f"_{action.lower()}_gen") + self._tpc = None + yield from meth() + else: + yield from self._exec_command( + SQL("{} PREPARED {}").format(SQL(action), str(xid)) + ) + self._tpc = None + + def _check_tpc(self) -> None: + """Raise NotSupportedError if TPC is not supported.""" + # TPC supported on every supported PostgreSQL version. + pass diff --git a/psycopg/psycopg/_encodings.py b/psycopg/psycopg/_encodings.py index 876acb975..3f14c2608 100644 --- a/psycopg/psycopg/_encodings.py +++ b/psycopg/psycopg/_encodings.py @@ -15,7 +15,7 @@ from ._compat import cache if TYPE_CHECKING: from .pq.abc import PGconn - from .connection import BaseConnection + from ._connection_base import BaseConnection OK = ConnStatus.OK diff --git a/psycopg/psycopg/_pipeline.py b/psycopg/psycopg/_pipeline.py index 2223f491a..ff7228eee 100644 --- a/psycopg/psycopg/_pipeline.py +++ b/psycopg/psycopg/_pipeline.py @@ -21,7 +21,8 @@ from .generators import pipeline_communicate, fetch_many, send if TYPE_CHECKING: from .pq.abc import PGresult from ._cursor_base import BaseCursor - from .connection import BaseConnection, Connection + from .connection import Connection + from ._connection_base import BaseConnection from .connection_async import AsyncConnection diff --git a/psycopg/psycopg/_py_transformer.py b/psycopg/psycopg/_py_transformer.py index 0438725c3..7174893c5 100644 --- a/psycopg/psycopg/_py_transformer.py +++ b/psycopg/psycopg/_py_transformer.py @@ -25,7 +25,7 @@ from ._encodings import pgconn_encoding if TYPE_CHECKING: from .adapt import AdaptersMap from .pq.abc import PGresult - from .connection import BaseConnection + from ._connection_base import BaseConnection DumperCache: TypeAlias = Dict[DumperKey, abc.Dumper] OidDumperCache: TypeAlias = Dict[int, abc.Dumper] diff --git a/psycopg/psycopg/_typeinfo.py b/psycopg/psycopg/_typeinfo.py index dcbb2c095..141ed0427 100644 --- a/psycopg/psycopg/_typeinfo.py +++ b/psycopg/psycopg/_typeinfo.py @@ -18,8 +18,9 @@ from .rows import dict_row from ._encodings import conn_encoding if TYPE_CHECKING: - from .connection import BaseConnection, Connection + from .connection import Connection from .connection_async import AsyncConnection + from ._connection_base import BaseConnection T = TypeVar("T", bound="TypeInfo") RegistryKey: TypeAlias = Union[str, int, Tuple[type, int]] diff --git a/psycopg/psycopg/abc.py b/psycopg/psycopg/abc.py index 0ee2037ac..8edcdc107 100644 --- a/psycopg/psycopg/abc.py +++ b/psycopg/psycopg/abc.py @@ -18,8 +18,8 @@ if TYPE_CHECKING: from .rows import Row, RowMaker from .pq.abc import PGresult from .waiting import Wait, Ready - from .connection import BaseConnection from ._adapters_map import AdaptersMap + from ._connection_base import BaseConnection NoneType: type = type(None) diff --git a/psycopg/psycopg/adapt.py b/psycopg/psycopg/adapt.py index f09bd6a84..31a710429 100644 --- a/psycopg/psycopg/adapt.py +++ b/psycopg/psycopg/adapt.py @@ -15,7 +15,7 @@ from ._transformer import Transformer as Transformer from ._adapters_map import AdaptersMap as AdaptersMap # noqa: F401 if TYPE_CHECKING: - from .connection import BaseConnection + from ._connection_base import BaseConnection Buffer = abc.Buffer diff --git a/psycopg/psycopg/connection.py b/psycopg/psycopg/connection.py index a6571d5d7..73724f616 100644 --- a/psycopg/psycopg/connection.py +++ b/psycopg/psycopg/connection.py @@ -7,656 +7,41 @@ psycopg connection objects import logging import threading from types import TracebackType -from typing import Any, Callable, cast, Dict, Generator, Generic, Iterator -from typing import List, NamedTuple, Optional, Type, TypeVar, Tuple, Union +from typing import Any, cast, Dict, Generator, Iterator +from typing import List, Optional, Type, TypeVar, Union from typing import overload, TYPE_CHECKING -from weakref import ref, ReferenceType -from warnings import warn -from functools import partial from contextlib import contextmanager -from typing_extensions import TypeAlias from . import pq from . import errors as e from . import waiting -from . import postgres -from .abc import AdaptContext, ConnectionType, Params, Query, RV +from .abc import AdaptContext, Params, Query, RV from .abc import PQGen, PQGenConn -from .sql import Composable, SQL from ._tpc import Xid from .rows import Row, RowFactory, tuple_row, TupleRow, args_row from .adapt import AdaptersMap from ._enums import IsolationLevel from .cursor import Cursor -from ._compat import LiteralString -from .pq.misc import connection_summary -from .conninfo import make_conninfo, conninfo_to_dict, ConnectionInfo -from ._pipeline import BasePipeline, Pipeline -from .generators import notifies, connect, execute +from .conninfo import make_conninfo, conninfo_to_dict +from ._pipeline import Pipeline +from .generators import notifies from ._encodings import pgconn_encoding -from ._preparing import PrepareManager from .transaction import Transaction from .server_cursor import ServerCursor +from ._connection_base import BaseConnection, CursorRow, Notify if TYPE_CHECKING: - from .pq.abc import PGconn, PGresult - from psycopg_pool.base import BasePool - - -# Row Type variable for Cursor (when it needs to be distinguished from the -# connection's one) -CursorRow = TypeVar("CursorRow") + from .pq.abc import PGconn TEXT = pq.Format.TEXT BINARY = pq.Format.BINARY -OK = pq.ConnStatus.OK -BAD = pq.ConnStatus.BAD - -COMMAND_OK = pq.ExecStatus.COMMAND_OK -TUPLES_OK = pq.ExecStatus.TUPLES_OK -FATAL_ERROR = pq.ExecStatus.FATAL_ERROR - IDLE = pq.TransactionStatus.IDLE INTRANS = pq.TransactionStatus.INTRANS logger = logging.getLogger("psycopg") -class Notify(NamedTuple): - """An asynchronous notification received from the database.""" - - channel: str - """The name of the channel on which the notification was received.""" - - payload: str - """The message attached to the notification.""" - - pid: int - """The PID of the backend process which sent the notification.""" - - -Notify.__module__ = "psycopg" - -NoticeHandler: TypeAlias = Callable[[e.Diagnostic], None] -NotifyHandler: TypeAlias = Callable[[Notify], None] - - -class BaseConnection(Generic[Row]): - """ - Base class for different types of connections. - - Share common functionalities such as access to the wrapped PGconn, but - allow different interfaces (sync/async). - """ - - # DBAPI2 exposed exceptions - Warning = e.Warning - Error = e.Error - InterfaceError = e.InterfaceError - DatabaseError = e.DatabaseError - DataError = e.DataError - OperationalError = e.OperationalError - IntegrityError = e.IntegrityError - InternalError = e.InternalError - ProgrammingError = e.ProgrammingError - NotSupportedError = e.NotSupportedError - - # Enums useful for the connection - ConnStatus = pq.ConnStatus - TransactionStatus = pq.TransactionStatus - - def __init__(self, pgconn: "PGconn"): - self.pgconn = pgconn - self._autocommit = False - - # None, but set to a copy of the global adapters map as soon as requested. - self._adapters: Optional[AdaptersMap] = None - - self._notice_handlers: List[NoticeHandler] = [] - self._notify_handlers: List[NotifyHandler] = [] - - # Number of transaction blocks currently entered - self._num_transactions = 0 - - self._closed = False # closed by an explicit close() - self._prepared: PrepareManager = PrepareManager() - self._tpc: Optional[Tuple[Xid, bool]] = None # xid, prepared - - wself = ref(self) - pgconn.notice_handler = partial(BaseConnection._notice_handler, wself) - pgconn.notify_handler = partial(BaseConnection._notify_handler, wself) - - # Attribute is only set if the connection is from a pool so we can tell - # apart a connection in the pool too (when _pool = None) - self._pool: Optional["BasePool"] - - self._pipeline: Optional[BasePipeline] = None - - # Time after which the connection should be closed - self._expire_at: float - - self._isolation_level: Optional[IsolationLevel] = None - self._read_only: Optional[bool] = None - self._deferrable: Optional[bool] = None - self._begin_statement = b"" - - def __del__(self) -> None: - # If fails on connection we might not have this attribute yet - if not hasattr(self, "pgconn"): - return - - # Connection correctly closed - if self.closed: - return - - # Connection in a pool so terminating with the program is normal - if hasattr(self, "_pool"): - return - - warn( - f"connection {self} was deleted while still open." - " Please use 'with' or '.close()' to close the connection", - ResourceWarning, - ) - - def __repr__(self) -> str: - cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}" - info = connection_summary(self.pgconn) - return f"<{cls} {info} at 0x{id(self):x}>" - - @property - def closed(self) -> bool: - """`!True` if the connection is closed.""" - return self.pgconn.status == BAD - - @property - def broken(self) -> bool: - """ - `!True` if the connection was interrupted. - - A broken connection is always `closed`, but wasn't closed in a clean - way, such as using `close()` or a `!with` block. - """ - return self.pgconn.status == BAD and not self._closed - - @property - def autocommit(self) -> bool: - """The autocommit state of the connection.""" - return self._autocommit - - @autocommit.setter - def autocommit(self, value: bool) -> None: - self._set_autocommit(value) - - def _set_autocommit(self, value: bool) -> None: - raise NotImplementedError - - def _set_autocommit_gen(self, value: bool) -> PQGen[None]: - yield from self._check_intrans_gen("autocommit") - self._autocommit = bool(value) - - @property - def isolation_level(self) -> Optional[IsolationLevel]: - """ - The isolation level of the new transactions started on the connection. - """ - return self._isolation_level - - @isolation_level.setter - def isolation_level(self, value: Optional[IsolationLevel]) -> None: - self._set_isolation_level(value) - - def _set_isolation_level(self, value: Optional[IsolationLevel]) -> None: - raise NotImplementedError - - def _set_isolation_level_gen(self, value: Optional[IsolationLevel]) -> PQGen[None]: - yield from self._check_intrans_gen("isolation_level") - self._isolation_level = IsolationLevel(value) if value is not None else None - self._begin_statement = b"" - - @property - def read_only(self) -> Optional[bool]: - """ - The read-only state of the new transactions started on the connection. - """ - return self._read_only - - @read_only.setter - def read_only(self, value: Optional[bool]) -> None: - self._set_read_only(value) - - def _set_read_only(self, value: Optional[bool]) -> None: - raise NotImplementedError - - def _set_read_only_gen(self, value: Optional[bool]) -> PQGen[None]: - yield from self._check_intrans_gen("read_only") - self._read_only = bool(value) if value is not None else None - self._begin_statement = b"" - - @property - def deferrable(self) -> Optional[bool]: - """ - The deferrable state of the new transactions started on the connection. - """ - return self._deferrable - - @deferrable.setter - def deferrable(self, value: Optional[bool]) -> None: - self._set_deferrable(value) - - def _set_deferrable(self, value: Optional[bool]) -> None: - raise NotImplementedError - - def _set_deferrable_gen(self, value: Optional[bool]) -> PQGen[None]: - yield from self._check_intrans_gen("deferrable") - self._deferrable = bool(value) if value is not None else None - self._begin_statement = b"" - - def _check_intrans_gen(self, attribute: str) -> PQGen[None]: - # Raise an exception if we are in a transaction - status = self.pgconn.transaction_status - if status == IDLE and self._pipeline: - yield from self._pipeline._sync_gen() - status = self.pgconn.transaction_status - if status != IDLE: - if self._num_transactions: - raise e.ProgrammingError( - f"can't change {attribute!r} now: " - "connection.transaction() context in progress" - ) - else: - raise e.ProgrammingError( - f"can't change {attribute!r} now: " - "connection in transaction status " - f"{pq.TransactionStatus(status).name}" - ) - - @property - def info(self) -> ConnectionInfo: - """A `ConnectionInfo` attribute to inspect connection properties.""" - return ConnectionInfo(self.pgconn) - - @property - def adapters(self) -> AdaptersMap: - if not self._adapters: - self._adapters = AdaptersMap(postgres.adapters) - - return self._adapters - - @property - def connection(self) -> "BaseConnection[Row]": - # implement the AdaptContext protocol - return self - - def fileno(self) -> int: - """Return the file descriptor of the connection. - - This function allows to use the connection as file-like object in - functions waiting for readiness, such as the ones defined in the - `selectors` module. - """ - return self.pgconn.socket - - def cancel(self) -> None: - """Cancel the current operation on the connection.""" - # No-op if the connection is closed - # this allows to use the method as callback handler without caring - # about its life. - if self.closed: - return - - if self._tpc and self._tpc[1]: - raise e.ProgrammingError( - "cancel() cannot be used with a prepared two-phase transaction" - ) - - self._try_cancel(self.pgconn) - - @classmethod - def _try_cancel(cls, pgconn: "PGconn") -> None: - try: - # Can fail if the connection is closed - c = pgconn.get_cancel() - except Exception as ex: - logger.warning("couldn't try to cancel query: %s", ex) - else: - c.cancel() - - def add_notice_handler(self, callback: NoticeHandler) -> None: - """ - Register a callable to be invoked when a notice message is received. - - :param callback: the callback to call upon message received. - :type callback: Callable[[~psycopg.errors.Diagnostic], None] - """ - self._notice_handlers.append(callback) - - def remove_notice_handler(self, callback: NoticeHandler) -> None: - """ - Unregister a notice message callable previously registered. - - :param callback: the callback to remove. - :type callback: Callable[[~psycopg.errors.Diagnostic], None] - """ - self._notice_handlers.remove(callback) - - @staticmethod - def _notice_handler( - wself: "ReferenceType[BaseConnection[Row]]", res: "PGresult" - ) -> None: - self = wself() - if not (self and self._notice_handlers): - return - - diag = e.Diagnostic(res, pgconn_encoding(self.pgconn)) - for cb in self._notice_handlers: - try: - cb(diag) - except Exception as ex: - logger.exception("error processing notice callback '%s': %s", cb, ex) - - def add_notify_handler(self, callback: NotifyHandler) -> None: - """ - Register a callable to be invoked whenever a notification is received. - - :param callback: the callback to call upon notification received. - :type callback: Callable[[~psycopg.Notify], None] - """ - self._notify_handlers.append(callback) - - def remove_notify_handler(self, callback: NotifyHandler) -> None: - """ - Unregister a notification callable previously registered. - - :param callback: the callback to remove. - :type callback: Callable[[~psycopg.Notify], None] - """ - self._notify_handlers.remove(callback) - - @staticmethod - def _notify_handler( - wself: "ReferenceType[BaseConnection[Row]]", pgn: pq.PGnotify - ) -> None: - self = wself() - if not (self and self._notify_handlers): - return - - enc = pgconn_encoding(self.pgconn) - n = Notify(pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid) - for cb in self._notify_handlers: - cb(n) - - @property - def prepare_threshold(self) -> Optional[int]: - """ - Number of times a query is executed before it is prepared. - - - If it is set to 0, every query is prepared the first time it is - executed. - - If it is set to `!None`, prepared statements are disabled on the - connection. - - Default value: 5 - """ - return self._prepared.prepare_threshold - - @prepare_threshold.setter - def prepare_threshold(self, value: Optional[int]) -> None: - self._prepared.prepare_threshold = value - - @property - def prepared_max(self) -> int: - """ - Maximum number of prepared statements on the connection. - - Default value: 100 - """ - return self._prepared.prepared_max - - @prepared_max.setter - def prepared_max(self, value: int) -> None: - self._prepared.prepared_max = value - - # Generators to perform high-level operations on the connection - # - # These operations are expressed in terms of non-blocking generators - # and the task of waiting when needed (when the generators yield) is left - # to the connections subclass, which might wait either in blocking mode - # or through asyncio. - # - # All these generators assume exclusive access to the connection: subclasses - # should have a lock and hold it before calling and consuming them. - - @classmethod - def _connect_gen( - cls: Type[ConnectionType], - conninfo: str = "", - *, - autocommit: bool = False, - ) -> PQGenConn[ConnectionType]: - """Generator to connect to the database and create a new instance.""" - pgconn = yield from connect(conninfo) - conn = cls(pgconn) - conn._autocommit = bool(autocommit) - return conn - - def _exec_command( - self, command: Query, result_format: pq.Format = TEXT - ) -> PQGen[Optional["PGresult"]]: - """ - Generator to send a command and receive the result to the backend. - - Only used to implement internal commands such as "commit", with eventual - arguments bound client-side. The cursor can do more complex stuff. - """ - self._check_connection_ok() - - if isinstance(command, str): - command = command.encode(pgconn_encoding(self.pgconn)) - elif isinstance(command, Composable): - command = command.as_bytes(self) - - if self._pipeline: - cmd = partial( - self.pgconn.send_query_params, - command, - None, - result_format=result_format, - ) - self._pipeline.command_queue.append(cmd) - self._pipeline.result_queue.append(None) - return None - - self.pgconn.send_query_params(command, None, result_format=result_format) - - result = (yield from execute(self.pgconn))[-1] - if result.status != COMMAND_OK and result.status != TUPLES_OK: - if result.status == FATAL_ERROR: - raise e.error_from_result(result, encoding=pgconn_encoding(self.pgconn)) - else: - raise e.InterfaceError( - f"unexpected result {pq.ExecStatus(result.status).name}" - f" from command {command.decode()!r}" - ) - return result - - def _check_connection_ok(self) -> None: - if self.pgconn.status == OK: - return - - if self.pgconn.status == BAD: - raise e.OperationalError("the connection is closed") - raise e.InterfaceError( - "cannot execute operations: the connection is" - f" in status {self.pgconn.status}" - ) - - def _start_query(self) -> PQGen[None]: - """Generator to start a transaction if necessary.""" - if self._autocommit: - return - - if self.pgconn.transaction_status != IDLE: - return - - yield from self._exec_command(self._get_tx_start_command()) - if self._pipeline: - yield from self._pipeline._sync_gen() - - def _get_tx_start_command(self) -> bytes: - if self._begin_statement: - return self._begin_statement - - parts = [b"BEGIN"] - - if self.isolation_level is not None: - val = IsolationLevel(self.isolation_level) - parts.append(b"ISOLATION LEVEL") - parts.append(val.name.replace("_", " ").encode()) - - if self.read_only is not None: - parts.append(b"READ ONLY" if self.read_only else b"READ WRITE") - - if self.deferrable is not None: - parts.append(b"DEFERRABLE" if self.deferrable else b"NOT DEFERRABLE") - - self._begin_statement = b" ".join(parts) - return self._begin_statement - - def _commit_gen(self) -> PQGen[None]: - """Generator implementing `Connection.commit()`.""" - if self._num_transactions: - raise e.ProgrammingError( - "Explicit commit() forbidden within a Transaction " - "context. (Transaction will be automatically committed " - "on successful exit from context.)" - ) - if self._tpc: - raise e.ProgrammingError( - "commit() cannot be used during a two-phase transaction" - ) - if self.pgconn.transaction_status == IDLE: - return - - yield from self._exec_command(b"COMMIT") - - if self._pipeline: - yield from self._pipeline._sync_gen() - - def _rollback_gen(self) -> PQGen[None]: - """Generator implementing `Connection.rollback()`.""" - if self._num_transactions: - raise e.ProgrammingError( - "Explicit rollback() forbidden within a Transaction " - "context. (Either raise Rollback() or allow " - "an exception to propagate out of the context.)" - ) - if self._tpc: - raise e.ProgrammingError( - "rollback() cannot be used during a two-phase transaction" - ) - - # Get out of a "pipeline aborted" state - if self._pipeline: - yield from self._pipeline._sync_gen() - - if self.pgconn.transaction_status == IDLE: - return - - yield from self._exec_command(b"ROLLBACK") - self._prepared.clear() - for cmd in self._prepared.get_maintenance_commands(): - yield from self._exec_command(cmd) - - if self._pipeline: - yield from self._pipeline._sync_gen() - - def xid(self, format_id: int, gtrid: str, bqual: str) -> Xid: - """ - Returns a `Xid` to pass to the `!tpc_*()` methods of this connection. - - The argument types and constraints are explained in - :ref:`two-phase-commit`. - - The values passed to the method will be available on the returned - object as the members `~Xid.format_id`, `~Xid.gtrid`, `~Xid.bqual`. - """ - self._check_tpc() - return Xid.from_parts(format_id, gtrid, bqual) - - def _tpc_begin_gen(self, xid: Union[Xid, str]) -> PQGen[None]: - self._check_tpc() - - if not isinstance(xid, Xid): - xid = Xid.from_string(xid) - - if self.pgconn.transaction_status != IDLE: - raise e.ProgrammingError( - "can't start two-phase transaction: connection in status" - f" {pq.TransactionStatus(self.pgconn.transaction_status).name}" - ) - - if self._autocommit: - raise e.ProgrammingError( - "can't use two-phase transactions in autocommit mode" - ) - - self._tpc = (xid, False) - yield from self._exec_command(self._get_tx_start_command()) - - def _tpc_prepare_gen(self) -> PQGen[None]: - if not self._tpc: - raise e.ProgrammingError( - "'tpc_prepare()' must be called inside a two-phase transaction" - ) - if self._tpc[1]: - raise e.ProgrammingError( - "'tpc_prepare()' cannot be used during a prepared two-phase transaction" - ) - xid = self._tpc[0] - self._tpc = (xid, True) - yield from self._exec_command(SQL("PREPARE TRANSACTION {}").format(str(xid))) - if self._pipeline: - yield from self._pipeline._sync_gen() - - def _tpc_finish_gen( - self, action: LiteralString, xid: Union[Xid, str, None] - ) -> PQGen[None]: - fname = f"tpc_{action.lower()}()" - if xid is None: - if not self._tpc: - raise e.ProgrammingError( - f"{fname} without xid must must be" - " called inside a two-phase transaction" - ) - xid = self._tpc[0] - else: - if self._tpc: - raise e.ProgrammingError( - f"{fname} with xid must must be called" - " outside a two-phase transaction" - ) - if not isinstance(xid, Xid): - xid = Xid.from_string(xid) - - if self._tpc and not self._tpc[1]: - meth: Callable[[], PQGen[None]] - meth = getattr(self, f"_{action.lower()}_gen") - self._tpc = None - yield from meth() - else: - yield from self._exec_command( - SQL("{} PREPARED {}").format(SQL(action), str(xid)) - ) - self._tpc = None - - def _check_tpc(self) -> None: - """Raise NotSupportedError if TPC is not supported.""" - # TPC supported on every supported PostgreSQL version. - pass - - class Connection(BaseConnection[Row]): """ Wrapper for a connection to the database. diff --git a/psycopg/psycopg/connection_async.py b/psycopg/psycopg/connection_async.py index 416d00cee..bdae30d45 100644 --- a/psycopg/psycopg/connection_async.py +++ b/psycopg/psycopg/connection_async.py @@ -23,11 +23,11 @@ from ._enums import IsolationLevel from .conninfo import make_conninfo, conninfo_to_dict, resolve_hostaddr_async from ._pipeline import AsyncPipeline from ._encodings import pgconn_encoding -from .connection import BaseConnection, CursorRow, Notify from .generators import notifies from .transaction import AsyncTransaction from .cursor_async import AsyncCursor from .server_cursor import AsyncServerCursor +from ._connection_base import BaseConnection, CursorRow, Notify if TYPE_CHECKING: from .pq.abc import PGconn diff --git a/psycopg/psycopg/types/composite.py b/psycopg/psycopg/types/composite.py index 2824287ad..d116273c4 100644 --- a/psycopg/psycopg/types/composite.py +++ b/psycopg/psycopg/types/composite.py @@ -22,7 +22,7 @@ from .._typeinfo import TypeInfo from .._encodings import _as_python_identifier if TYPE_CHECKING: - from ..connection import BaseConnection + from .._connection_base import BaseConnection _struct_oidlen = struct.Struct("!Ii") _pack_oidlen = cast(Callable[[int, int], bytes], _struct_oidlen.pack) diff --git a/psycopg/psycopg/types/datetime.py b/psycopg/psycopg/types/datetime.py index 401b0b4c8..614bfca32 100644 --- a/psycopg/psycopg/types/datetime.py +++ b/psycopg/psycopg/types/datetime.py @@ -18,7 +18,7 @@ from ..errors import InterfaceError, DataError from .._struct import pack_int4, pack_int8, unpack_int4, unpack_int8 if TYPE_CHECKING: - from ..connection import BaseConnection + from .._connection_base import BaseConnection _struct_timetz = struct.Struct("!qi") # microseconds, sec tz offset _pack_timetz = cast(Callable[[int, int], bytes], _struct_timetz.pack) diff --git a/psycopg/psycopg/types/enum.py b/psycopg/psycopg/types/enum.py index a1eb62ebe..3035214ba 100644 --- a/psycopg/psycopg/types/enum.py +++ b/psycopg/psycopg/types/enum.py @@ -17,7 +17,7 @@ from .._encodings import conn_encoding from .._typeinfo import TypeInfo if TYPE_CHECKING: - from ..connection import BaseConnection + from .._connection_base import BaseConnection E = TypeVar("E", bound=Enum) diff --git a/psycopg/psycopg/types/multirange.py b/psycopg/psycopg/types/multirange.py index 8af08de20..d672f6be8 100644 --- a/psycopg/psycopg/types/multirange.py +++ b/psycopg/psycopg/types/multirange.py @@ -25,7 +25,7 @@ from .range import Range, T, load_range_text, load_range_binary from .range import dump_range_text, dump_range_binary, fail_dump if TYPE_CHECKING: - from ..connection import BaseConnection + from .._connection_base import BaseConnection class MultirangeInfo(TypeInfo): diff --git a/psycopg/psycopg/types/range.py b/psycopg/psycopg/types/range.py index 71127919e..6290ca080 100644 --- a/psycopg/psycopg/types/range.py +++ b/psycopg/psycopg/types/range.py @@ -23,7 +23,7 @@ from .._struct import pack_len, unpack_len from .._typeinfo import TypeInfo, TypesRegistry if TYPE_CHECKING: - from ..connection import BaseConnection + from .._connection_base import BaseConnection RANGE_EMPTY = 0x01 # range is empty RANGE_LB_INC = 0x02 # lower bound is inclusive diff --git a/psycopg_c/psycopg_c/_psycopg.pyi b/psycopg_c/psycopg_c/_psycopg.pyi index bd7c63d91..7d456ba53 100644 --- a/psycopg_c/psycopg_c/_psycopg.pyi +++ b/psycopg_c/psycopg_c/_psycopg.pyi @@ -9,12 +9,10 @@ information. Will submit a bug. from typing import Any, Iterable, List, Optional, Sequence, Tuple -from psycopg import pq -from psycopg import abc +from psycopg import pq, abc, BaseConnection from psycopg.rows import Row, RowMaker from psycopg.adapt import AdaptersMap, PyFormat from psycopg.pq.abc import PGconn, PGresult -from psycopg.connection import BaseConnection from psycopg._compat import Deque class Transformer(abc.AdaptContext): diff --git a/psycopg_pool/psycopg_pool/base.py b/psycopg_pool/psycopg_pool/base.py index 13823bccd..b0e308136 100644 --- a/psycopg_pool/psycopg_pool/base.py +++ b/psycopg_pool/psycopg_pool/base.py @@ -14,7 +14,7 @@ from .errors import PoolClosed from ._compat import Counter, Deque if TYPE_CHECKING: - from psycopg.connection import BaseConnection + from psycopg._connection_base import BaseConnection class BasePool: diff --git a/tests/test_connection.py b/tests/test_connection.py index 73ec2fba8..914bfa647 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -396,7 +396,7 @@ def test_connect_args(conn_cls, monkeypatch, setpgenv, pgconn, args, kwargs, wan yield setpgenv({}) - monkeypatch.setattr(psycopg.connection, "connect", fake_connect) + monkeypatch.setattr(psycopg.generators, "connect", fake_connect) conn = conn_cls.connect(*args, **kwargs) assert conninfo_to_dict(the_conninfo) == conninfo_to_dict(want) conn.close() @@ -415,7 +415,7 @@ def test_connect_badargs(conn_cls, monkeypatch, pgconn, args, kwargs, exctype): return pgconn yield - monkeypatch.setattr(psycopg.connection, "connect", fake_connect) + monkeypatch.setattr(psycopg.generators, "connect", fake_connect) with pytest.raises(exctype): conn_cls.connect(*args, **kwargs) diff --git a/tests/test_connection_async.py b/tests/test_connection_async.py index d336c19de..b9a0c005e 100644 --- a/tests/test_connection_async.py +++ b/tests/test_connection_async.py @@ -395,7 +395,7 @@ async def test_connect_args( yield setpgenv({}) - monkeypatch.setattr(psycopg.connection, "connect", fake_connect) + monkeypatch.setattr(psycopg.generators, "connect", fake_connect) conn = await aconn_cls.connect(*args, **kwargs) assert conninfo_to_dict(the_conninfo) == conninfo_to_dict(want) await conn.close() @@ -414,7 +414,7 @@ async def test_connect_badargs(aconn_cls, monkeypatch, pgconn, args, kwargs, exc return pgconn yield - monkeypatch.setattr(psycopg.connection, "connect", fake_connect) + monkeypatch.setattr(psycopg.generators, "connect", fake_connect) with pytest.raises(exctype): await aconn_cls.connect(*args, **kwargs) diff --git a/tests/test_module.py b/tests/test_module.py index 794ef0f89..c6b3e08e3 100644 --- a/tests/test_module.py +++ b/tests/test_module.py @@ -17,7 +17,7 @@ def test_connect(monkeypatch, dsn, args, kwargs, want_conninfo): # Details of the params manipulation are in test_conninfo. import psycopg.connection - orig_connect = psycopg.connection.connect # type: ignore + orig_connect = psycopg.generators.connect got_conninfo = None @@ -26,7 +26,7 @@ def test_connect(monkeypatch, dsn, args, kwargs, want_conninfo): got_conninfo = conninfo return orig_connect(dsn) - monkeypatch.setattr(psycopg.connection, "connect", mock_connect) + monkeypatch.setattr(psycopg.generators, "connect", mock_connect) conn = psycopg.connect(*args, **kwargs) assert got_conninfo == want_conninfo diff --git a/tests/test_psycopg_dbapi20.py b/tests/test_psycopg_dbapi20.py index 82a5d730c..69d4e8d8a 100644 --- a/tests/test_psycopg_dbapi20.py +++ b/tests/test_psycopg_dbapi20.py @@ -141,7 +141,7 @@ def test_connect_args(monkeypatch, pgconn, args, kwargs, want): return pgconn yield - monkeypatch.setattr(psycopg.connection, "connect", fake_connect) + monkeypatch.setattr(psycopg.generators, "connect", fake_connect) conn = psycopg.connect(*args, **kwargs) assert conninfo_to_dict(the_conninfo) == conninfo_to_dict(want) conn.close()