From: Daniele Varrazzo Date: Mon, 29 Jan 2024 15:40:28 +0000 (+0000) Subject: Revert "Merge branch 'conn-generator-refactoring'" X-Git-Tag: 3.2.0~85 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=557137dd914ded2d91c7ee66c9c32e46f892876e;p=thirdparty%2Fpsycopg.git Revert "Merge branch 'conn-generator-refactoring'" This reverts commit 70790e3f19a13c025f1abebeb0c64cb13cf26c8f, reversing changes made to c25b8409221d05d371c51eceebef6970eb33ffa0. --- diff --git a/psycopg/psycopg/_connection_base.py b/psycopg/psycopg/_connection_base.py index 12ae7395f..39e00002b 100644 --- a/psycopg/psycopg/_connection_base.py +++ b/psycopg/psycopg/_connection_base.py @@ -419,11 +419,9 @@ class BaseConnection(Generic[Row]): # should have a lock and hold it before calling and consuming them. @classmethod - def _connect_gen( - cls, conninfo: str = "", *, timeout: float = 0.0 - ) -> PQGenConn[Self]: + def _connect_gen(cls, conninfo: str = "") -> PQGenConn[Self]: """Generator to connect to the database and create a new instance.""" - pgconn = yield from generators.connect(conninfo, timeout=timeout) + pgconn = yield from generators.connect(conninfo) conn = cls(pgconn) return conn diff --git a/psycopg/psycopg/abc.py b/psycopg/psycopg/abc.py index 480fc064c..ad4a96646 100644 --- a/psycopg/psycopg/abc.py +++ b/psycopg/psycopg/abc.py @@ -56,7 +56,7 @@ class WaitFunc(Protocol): """ def __call__( - self, gen: PQGen[RV], fileno: int, interval: Optional[float] = None + self, gen: PQGen[RV], fileno: int, timeout: Optional[float] = None ) -> RV: ... diff --git a/psycopg/psycopg/connection.py b/psycopg/psycopg/connection.py index 65c1b1dcd..cb0244aa5 100644 --- a/psycopg/psycopg/connection.py +++ b/psycopg/psycopg/connection.py @@ -19,7 +19,7 @@ from contextlib import contextmanager from . import pq from . import errors as e from . import waiting -from .abc import AdaptContext, ConnDict, ConnParam, Params, PQGen, Query, RV +from .abc import AdaptContext, ConnDict, ConnParam, Params, PQGen, PQGenConn, Query, RV from ._tpc import Xid from .rows import Row, RowFactory, tuple_row, args_row from .adapt import AdaptersMap @@ -100,8 +100,8 @@ class Connection(BaseConnection[Row]): for attempt in attempts: try: conninfo = make_conninfo("", **attempt) - gen = cls._connect_gen(conninfo, timeout=timeout) - rv = waiting.wait_conn(gen, interval=_WAIT_INTERVAL) + rv = cls._wait_conn(cls._connect_gen(conninfo), timeout=timeout) + break except e._NO_TRACEBACK as ex: if len(attempts) > 1: logger.debug( @@ -112,8 +112,6 @@ class Connection(BaseConnection[Row]): str(ex), ) last_ex = ex - else: - break if not rv: assert last_ex @@ -298,10 +296,10 @@ class Connection(BaseConnection[Row]): # into shorter interval. if timeout is not None: deadline = monotonic() + timeout - interval = min(timeout, _WAIT_INTERVAL) + timeout = min(timeout, _WAIT_INTERVAL) else: deadline = None - interval = _WAIT_INTERVAL + timeout = _WAIT_INTERVAL nreceived = 0 @@ -310,7 +308,7 @@ class Connection(BaseConnection[Row]): # notification is received to makes sure that they are consistent. try: with self.lock: - ns = self.wait(notifies(self.pgconn), interval=interval) + ns = self.wait(notifies(self.pgconn), timeout=timeout) if ns: enc = pgconn_encoding(self.pgconn) except e._NO_TRACEBACK as ex: @@ -329,8 +327,8 @@ class Connection(BaseConnection[Row]): # Check the deadline after the loop to ensure that timeout=0 # polls at least once. if deadline: - interval = min(_WAIT_INTERVAL, deadline - monotonic()) - if interval < 0.0: + timeout = min(_WAIT_INTERVAL, deadline - monotonic()) + if timeout < 0.0: break @contextmanager @@ -353,7 +351,7 @@ class Connection(BaseConnection[Row]): assert pipeline is self._pipeline self._pipeline = None - def wait(self, gen: PQGen[RV], interval: Optional[float] = _WAIT_INTERVAL) -> RV: + def wait(self, gen: PQGen[RV], timeout: Optional[float] = _WAIT_INTERVAL) -> RV: """ Consume a generator operating on the connection. @@ -361,18 +359,23 @@ class Connection(BaseConnection[Row]): fd (i.e. not on connect and reset). """ try: - return waiting.wait(gen, self.pgconn.socket, interval=interval) + return waiting.wait(gen, self.pgconn.socket, timeout=timeout) except _INTERRUPTED: if self.pgconn.transaction_status == ACTIVE: # On Ctrl-C, try to cancel the query in the server, otherwise # the connection will remain stuck in ACTIVE state. self._try_cancel(self.pgconn) try: - waiting.wait(gen, self.pgconn.socket, interval=interval) + waiting.wait(gen, self.pgconn.socket, timeout=timeout) except e.QueryCanceled: pass # as expected raise + @classmethod + def _wait_conn(cls, gen: PQGenConn[RV], timeout: Optional[int]) -> RV: + """Consume a connection generator.""" + return waiting.wait_conn(gen, timeout) + def _set_autocommit(self, value: bool) -> None: self.set_autocommit(value) diff --git a/psycopg/psycopg/connection_async.py b/psycopg/psycopg/connection_async.py index 4e67c5ef8..d810d45b2 100644 --- a/psycopg/psycopg/connection_async.py +++ b/psycopg/psycopg/connection_async.py @@ -16,7 +16,7 @@ from contextlib import asynccontextmanager from . import pq from . import errors as e from . import waiting -from .abc import AdaptContext, ConnDict, ConnParam, Params, PQGen, Query, RV +from .abc import AdaptContext, ConnDict, ConnParam, Params, PQGen, PQGenConn, Query, RV from ._tpc import Xid from .rows import Row, AsyncRowFactory, tuple_row, args_row from .adapt import AdaptersMap @@ -115,8 +115,8 @@ class AsyncConnection(BaseConnection[Row]): for attempt in attempts: try: conninfo = make_conninfo("", **attempt) - gen = cls._connect_gen(conninfo, timeout=timeout) - rv = await waiting.wait_conn_async(gen, interval=_WAIT_INTERVAL) + rv = await cls._wait_conn(cls._connect_gen(conninfo), timeout=timeout) + break except e._NO_TRACEBACK as ex: if len(attempts) > 1: logger.debug( @@ -127,8 +127,6 @@ class AsyncConnection(BaseConnection[Row]): str(ex), ) last_ex = ex - else: - break if not rv: assert last_ex @@ -314,10 +312,10 @@ class AsyncConnection(BaseConnection[Row]): # into shorter interval. if timeout is not None: deadline = monotonic() + timeout - interval = min(timeout, _WAIT_INTERVAL) + timeout = min(timeout, _WAIT_INTERVAL) else: deadline = None - interval = _WAIT_INTERVAL + timeout = _WAIT_INTERVAL nreceived = 0 @@ -326,7 +324,7 @@ class AsyncConnection(BaseConnection[Row]): # notification is received to makes sure that they are consistent. try: async with self.lock: - ns = await self.wait(notifies(self.pgconn), interval=interval) + ns = await self.wait(notifies(self.pgconn), timeout=timeout) if ns: enc = pgconn_encoding(self.pgconn) except e._NO_TRACEBACK as ex: @@ -345,8 +343,8 @@ class AsyncConnection(BaseConnection[Row]): # Check the deadline after the loop to ensure that timeout=0 # polls at least once. if deadline: - interval = min(_WAIT_INTERVAL, deadline - monotonic()) - if interval < 0.0: + timeout = min(_WAIT_INTERVAL, deadline - monotonic()) + if timeout < 0.0: break @asynccontextmanager @@ -370,7 +368,7 @@ class AsyncConnection(BaseConnection[Row]): self._pipeline = None async def wait( - self, gen: PQGen[RV], interval: Optional[float] = _WAIT_INTERVAL + self, gen: PQGen[RV], timeout: Optional[float] = _WAIT_INTERVAL ) -> RV: """ Consume a generator operating on the connection. @@ -379,18 +377,23 @@ class AsyncConnection(BaseConnection[Row]): fd (i.e. not on connect and reset). """ try: - return await waiting.wait_async(gen, self.pgconn.socket, interval=interval) + return await waiting.wait_async(gen, self.pgconn.socket, timeout=timeout) except _INTERRUPTED: if self.pgconn.transaction_status == ACTIVE: # On Ctrl-C, try to cancel the query in the server, otherwise # the connection will remain stuck in ACTIVE state. self._try_cancel(self.pgconn) try: - await waiting.wait_async(gen, self.pgconn.socket, interval=interval) + await waiting.wait_async(gen, self.pgconn.socket, timeout=timeout) except e.QueryCanceled: pass # as expected raise + @classmethod + async def _wait_conn(cls, gen: PQGenConn[RV], timeout: Optional[int]) -> RV: + """Consume a connection generator.""" + return await waiting.wait_conn_async(gen, timeout) + def _set_autocommit(self, value: bool) -> None: if True: # ASYNC self._no_set_async("autocommit") diff --git a/psycopg/psycopg/generators.py b/psycopg/psycopg/generators.py index 96143af93..2e463196e 100644 --- a/psycopg/psycopg/generators.py +++ b/psycopg/psycopg/generators.py @@ -21,7 +21,6 @@ generator should probably yield the same value again in order to wait more. # Copyright (C) 2020 The Psycopg Team import logging -from time import monotonic from typing import List, Optional, Union from . import pq @@ -57,12 +56,11 @@ READY_RW = Ready.RW logger = logging.getLogger(__name__) -def _connect(conninfo: str, *, timeout: float = 0.0) -> PQGenConn[PGconn]: +def _connect(conninfo: str) -> PQGenConn[PGconn]: """ Generator to create a database connection without blocking. - """ - deadline = monotonic() + timeout if timeout else 0.0 + """ conn = pq.PGconn.connect_start(conninfo.encode()) while True: if conn.status == BAD: @@ -73,18 +71,12 @@ def _connect(conninfo: str, *, timeout: float = 0.0) -> PQGenConn[PGconn]: ) status = conn.connect_poll() - - if status == POLL_READING or status == POLL_WRITING: - wait = WAIT_R if status == POLL_READING else WAIT_W - while True: - ready = yield conn.socket, wait - if deadline and monotonic() > deadline: - raise e.ConnectionTimeout("connection timeout expired") - if ready: - break - - elif status == POLL_OK: + if status == POLL_OK: break + elif status == POLL_READING: + yield conn.socket, WAIT_R + elif status == POLL_WRITING: + yield conn.socket, WAIT_W elif status == POLL_FAILED: encoding = conninfo_encoding(conninfo) raise e.OperationalError( diff --git a/psycopg/psycopg/waiting.py b/psycopg/psycopg/waiting.py index f01de9234..6315c0ad7 100644 --- a/psycopg/psycopg/waiting.py +++ b/psycopg/psycopg/waiting.py @@ -34,15 +34,16 @@ READY_RW = Ready.RW logger = logging.getLogger(__name__) -def wait_selector(gen: PQGen[RV], fileno: int, interval: Optional[float] = None) -> RV: +def wait_selector(gen: PQGen[RV], fileno: int, timeout: Optional[float] = None) -> RV: """ Wait for a generator using the best strategy available. :param gen: a generator performing database operations and yielding `Ready` values when it would block. :param fileno: the file descriptor to wait on. - :param interval: interval (in seconds) to check for other interrupt, e.g. - to allow Ctrl-C. If zero or None, wait indefinitely. + :param timeout: timeout (in seconds) to check for other interrupt, e.g. + to allow Ctrl-C. + :type timeout: float :return: whatever `!gen` returns on completion. Consume `!gen`, scheduling `fileno` for completion when it is reported to @@ -53,7 +54,7 @@ def wait_selector(gen: PQGen[RV], fileno: int, interval: Optional[float] = None) with DefaultSelector() as sel: sel.register(fileno, s) while True: - rlist = sel.select(timeout=interval) + rlist = sel.select(timeout=timeout) if not rlist: gen.send(READY_NONE) continue @@ -68,14 +69,15 @@ def wait_selector(gen: PQGen[RV], fileno: int, interval: Optional[float] = None) return rv -def wait_conn(gen: PQGenConn[RV], interval: Optional[float] = None) -> RV: +def wait_conn(gen: PQGenConn[RV], timeout: Optional[float] = None) -> RV: """ Wait for a connection generator using the best strategy available. :param gen: a generator performing database operations and yielding (fd, `Ready`) pairs when it would block. - :param interval: interval (in seconds) to check for other interrupt, e.g. + :param timeout: timeout (in seconds) to check for other interrupt, e.g. to allow Ctrl-C. If zero or None, wait indefinitely. + :type timeout: float :return: whatever `!gen` returns on completion. Behave like in `wait()`, but take the fileno to wait from the generator @@ -83,20 +85,17 @@ def wait_conn(gen: PQGenConn[RV], interval: Optional[float] = None) -> RV: """ try: fileno, s = next(gen) - if not interval: - interval = None + if not timeout: + timeout = None with DefaultSelector() as sel: - sel.register(fileno, s) while True: - rlist = sel.select(timeout=interval) - if not rlist: - gen.send(READY_NONE) - continue - + sel.register(fileno, s) + rlist = sel.select(timeout=timeout) sel.unregister(fileno) + if not rlist: + raise e.ConnectionTimeout("connection timeout expired") ready = rlist[0][1] fileno, s = gen.send(ready) - sel.register(fileno, s) except StopIteration as ex: rv: RV = ex.args[0] if ex.args else None @@ -104,7 +103,7 @@ def wait_conn(gen: PQGenConn[RV], interval: Optional[float] = None) -> RV: async def wait_async( - gen: PQGen[RV], fileno: int, interval: Optional[float] = None + gen: PQGen[RV], fileno: int, timeout: Optional[float] = None ) -> RV: """ Coroutine waiting for a generator to complete. @@ -112,7 +111,7 @@ async def wait_async( :param gen: a generator performing database operations and yielding `Ready` values when it would block. :param fileno: the file descriptor to wait on. - :param interval: interval (in seconds) to check for other interrupt, e.g. + :param timeout: timeout (in seconds) to check for other interrupt, e.g. to allow Ctrl-C. If zero or None, wait indefinitely. :return: whatever `!gen` returns on completion. @@ -144,9 +143,9 @@ async def wait_async( if writer: loop.add_writer(fileno, wakeup, READY_W) try: - if interval: + if timeout: try: - await wait_for(ev.wait(), interval) + await wait_for(ev.wait(), timeout) except TimeoutError: pass else: @@ -166,13 +165,13 @@ async def wait_async( return rv -async def wait_conn_async(gen: PQGenConn[RV], interval: Optional[float] = None) -> RV: +async def wait_conn_async(gen: PQGenConn[RV], timeout: Optional[float] = None) -> RV: """ Coroutine waiting for a connection generator to complete. :param gen: a generator performing database operations and yielding (fd, `Ready`) pairs when it would block. - :param interval: interval (in seconds) to check for other interrupt, e.g. + :param timeout: timeout (in seconds) to check for other interrupt, e.g. to allow Ctrl-C. If zero or None, wait indefinitely. :return: whatever `!gen` returns on completion. @@ -205,11 +204,8 @@ async def wait_conn_async(gen: PQGenConn[RV], interval: Optional[float] = None) if writer: loop.add_writer(fileno, wakeup, READY_W) try: - if interval: - try: - await wait_for(ev.wait(), interval) - except TimeoutError: - pass + if timeout: + await wait_for(ev.wait(), timeout) else: await ev.wait() finally: @@ -219,6 +215,9 @@ async def wait_conn_async(gen: PQGenConn[RV], interval: Optional[float] = None) loop.remove_writer(fileno) fileno, s = gen.send(ready) + except TimeoutError: + raise e.ConnectionTimeout("connection timeout expired") + except StopIteration as ex: rv: RV = ex.args[0] if ex.args else None return rv @@ -227,7 +226,7 @@ async def wait_conn_async(gen: PQGenConn[RV], interval: Optional[float] = None) # Specialised implementation of wait functions. -def wait_select(gen: PQGen[RV], fileno: int, interval: Optional[float] = None) -> RV: +def wait_select(gen: PQGen[RV], fileno: int, timeout: Optional[float] = None) -> RV: """ Wait for a generator using select where supported. @@ -243,7 +242,7 @@ def wait_select(gen: PQGen[RV], fileno: int, interval: Optional[float] = None) - fnlist if s & WAIT_R else empty, fnlist if s & WAIT_W else empty, fnlist, - interval, + timeout, ) ready = 0 if rl: @@ -272,7 +271,7 @@ else: _epoll_evmasks = {} -def wait_epoll(gen: PQGen[RV], fileno: int, interval: Optional[float] = None) -> RV: +def wait_epoll(gen: PQGen[RV], fileno: int, timeout: Optional[float] = None) -> RV: """ Wait for a generator using epoll where supported. @@ -291,14 +290,14 @@ def wait_epoll(gen: PQGen[RV], fileno: int, interval: Optional[float] = None) -> try: s = next(gen) - if interval is None or interval < 0: - interval = 0.0 + if timeout is None or timeout < 0: + timeout = 0.0 with select.epoll() as epoll: evmask = _epoll_evmasks[s] epoll.register(fileno, evmask) while True: - fileevs = epoll.poll(interval) + fileevs = epoll.poll(timeout) if not fileevs: gen.send(READY_NONE) continue @@ -327,7 +326,7 @@ else: _poll_evmasks = {} -def wait_poll(gen: PQGen[RV], fileno: int, interval: Optional[float] = None) -> RV: +def wait_poll(gen: PQGen[RV], fileno: int, timeout: Optional[float] = None) -> RV: """ Wait for a generator using poll where supported. @@ -336,16 +335,16 @@ def wait_poll(gen: PQGen[RV], fileno: int, interval: Optional[float] = None) -> try: s = next(gen) - if interval is None or interval < 0: - interval = 0 + if timeout is None or timeout < 0: + timeout = 0 else: - interval = int(interval * 1000.0) + timeout = int(timeout * 1000.0) poll = select.poll() evmask = _poll_evmasks[s] poll.register(fileno, evmask) while True: - fileevs = poll.poll(interval) + fileevs = poll.poll(timeout) if not fileevs: gen.send(READY_NONE) continue diff --git a/psycopg_c/psycopg_c/_psycopg.pyi b/psycopg_c/psycopg_c/_psycopg.pyi index ec976eb5c..7d456ba53 100644 --- a/psycopg_c/psycopg_c/_psycopg.pyi +++ b/psycopg_c/psycopg_c/_psycopg.pyi @@ -51,7 +51,7 @@ class Transformer(abc.AdaptContext): def get_loader(self, oid: int, format: pq.Format) -> abc.Loader: ... # Generators -def connect(conninfo: str, *, timeout: float = 0.0) -> abc.PQGenConn[PGconn]: ... +def connect(conninfo: str) -> abc.PQGenConn[PGconn]: ... def execute(pgconn: PGconn) -> abc.PQGen[List[PGresult]]: ... def send(pgconn: PGconn) -> abc.PQGen[None]: ... def fetch_many(pgconn: PGconn) -> abc.PQGen[List[PGresult]]: ... @@ -60,7 +60,7 @@ def pipeline_communicate( pgconn: PGconn, commands: Deque[abc.PipelineCommand] ) -> abc.PQGen[List[List[PGresult]]]: ... def wait_c( - gen: abc.PQGen[abc.RV], fileno: int, interval: Optional[float] = None + gen: abc.PQGen[abc.RV], fileno: int, timeout: Optional[float] = None ) -> abc.RV: ... # Copy support diff --git a/psycopg_c/psycopg_c/_psycopg/generators.pyx b/psycopg_c/psycopg_c/_psycopg/generators.pyx index b29ad327a..70335cf89 100644 --- a/psycopg_c/psycopg_c/_psycopg/generators.pyx +++ b/psycopg_c/psycopg_c/_psycopg/generators.pyx @@ -7,7 +7,6 @@ C implementation of generators for the communication protocols with the libpq from cpython.object cimport PyObject_CallFunctionObjArgs from typing import List -from time import monotonic from psycopg import errors as e from psycopg.pq import abc, error_message @@ -28,17 +27,15 @@ cdef int READY_R = Ready.R cdef int READY_W = Ready.W cdef int READY_RW = Ready.RW -def connect(conninfo: str, *, timeout: float = 0.0) -> PQGenConn[abc.PGconn]: +def connect(conninfo: str) -> PQGenConn[abc.PGconn]: """ Generator to create a database connection without blocking. - """ - cdef int deadline = monotonic() + timeout if timeout else 0.0 + """ cdef pq.PGconn conn = pq.PGconn.connect_start(conninfo.encode()) cdef libpq.PGconn *pgconn_ptr = conn._pgconn_ptr cdef int conn_status = libpq.PQstatus(pgconn_ptr) cdef int poll_status - cdef object wait, ready while True: if conn_status == libpq.CONNECTION_BAD: @@ -51,18 +48,12 @@ def connect(conninfo: str, *, timeout: float = 0.0) -> PQGenConn[abc.PGconn]: with nogil: poll_status = libpq.PQconnectPoll(pgconn_ptr) - if poll_status == libpq.PGRES_POLLING_READING \ - or poll_status == libpq.PGRES_POLLING_WRITING: - wait = WAIT_R if poll_status == libpq.PGRES_POLLING_READING else WAIT_W - while True: - ready = yield (libpq.PQsocket(pgconn_ptr), wait) - if deadline and monotonic() > deadline: - raise e.ConnectionTimeout("connection timeout expired") - if ready: - break - - elif poll_status == libpq.PGRES_POLLING_OK: + if poll_status == libpq.PGRES_POLLING_OK: break + elif poll_status == libpq.PGRES_POLLING_READING: + yield (libpq.PQsocket(pgconn_ptr), WAIT_R) + elif poll_status == libpq.PGRES_POLLING_WRITING: + yield (libpq.PQsocket(pgconn_ptr), WAIT_W) elif poll_status == libpq.PGRES_POLLING_FAILED: encoding = conninfo_encoding(conninfo) raise e.OperationalError( diff --git a/psycopg_c/psycopg_c/_psycopg/waiting.pyx b/psycopg_c/psycopg_c/_psycopg/waiting.pyx index d11a2d9c0..3a6cc6e25 100644 --- a/psycopg_c/psycopg_c/_psycopg/waiting.pyx +++ b/psycopg_c/psycopg_c/_psycopg/waiting.pyx @@ -177,20 +177,20 @@ finally: cdef int wait_c_impl(int fileno, int wait, float timeout) except -1 -def wait_c(gen: PQGen[RV], int fileno, interval = None) -> RV: +def wait_c(gen: PQGen[RV], int fileno, timeout = None) -> RV: """ Wait for a generator using poll or select. """ - cdef float cinterval + cdef float ctimeout cdef int wait, ready cdef PyObject *pyready - if interval is None: - cinterval = -1.0 + if timeout is None: + ctimeout = -1.0 else: - cinterval = float(interval) - if cinterval < 0.0: - cinterval = -1.0 + ctimeout = float(timeout) + if ctimeout < 0.0: + ctimeout = -1.0 send = gen.send @@ -198,7 +198,7 @@ def wait_c(gen: PQGen[RV], int fileno, interval = None) -> RV: wait = next(gen) while True: - ready = wait_c_impl(fileno, wait, cinterval) + ready = wait_c_impl(fileno, wait, ctimeout) if ready == READY_NONE: pyready = PY_READY_NONE elif ready == READY_R: diff --git a/tests/test_connection.py b/tests/test_connection.py index eb8541238..f17f8d2a0 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -446,7 +446,7 @@ def test_connect_args( ): got_conninfo: str - def fake_connect(conninfo, *, timeout=0.0): + def fake_connect(conninfo): nonlocal got_conninfo got_conninfo = conninfo return pgconn @@ -468,6 +468,12 @@ def test_connect_args( ], ) def test_connect_badargs(conn_cls, monkeypatch, pgconn, args, kwargs, exctype): + + def fake_connect(conninfo): + return pgconn + yield + + monkeypatch.setattr(psycopg.generators, "connect", fake_connect) with pytest.raises(exctype): conn_cls.connect(*args, **kwargs) diff --git a/tests/test_connection_async.py b/tests/test_connection_async.py index cd761b330..d7aa7ca8b 100644 --- a/tests/test_connection_async.py +++ b/tests/test_connection_async.py @@ -443,7 +443,7 @@ async def test_connect_args( ): got_conninfo: str - def fake_connect(conninfo, *, timeout=0.0): + def fake_connect(conninfo): nonlocal got_conninfo got_conninfo = conninfo return pgconn @@ -465,6 +465,11 @@ async def test_connect_args( ], ) async def test_connect_badargs(aconn_cls, monkeypatch, pgconn, args, kwargs, exctype): + def fake_connect(conninfo): + return pgconn + yield + + monkeypatch.setattr(psycopg.generators, "connect", fake_connect) with pytest.raises(exctype): await aconn_cls.connect(*args, **kwargs) diff --git a/tests/test_module.py b/tests/test_module.py index 2b1869e94..9b144d7d6 100644 --- a/tests/test_module.py +++ b/tests/test_module.py @@ -22,10 +22,10 @@ def test_connect(monkeypatch, dsn_env, args, kwargs, want, setpgenv): got_conninfo: str - def mock_connect(conninfo, *, timeout): + def mock_connect(conninfo): nonlocal got_conninfo got_conninfo = conninfo - return orig_connect(dsn_env, timeout=timeout) + return orig_connect(dsn_env) setpgenv({}) monkeypatch.setattr(psycopg.generators, "connect", mock_connect) diff --git a/tests/test_psycopg_dbapi20.py b/tests/test_psycopg_dbapi20.py index df1c981a6..b4feac792 100644 --- a/tests/test_psycopg_dbapi20.py +++ b/tests/test_psycopg_dbapi20.py @@ -135,7 +135,7 @@ def test_time_from_ticks(ticks, want): def test_connect_args(monkeypatch, pgconn, args, kwargs, want, setpgenv, fake_resolve): got_conninfo: str - def fake_connect(conninfo, *, timeout=0.0): + def fake_connect(conninfo): nonlocal got_conninfo got_conninfo = conninfo return pgconn @@ -157,5 +157,9 @@ def test_connect_args(monkeypatch, pgconn, args, kwargs, want, setpgenv, fake_re ], ) def test_connect_badargs(monkeypatch, pgconn, args, kwargs, exctype): + def fake_connect(conninfo): + return pgconn + yield + with pytest.raises(exctype): psycopg.connect(*args, **kwargs) diff --git a/tests/test_waiting.py b/tests/test_waiting.py index 15c331ac3..c4d8915e8 100644 --- a/tests/test_waiting.py +++ b/tests/test_waiting.py @@ -28,11 +28,11 @@ waitfns = [ ] events = ["R", "W", "RW"] -intervals = [pytest.param({}, id="blank")] -intervals += [pytest.param({"interval": x}, id=str(x)) for x in [None, 0, 0.2, 10]] +timeouts = [pytest.param({}, id="blank")] +timeouts += [pytest.param({"timeout": x}, id=str(x)) for x in [None, 0, 0.2, 10]] -@pytest.mark.parametrize("timeout", intervals) +@pytest.mark.parametrize("timeout", timeouts) def test_wait_conn(dsn, timeout): gen = generators.connect(dsn) conn = waiting.wait_conn(gen, **timeout) @@ -63,7 +63,7 @@ def test_wait_ready(waitfn, event): @pytest.mark.parametrize("waitfn", waitfns) -@pytest.mark.parametrize("timeout", intervals) +@pytest.mark.parametrize("timeout", timeouts) def test_wait(pgconn, waitfn, timeout): waitfn = getattr(waiting, waitfn) @@ -104,7 +104,7 @@ def test_wait_timeout(pgconn, waitfn): except StopIteration as ex: return ex.value - (res,) = waitfn(gen_wrapper(), pgconn.socket, interval=0.1) + (res,) = waitfn(gen_wrapper(), pgconn.socket, timeout=0.1) assert res.status == ExecStatus.TUPLES_OK ds = [t1 - t0 for t0, t1 in zip(ts[:-1], ts[1:])] assert len(ds) >= 5 @@ -146,7 +146,7 @@ def test_wait_large_fd(dsn, fname): f.close() -@pytest.mark.parametrize("timeout", intervals) +@pytest.mark.parametrize("timeout", timeouts) @pytest.mark.anyio async def test_wait_conn_async(dsn, timeout): gen = generators.connect(dsn)