From: Daniele Varrazzo Date: Sun, 4 Feb 2024 14:36:54 +0000 (+0000) Subject: chore: update code to master X-Git-Tag: pool-3.2.2~15 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=a82087505db1e1a1bf11ad0e5043c98f5c460b51;p=thirdparty%2Fpsycopg.git chore: update code to master --- diff --git a/.flake8 b/.flake8 index ec4053fb2..d2473a1ae 100644 --- a/.flake8 +++ b/.flake8 @@ -1,6 +1,6 @@ [flake8] max-line-length = 88 -ignore = W503, E203 +ignore = W503, E203, E704 extend-exclude = .venv build per-file-ignores = # Autogenerated section diff --git a/.github/workflows/3rd-party-tests.yml b/.github/workflows/3rd-party-tests.yml index 89f948d0e..26c9f270c 100644 --- a/.github/workflows/3rd-party-tests.yml +++ b/.github/workflows/3rd-party-tests.yml @@ -61,7 +61,7 @@ jobs: --health-retries 5 steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: actions/setup-python@v5 with: @@ -158,7 +158,7 @@ jobs: --health-retries 5 steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: actions/setup-python@v5 with: diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index b615fb28f..b86c53e75 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -18,7 +18,7 @@ jobs: if: true steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: actions/setup-python@v5 with: diff --git a/.github/workflows/packages-bin.yml b/.github/workflows/packages-bin.yml index 975fe4c5a..ec5e1757a 100644 --- a/.github/workflows/packages-bin.yml +++ b/.github/workflows/packages-bin.yml @@ -23,7 +23,7 @@ jobs: platform: [manylinux, musllinux] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up QEMU for multi-arch build # Check https://github.com/docker/setup-qemu-action for newer versions. @@ -43,7 +43,7 @@ jobs: run: python3 ./tools/build/copy_to_binary.py - name: Build wheels - uses: pypa/cibuildwheel@v2.16.0 + uses: pypa/cibuildwheel@v2.16.5 with: package-dir: psycopg_binary env: @@ -104,13 +104,13 @@ jobs: pyver: [cp38, cp39, cp310, cp311, cp312] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Create the binary package source tree run: python3 ./tools/build/copy_to_binary.py - name: Build wheels - uses: pypa/cibuildwheel@v2.16.0 + uses: pypa/cibuildwheel@v2.16.5 with: package-dir: psycopg_binary env: @@ -147,7 +147,7 @@ jobs: pyver: [cp38, cp39, cp310, cp311, cp312] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Start PostgreSQL service for test run: | @@ -159,7 +159,7 @@ jobs: run: python3 ./tools/build/copy_to_binary.py - name: Build wheels - uses: pypa/cibuildwheel@v2.16.0 + uses: pypa/cibuildwheel@v2.16.5 with: package-dir: psycopg_binary env: diff --git a/.github/workflows/packages-pool.yml b/.github/workflows/packages-pool.yml index 08c334825..db79ec133 100644 --- a/.github/workflows/packages-pool.yml +++ b/.github/workflows/packages-pool.yml @@ -18,7 +18,7 @@ jobs: - {package: psycopg_pool, format: wheel} steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: actions/setup-python@v5 with: diff --git a/.github/workflows/packages-src.yml b/.github/workflows/packages-src.yml index 52db48589..6b4f911dd 100644 --- a/.github/workflows/packages-src.yml +++ b/.github/workflows/packages-src.yml @@ -20,7 +20,7 @@ jobs: - {package: psycopg_c, format: sdist, impl: c} steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: actions/setup-python@v5 with: diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index f97e64ba8..9fb9421a8 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -59,7 +59,7 @@ jobs: MARKERS: "" steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: actions/setup-python@v5 with: @@ -155,7 +155,7 @@ jobs: NOT_MARKERS: "timing proxy mypy" steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: actions/setup-python@v5 with: @@ -212,7 +212,7 @@ jobs: shell: bash steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: actions/setup-python@v5 with: @@ -267,7 +267,7 @@ jobs: PSYCOPG_TEST_DSN: "host=127.0.0.1 port=26257 user=root dbname=defaultdb" steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: actions/setup-python@v5 with: diff --git a/docs/advanced/async.rst b/docs/advanced/async.rst index bf7526071..ef1f6c151 100644 --- a/docs/advanced/async.rst +++ b/docs/advanced/async.rst @@ -334,7 +334,8 @@ mode if you wish to receive or send notifications in a timely manner. Notifications are received as instances of `Notify`. If you are reserving a connection only to receive notifications, the simplest way is to consume the `Connection.notifies` generator. The generator can be stopped using -`!close()`. +`!close()`. Starting from Psycopg 3.2, the method supports options to receive +notifications only for a certain time or up to a certain number. .. note:: diff --git a/docs/api/connections.rst b/docs/api/connections.rst index a607f07fa..898a8470d 100644 --- a/docs/api/connections.rst +++ b/docs/api/connections.rst @@ -286,6 +286,10 @@ The `!Connection` class any sessions in the database generates a :sql:`NOTIFY` on one of the listened channels. + .. versionchanged:: 3.2 + + Added `!timeout` and `!stop_after` parameters. + .. automethod:: add_notify_handler See :ref:`async-notify` for details. @@ -494,6 +498,11 @@ The `!AsyncConnection` class ... .. automethod:: notifies + + .. versionchanged:: 3.2 + + Added `!timeout` and `!stop_after` parameters. + .. automethod:: set_autocommit .. automethod:: set_isolation_level .. automethod:: set_read_only diff --git a/docs/api/rows.rst b/docs/api/rows.rst index d4c438242..15dfd3cad 100644 --- a/docs/api/rows.rst +++ b/docs/api/rows.rst @@ -14,6 +14,10 @@ Check out :ref:`row-factory-create` for information about how to use these objec .. autofunction:: tuple_row .. autofunction:: dict_row .. autofunction:: namedtuple_row +.. autofunction:: scalar_row + + .. versionadded:: 3.2 + .. autofunction:: class_row This is not a row factory, but rather a factory of row factories. diff --git a/docs/api/sql.rst b/docs/api/sql.rst index 6959fee4d..5e7000b26 100644 --- a/docs/api/sql.rst +++ b/docs/api/sql.rst @@ -108,9 +108,25 @@ The `!sql` objects are in the following inheritance hierarchy: .. autoclass:: Composable() - .. automethod:: as_bytes .. automethod:: as_string + .. versionchanged:: 3.2 + + The `!context` parameter is optional. + + .. warning:: + + If a context is not specified, the results are "generic" and not + tailored for a specific target connection. Details such as the + connection encoding and escaping style will not be taken into + account. + + .. automethod:: as_bytes + + .. versionchanged:: 3.2 + + The `!context` parameter is optional. See `as_string` for details. + .. autoclass:: SQL diff --git a/docs/news.rst b/docs/news.rst index db3234f4c..aa87171a5 100644 --- a/docs/news.rst +++ b/docs/news.rst @@ -15,12 +15,17 @@ Psycopg 3.2 (unreleased) - Add support for integer, floating point, boolean `NumPy scalar types`__ (:ticket:`#332`). +- Add `!timeout` and `!stop_after` parameters to `Connection.notifies()` + (:ticket:`340`). - Add :ref:`raw-query-cursors` to execute queries using placeholders in PostgreSQL format (`$1`, `$2`...) (:ticket:`#560`). +- Add `~rows.scalar_row` to return scalar values from a query (:ticket:`#723`). - Add `~Connection.set_autocommit()` on sync connections, and similar transaction control methods available on the async connections. - Add support for libpq functions to close prepared statements and portals introduced in libpq v17 (:ticket:`#603`). +- The `!context` parameter of `sql` objects `~sql.Composable.as_string()` and + `~sql.Composable.as_bytes()` methods is not optional (:ticket:`#716`). - Disable receiving more than one result on the same cursor in pipeline mode, to iterate through `~Cursor.nextset()`. The behaviour was different than in non-pipeline mode and not totally reliable (:ticket:`#604`). @@ -30,18 +35,29 @@ Psycopg 3.2 (unreleased) .. __: https://numpy.org/doc/stable/reference/arrays.scalars.html#built-in-scalar-types -Psycopg 3.1.17 (unreleased) +Psycopg 3.1.18 (unreleased) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ -- Fix multiple connection attempts when a host name resolve to multiple - IP addresses (:ticket:`699`). -- Use `typing.Self` as a more correct return value annotation of context - managers and other self-returning methods (see :ticket:`708`). +- Fix possible deadlock on pipeline exit (:ticket:`#685`). +- Fix overflow loading large intervals in C module (:ticket:`#719`). +- Fix compatibility with musl libc distributions affected by `CPython issue + #65821`__ (:ticket:`#725`). + +.. __: https://github.com/python/cpython/issues/65821 Current release --------------- +Psycopg 3.1.17 +^^^^^^^^^^^^^^ + +- Fix multiple connection attempts when a host name resolve to multiple + IP addresses (:ticket:`#699`). +- Use `typing.Self` as a more correct return value annotation of context + managers and other self-returning methods (see :ticket:`#708`). + + Psycopg 3.1.16 ^^^^^^^^^^^^^^ diff --git a/psycopg/.flake8 b/psycopg/.flake8 index 67fb0245c..33b08d768 100644 --- a/psycopg/.flake8 +++ b/psycopg/.flake8 @@ -1,6 +1,6 @@ [flake8] max-line-length = 88 -ignore = W503, E203 +ignore = W503, E203, E704 per-file-ignores = # Autogenerated section psycopg/errors.py: E125, E128, E302 diff --git a/psycopg/psycopg/_acompat.py b/psycopg/psycopg/_acompat.py index cf106c5ba..d7290889d 100644 --- a/psycopg/psycopg/_acompat.py +++ b/psycopg/psycopg/_acompat.py @@ -15,9 +15,7 @@ import asyncio import threading from typing import Any, Callable, Coroutine, TYPE_CHECKING -from typing_extensions import TypeAlias - -from ._compat import TypeVar +from ._compat import TypeAlias, TypeVar Worker: TypeAlias = threading.Thread AWorker: TypeAlias = "asyncio.Task[None]" diff --git a/psycopg/psycopg/_compat.py b/psycopg/psycopg/_compat.py index 1e1130486..68d689a2d 100644 --- a/psycopg/psycopg/_compat.py +++ b/psycopg/psycopg/_compat.py @@ -18,9 +18,9 @@ else: cache = lru_cache(maxsize=None) if sys.version_info >= (3, 10): - from typing import TypeGuard + from typing import TypeGuard, TypeAlias else: - from typing_extensions import TypeGuard + from typing_extensions import TypeGuard, TypeAlias if sys.version_info >= (3, 11): from typing import LiteralString, Self @@ -37,6 +37,7 @@ __all__ = [ "Deque", "LiteralString", "Self", + "TypeAlias", "TypeGuard", "TypeVar", "ZoneInfo", diff --git a/psycopg/psycopg/_connection_base.py b/psycopg/psycopg/_connection_base.py index 4dc695ce6..39e00002b 100644 --- a/psycopg/psycopg/_connection_base.py +++ b/psycopg/psycopg/_connection_base.py @@ -11,7 +11,6 @@ from typing import TYPE_CHECKING from weakref import ref, ReferenceType from warnings import warn from functools import partial -from typing_extensions import TypeAlias from . import pq from . import errors as e @@ -23,7 +22,7 @@ from ._tpc import Xid from .rows import Row from .adapt import AdaptersMap from ._enums import IsolationLevel -from ._compat import LiteralString, Self, TypeVar +from ._compat import LiteralString, Self, TypeAlias, TypeVar from .pq.misc import connection_summary from ._pipeline import BasePipeline from ._encodings import pgconn_encoding diff --git a/psycopg/psycopg/_conninfo_attempts.py b/psycopg/psycopg/_conninfo_attempts.py index 4fc0f792a..6f64f4ba1 100644 --- a/psycopg/psycopg/_conninfo_attempts.py +++ b/psycopg/psycopg/_conninfo_attempts.py @@ -14,14 +14,15 @@ import logging from random import shuffle from . import errors as e -from ._conninfo_utils import ConnDict, get_param, is_ip_address, get_param_def +from .abc import ConnDict, ConnMapping +from ._conninfo_utils import get_param, is_ip_address, get_param_def from ._conninfo_utils import split_attempts logger = logging.getLogger("psycopg") -def conninfo_attempts(params: ConnDict) -> list[ConnDict]: +def conninfo_attempts(params: ConnMapping) -> list[ConnDict]: """Split a set of connection params on the single attempts to perform. A connection param can perform more than one attempt more than one ``host`` diff --git a/psycopg/psycopg/_conninfo_attempts_async.py b/psycopg/psycopg/_conninfo_attempts_async.py index 6aca4ee3a..a549081e9 100644 --- a/psycopg/psycopg/_conninfo_attempts_async.py +++ b/psycopg/psycopg/_conninfo_attempts_async.py @@ -11,7 +11,8 @@ import logging from random import shuffle from . import errors as e -from ._conninfo_utils import ConnDict, get_param, is_ip_address, get_param_def +from .abc import ConnDict, ConnMapping +from ._conninfo_utils import get_param, is_ip_address, get_param_def from ._conninfo_utils import split_attempts if True: # ASYNC: @@ -20,7 +21,7 @@ if True: # ASYNC: logger = logging.getLogger("psycopg") -async def conninfo_attempts_async(params: ConnDict) -> list[ConnDict]: +async def conninfo_attempts_async(params: ConnMapping) -> list[ConnDict]: """Split a set of connection params on the single attempts to perform. A connection param can perform more than one attempt more than one ``host`` diff --git a/psycopg/psycopg/_conninfo_utils.py b/psycopg/psycopg/_conninfo_utils.py index 8940c937b..a342987a0 100644 --- a/psycopg/psycopg/_conninfo_utils.py +++ b/psycopg/psycopg/_conninfo_utils.py @@ -7,19 +7,20 @@ Internal utilities to manipulate connection strings from __future__ import annotations import os -from typing import Any +from typing import TYPE_CHECKING from functools import lru_cache from ipaddress import ip_address from dataclasses import dataclass -from typing_extensions import TypeAlias from . import pq +from .abc import ConnDict, ConnMapping from . import errors as e -ConnDict: TypeAlias = "dict[str, Any]" +if TYPE_CHECKING: + from typing import Any # noqa: F401 -def split_attempts(params: ConnDict) -> list[ConnDict]: +def split_attempts(params: ConnMapping) -> list[ConnDict]: """ Split connection parameters with a sequence of hosts into separate attempts. """ @@ -47,7 +48,7 @@ def split_attempts(params: ConnDict) -> list[ConnDict]: # A single attempt to make. Don't mangle the conninfo string. if nhosts <= 1: - return [params] + return [{**params}] if len(ports) == 1: ports *= nhosts @@ -55,7 +56,7 @@ def split_attempts(params: ConnDict) -> list[ConnDict]: # Now all lists are either empty or have the same length rv = [] for i in range(nhosts): - attempt = params.copy() + attempt = {**params} if hosts: attempt["host"] = hosts[i] if hostaddrs: @@ -67,7 +68,7 @@ def split_attempts(params: ConnDict) -> list[ConnDict]: return rv -def get_param(params: ConnDict, name: str) -> str | None: +def get_param(params: ConnMapping, name: str) -> str | None: """ Return a value from a connection string. diff --git a/psycopg/psycopg/_copy_base.py b/psycopg/psycopg/_copy_base.py index 140744ff1..9194b266b 100644 --- a/psycopg/psycopg/_copy_base.py +++ b/psycopg/psycopg/_copy_base.py @@ -210,20 +210,16 @@ class Formatter(ABC): self._row_mode = False # true if the user is using write_row() @abstractmethod - def parse_row(self, data: Buffer) -> Optional[Tuple[Any, ...]]: - ... + def parse_row(self, data: Buffer) -> Optional[Tuple[Any, ...]]: ... @abstractmethod - def write(self, buffer: Union[Buffer, str]) -> Buffer: - ... + def write(self, buffer: Union[Buffer, str]) -> Buffer: ... @abstractmethod - def write_row(self, row: Sequence[Any]) -> Buffer: - ... + def write_row(self, row: Sequence[Any]) -> Buffer: ... @abstractmethod - def end(self) -> Buffer: - ... + def end(self) -> Buffer: ... class TextFormatter(Formatter): diff --git a/psycopg/psycopg/_encodings.py b/psycopg/psycopg/_encodings.py index 238271093..4949e26c6 100644 --- a/psycopg/psycopg/_encodings.py +++ b/psycopg/psycopg/_encodings.py @@ -116,7 +116,7 @@ def conninfo_encoding(conninfo: str) -> str: pgenc = params.get("client_encoding") if pgenc: try: - return pg2pyenc(pgenc.encode()) + return pg2pyenc(str(pgenc).encode()) except NotSupportedError: pass diff --git a/psycopg/psycopg/_enums.py b/psycopg/psycopg/_enums.py index a7cb78df4..1975650c6 100644 --- a/psycopg/psycopg/_enums.py +++ b/psycopg/psycopg/_enums.py @@ -20,6 +20,7 @@ class Wait(IntEnum): class Ready(IntEnum): + NONE = 0 R = EVENT_READ W = EVENT_WRITE RW = EVENT_READ | EVENT_WRITE diff --git a/psycopg/psycopg/_pipeline.py b/psycopg/psycopg/_pipeline.py index 72ac97ddd..05d0beb64 100644 --- a/psycopg/psycopg/_pipeline.py +++ b/psycopg/psycopg/_pipeline.py @@ -7,12 +7,11 @@ commands pipeline management import logging from types import TracebackType from typing import Any, List, Optional, Union, Tuple, Type, TYPE_CHECKING -from typing_extensions import TypeAlias from . import pq from . import errors as e from .abc import PipelineCommand, PQGen -from ._compat import Deque, Self +from ._compat import Deque, Self, TypeAlias from .pq.misc import connection_summary from ._encodings import pgconn_encoding from ._preparing import Key, Prepare @@ -133,8 +132,7 @@ class BasePipeline: self._enqueue_sync() yield from self._communicate_gen() finally: - # No need to force flush since we emitted a sync just before. - yield from self._fetch_gen(flush=False) + yield from self._fetch_gen(flush=True) def _communicate_gen(self) -> PQGen[None]: """Communicate with pipeline to send commands and possibly fetch diff --git a/psycopg/psycopg/_preparing.py b/psycopg/psycopg/_preparing.py index 158552ba5..465de53a4 100644 --- a/psycopg/psycopg/_preparing.py +++ b/psycopg/psycopg/_preparing.py @@ -7,10 +7,9 @@ Support for prepared statements from enum import IntEnum, auto from typing import Iterator, Optional, Sequence, Tuple, TYPE_CHECKING from collections import OrderedDict -from typing_extensions import TypeAlias from . import pq -from ._compat import Deque +from ._compat import Deque, TypeAlias from ._queries import PostgresQuery if TYPE_CHECKING: diff --git a/psycopg/psycopg/_py_transformer.py b/psycopg/psycopg/_py_transformer.py index 17f21c079..dd7f54759 100644 --- a/psycopg/psycopg/_py_transformer.py +++ b/psycopg/psycopg/_py_transformer.py @@ -12,7 +12,6 @@ dependencies problems). from typing import Any, Dict, List, Optional, Sequence, Tuple from typing import DefaultDict, TYPE_CHECKING from collections import defaultdict -from typing_extensions import TypeAlias from . import pq from . import abc @@ -20,6 +19,7 @@ from . import errors as e from .abc import Buffer, LoadFunc, AdaptContext, PyFormat, DumperKey, NoneType from .rows import Row, RowMaker from ._oids import INVALID_OID, TEXT_OID +from ._compat import TypeAlias from ._encodings import conn_encoding if TYPE_CHECKING: diff --git a/psycopg/psycopg/_queries.py b/psycopg/psycopg/_queries.py index 376012aec..dc2e5a67e 100644 --- a/psycopg/psycopg/_queries.py +++ b/psycopg/psycopg/_queries.py @@ -8,14 +8,13 @@ import re from typing import Any, Callable, Dict, List, Mapping, Match, NamedTuple, Optional from typing import Sequence, Tuple, Union, TYPE_CHECKING from functools import lru_cache -from typing_extensions import TypeAlias from . import pq from . import errors as e from .sql import Composable from .abc import Buffer, Query, Params from ._enums import PyFormat -from ._compat import TypeGuard +from ._compat import TypeAlias, TypeGuard from ._encodings import conn_encoding if TYPE_CHECKING: diff --git a/psycopg/psycopg/_struct.py b/psycopg/psycopg/_struct.py index bce427c80..7232a20bd 100644 --- a/psycopg/psycopg/_struct.py +++ b/psycopg/psycopg/_struct.py @@ -6,10 +6,10 @@ Utility functions to deal with binary structs. import struct from typing import Callable, cast, Optional, Protocol, Tuple -from typing_extensions import TypeAlias -from .abc import Buffer from . import errors as e +from .abc import Buffer +from ._compat import TypeAlias PackInt: TypeAlias = Callable[[int], bytes] UnpackInt: TypeAlias = Callable[[Buffer], Tuple[int]] @@ -18,8 +18,7 @@ UnpackFloat: TypeAlias = Callable[[Buffer], Tuple[float]] class UnpackLen(Protocol): - def __call__(self, data: Buffer, start: Optional[int]) -> Tuple[int]: - ... + def __call__(self, data: Buffer, start: Optional[int]) -> Tuple[int]: ... pack_int2 = cast(PackInt, struct.Struct("!h").pack) diff --git a/psycopg/psycopg/_typeinfo.py b/psycopg/psycopg/_typeinfo.py index bfa740ff9..fc170492a 100644 --- a/psycopg/psycopg/_typeinfo.py +++ b/psycopg/psycopg/_typeinfo.py @@ -9,13 +9,12 @@ information to the adapters if needed. from typing import Any, Dict, Iterator, Optional, overload from typing import Sequence, Tuple, Type, Union, TYPE_CHECKING -from typing_extensions import TypeAlias from . import sql from . import errors as e from .abc import AdaptContext, Query from .rows import dict_row -from ._compat import TypeVar +from ._compat import TypeAlias, TypeVar from ._encodings import conn_encoding if TYPE_CHECKING: @@ -59,15 +58,13 @@ class TypeInfo: @classmethod def fetch( cls: Type[T], conn: "Connection[Any]", name: Union[str, sql.Identifier] - ) -> Optional[T]: - ... + ) -> Optional[T]: ... @overload @classmethod async def fetch( cls: Type[T], conn: "AsyncConnection[Any]", name: Union[str, sql.Identifier] - ) -> Optional[T]: - ... + ) -> Optional[T]: ... @classmethod def fetch( @@ -239,12 +236,10 @@ class TypesRegistry: yield t @overload - def __getitem__(self, key: Union[str, int]) -> TypeInfo: - ... + def __getitem__(self, key: Union[str, int]) -> TypeInfo: ... @overload - def __getitem__(self, key: Tuple[Type[T], int]) -> T: - ... + def __getitem__(self, key: Tuple[Type[T], int]) -> T: ... def __getitem__(self, key: RegistryKey) -> TypeInfo: """ @@ -265,12 +260,10 @@ class TypesRegistry: raise KeyError(f"couldn't find the type {key!r} in the types registry") @overload - def get(self, key: Union[str, int]) -> Optional[TypeInfo]: - ... + def get(self, key: Union[str, int]) -> Optional[TypeInfo]: ... @overload - def get(self, key: Tuple[Type[T], int]) -> Optional[T]: - ... + def get(self, key: Tuple[Type[T], int]) -> Optional[T]: ... def get(self, key: RegistryKey) -> Optional[TypeInfo]: """ diff --git a/psycopg/psycopg/abc.py b/psycopg/psycopg/abc.py index 0952e8d0b..ad4a96646 100644 --- a/psycopg/psycopg/abc.py +++ b/psycopg/psycopg/abc.py @@ -4,14 +4,13 @@ Protocol objects representing different implementations of the same classes. # Copyright (C) 2020 The Psycopg Team -from typing import Any, Callable, Generator, Mapping +from typing import Any, Dict, Callable, Generator, Mapping from typing import List, Optional, Protocol, Sequence, Tuple, Union from typing import TYPE_CHECKING -from typing_extensions import TypeAlias from . import pq from ._enums import PyFormat as PyFormat -from ._compat import LiteralString, TypeVar +from ._compat import LiteralString, TypeAlias, TypeVar if TYPE_CHECKING: from . import sql @@ -31,18 +30,22 @@ Params: TypeAlias = Union[Sequence[Any], Mapping[str, Any]] ConnectionType = TypeVar("ConnectionType", bound="BaseConnection[Any]") PipelineCommand: TypeAlias = Callable[[], None] DumperKey: TypeAlias = Union[type, Tuple["DumperKey", ...]] +ConnParam: TypeAlias = Union[str, int, None] +ConnDict: TypeAlias = Dict[str, ConnParam] +ConnMapping: TypeAlias = Mapping[str, ConnParam] + # Waiting protocol types RV = TypeVar("RV") -PQGenConn: TypeAlias = Generator[Tuple[int, "Wait"], "Ready", RV] +PQGenConn: TypeAlias = Generator[Tuple[int, "Wait"], Union["Ready", int], RV] """Generator for processes where the connection file number can change. This can happen in connection and reset, but not in normal querying. """ -PQGen: TypeAlias = Generator["Wait", "Ready", RV] +PQGen: TypeAlias = Generator["Wait", Union["Ready", int], RV] """Generator for processes where the connection file number won't change. """ @@ -54,8 +57,7 @@ class WaitFunc(Protocol): def __call__( self, gen: PQGen[RV], fileno: int, timeout: Optional[float] = None - ) -> RV: - ... + ) -> RV: ... # Adaptation types @@ -106,8 +108,7 @@ class Dumper(Protocol): oid: int """The oid to pass to the server, if known; 0 otherwise (class attribute).""" - def __init__(self, cls: type, context: Optional[AdaptContext] = None): - ... + def __init__(self, cls: type, context: Optional[AdaptContext] = None): ... def dump(self, obj: Any) -> Buffer: """Convert the object `!obj` to PostgreSQL representation. @@ -187,8 +188,7 @@ class Loader(Protocol): This is a class attribute. """ - def __init__(self, oid: int, context: Optional[AdaptContext] = None): - ... + def __init__(self, oid: int, context: Optional[AdaptContext] = None): ... def load(self, data: Buffer) -> Any: """ @@ -203,28 +203,22 @@ class Transformer(Protocol): types: Optional[Tuple[int, ...]] formats: Optional[List[pq.Format]] - def __init__(self, context: Optional[AdaptContext] = None): - ... + def __init__(self, context: Optional[AdaptContext] = None): ... @classmethod - def from_context(cls, context: Optional[AdaptContext]) -> "Transformer": - ... + def from_context(cls, context: Optional[AdaptContext]) -> "Transformer": ... @property - def connection(self) -> Optional["BaseConnection[Any]"]: - ... + def connection(self) -> Optional["BaseConnection[Any]"]: ... @property - def encoding(self) -> str: - ... + def encoding(self) -> str: ... @property - def adapters(self) -> "AdaptersMap": - ... + def adapters(self) -> "AdaptersMap": ... @property - def pgresult(self) -> Optional["PGresult"]: - ... + def pgresult(self) -> Optional["PGresult"]: ... def set_pgresult( self, @@ -232,34 +226,26 @@ class Transformer(Protocol): *, set_loaders: bool = True, format: Optional[pq.Format] = None - ) -> None: - ... + ) -> None: ... - def set_dumper_types(self, types: Sequence[int], format: pq.Format) -> None: - ... + def set_dumper_types(self, types: Sequence[int], format: pq.Format) -> None: ... - def set_loader_types(self, types: Sequence[int], format: pq.Format) -> None: - ... + def set_loader_types(self, types: Sequence[int], format: pq.Format) -> None: ... def dump_sequence( self, params: Sequence[Any], formats: Sequence[PyFormat] - ) -> Sequence[Optional[Buffer]]: - ... + ) -> Sequence[Optional[Buffer]]: ... - def as_literal(self, obj: Any) -> bytes: - ... + def as_literal(self, obj: Any) -> bytes: ... - def get_dumper(self, obj: Any, format: PyFormat) -> Dumper: - ... + def get_dumper(self, obj: Any, format: PyFormat) -> Dumper: ... - def load_rows(self, row0: int, row1: int, make_row: "RowMaker[Row]") -> List["Row"]: - ... + def load_rows( + self, row0: int, row1: int, make_row: "RowMaker[Row]" + ) -> List["Row"]: ... - def load_row(self, row: int, make_row: "RowMaker[Row]") -> Optional["Row"]: - ... + def load_row(self, row: int, make_row: "RowMaker[Row]") -> Optional["Row"]: ... - def load_sequence(self, record: Sequence[Optional[Buffer]]) -> Tuple[Any, ...]: - ... + def load_sequence(self, record: Sequence[Optional[Buffer]]) -> Tuple[Any, ...]: ... - def get_loader(self, oid: int, format: pq.Format) -> Loader: - ... + def get_loader(self, oid: int, format: pq.Format) -> Loader: ... diff --git a/psycopg/psycopg/adapt.py b/psycopg/psycopg/adapt.py index 31a710429..7d6a191d8 100644 --- a/psycopg/psycopg/adapt.py +++ b/psycopg/psycopg/adapt.py @@ -46,8 +46,7 @@ class Dumper(abc.Dumper, ABC): ) @abstractmethod - def dump(self, obj: Any) -> Buffer: - ... + def dump(self, obj: Any) -> Buffer: ... def quote(self, obj: Any) -> Buffer: """ diff --git a/psycopg/psycopg/connection.py b/psycopg/psycopg/connection.py index dc02ce381..cb0244aa5 100644 --- a/psycopg/psycopg/connection.py +++ b/psycopg/psycopg/connection.py @@ -10,6 +10,7 @@ Psycopg connection object (sync version) from __future__ import annotations import logging +from time import monotonic from types import TracebackType from typing import Any, Generator, Iterator, List, Optional from typing import Type, Union, cast, overload, TYPE_CHECKING @@ -18,13 +19,13 @@ from contextlib import contextmanager from . import pq from . import errors as e from . import waiting -from .abc import AdaptContext, Params, PQGen, PQGenConn, Query, RV +from .abc import AdaptContext, ConnDict, ConnParam, Params, PQGen, PQGenConn, Query, RV from ._tpc import Xid from .rows import Row, RowFactory, tuple_row, args_row from .adapt import AdaptersMap from ._enums import IsolationLevel from ._compat import Self -from .conninfo import ConnDict, make_conninfo, conninfo_to_dict +from .conninfo import make_conninfo, conninfo_to_dict from .conninfo import conninfo_attempts, timeout_from_conninfo from ._pipeline import Pipeline from ._encodings import pgconn_encoding @@ -39,10 +40,13 @@ from threading import Lock if TYPE_CHECKING: from .pq.abc import PGconn +_WAIT_INTERVAL = 0.1 + TEXT = pq.Format.TEXT BINARY = pq.Format.BINARY IDLE = pq.TransactionStatus.IDLE +ACTIVE = pq.TransactionStatus.ACTIVE INTRANS = pq.TransactionStatus.INTRANS _INTERRUPTED = KeyboardInterrupt @@ -83,7 +87,7 @@ class Connection(BaseConnection[Row]): context: Optional[AdaptContext] = None, row_factory: Optional[RowFactory[Row]] = None, cursor_factory: Optional[Type[Cursor[Row]]] = None, - **kwargs: Any, + **kwargs: ConnParam, ) -> Self: """ Connect to a database server and return a new `Connection` instance. @@ -95,7 +99,7 @@ class Connection(BaseConnection[Row]): attempts = conninfo_attempts(params) for attempt in attempts: try: - conninfo = make_conninfo(**attempt) + conninfo = make_conninfo("", **attempt) rv = cls._wait_conn(cls._connect_gen(conninfo), timeout=timeout) break except e._NO_TRACEBACK as ex: @@ -165,14 +169,12 @@ class Connection(BaseConnection[Row]): self.pgconn.finish() @overload - def cursor(self, *, binary: bool = False) -> Cursor[Row]: - ... + def cursor(self, *, binary: bool = False) -> Cursor[Row]: ... @overload def cursor( self, *, binary: bool = False, row_factory: RowFactory[CursorRow] - ) -> Cursor[CursorRow]: - ... + ) -> Cursor[CursorRow]: ... @overload def cursor( @@ -182,8 +184,7 @@ class Connection(BaseConnection[Row]): binary: bool = False, scrollable: Optional[bool] = None, withhold: bool = False, - ) -> ServerCursor[Row]: - ... + ) -> ServerCursor[Row]: ... @overload def cursor( @@ -194,8 +195,7 @@ class Connection(BaseConnection[Row]): row_factory: RowFactory[CursorRow], scrollable: Optional[bool] = None, withhold: bool = False, - ) -> ServerCursor[CursorRow]: - ... + ) -> ServerCursor[CursorRow]: ... def cursor( self, @@ -280,20 +280,56 @@ class Connection(BaseConnection[Row]): with tx: yield tx - def notifies(self) -> Generator[Notify, None, None]: + def notifies( + self, *, timeout: Optional[float] = None, stop_after: Optional[int] = None + ) -> Generator[Notify, None, None]: """ Yield `Notify` objects as soon as they are received from the database. + + :param timeout: maximum amount of time to wait for notifications. + `!None` means no timeout. + :param stop_after: stop after receiving this number of notifications. + You might actually receive more than this number if more than one + notifications arrives in the same packet. """ + # Allow interrupting the wait with a signal by reducing a long timeout + # into shorter interval. + if timeout is not None: + deadline = monotonic() + timeout + timeout = min(timeout, _WAIT_INTERVAL) + else: + deadline = None + timeout = _WAIT_INTERVAL + + nreceived = 0 + while True: - with self.lock: - try: - ns = self.wait(notifies(self.pgconn)) - except e._NO_TRACEBACK as ex: - raise ex.with_traceback(None) - enc = pgconn_encoding(self.pgconn) + # Collect notifications. Also get the connection encoding if any + # notification is received to makes sure that they are consistent. + try: + with self.lock: + ns = self.wait(notifies(self.pgconn), timeout=timeout) + if ns: + enc = pgconn_encoding(self.pgconn) + except e._NO_TRACEBACK as ex: + raise ex.with_traceback(None) + + # Emit the notifications received. for pgn in ns: n = Notify(pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid) yield n + nreceived += 1 + + # Stop if we have received enough notifications. + if stop_after is not None and nreceived >= stop_after: + break + + # Check the deadline after the loop to ensure that timeout=0 + # polls at least once. + if deadline: + timeout = min(_WAIT_INTERVAL, deadline - monotonic()) + if timeout < 0.0: + break @contextmanager def pipeline(self) -> Iterator[Pipeline]: @@ -315,7 +351,7 @@ class Connection(BaseConnection[Row]): assert pipeline is self._pipeline self._pipeline = None - def wait(self, gen: PQGen[RV], timeout: Optional[float] = 0.1) -> RV: + def wait(self, gen: PQGen[RV], timeout: Optional[float] = _WAIT_INTERVAL) -> RV: """ Consume a generator operating on the connection. @@ -325,13 +361,14 @@ class Connection(BaseConnection[Row]): try: return waiting.wait(gen, self.pgconn.socket, timeout=timeout) except _INTERRUPTED: - # On Ctrl-C, try to cancel the query in the server, otherwise - # the connection will remain stuck in ACTIVE state. - self._try_cancel(self.pgconn) - try: - waiting.wait(gen, self.pgconn.socket, timeout=timeout) - except e.QueryCanceled: - pass # as expected + if self.pgconn.transaction_status == ACTIVE: + # On Ctrl-C, try to cancel the query in the server, otherwise + # the connection will remain stuck in ACTIVE state. + self._try_cancel(self.pgconn) + try: + waiting.wait(gen, self.pgconn.socket, timeout=timeout) + except e.QueryCanceled: + pass # as expected raise @classmethod diff --git a/psycopg/psycopg/connection_async.py b/psycopg/psycopg/connection_async.py index 46269fda4..d810d45b2 100644 --- a/psycopg/psycopg/connection_async.py +++ b/psycopg/psycopg/connection_async.py @@ -7,6 +7,7 @@ Psycopg connection object (async version) from __future__ import annotations import logging +from time import monotonic from types import TracebackType from typing import Any, AsyncGenerator, AsyncIterator, List, Optional from typing import Type, Union, cast, overload, TYPE_CHECKING @@ -15,13 +16,13 @@ from contextlib import asynccontextmanager from . import pq from . import errors as e from . import waiting -from .abc import AdaptContext, Params, PQGen, PQGenConn, Query, RV +from .abc import AdaptContext, ConnDict, ConnParam, Params, PQGen, PQGenConn, Query, RV from ._tpc import Xid from .rows import Row, AsyncRowFactory, tuple_row, args_row from .adapt import AdaptersMap from ._enums import IsolationLevel from ._compat import Self -from .conninfo import ConnDict, make_conninfo, conninfo_to_dict +from .conninfo import make_conninfo, conninfo_to_dict from .conninfo import conninfo_attempts_async, timeout_from_conninfo from ._pipeline import AsyncPipeline from ._encodings import pgconn_encoding @@ -41,10 +42,13 @@ else: if TYPE_CHECKING: from .pq.abc import PGconn +_WAIT_INTERVAL = 0.1 + TEXT = pq.Format.TEXT BINARY = pq.Format.BINARY IDLE = pq.TransactionStatus.IDLE +ACTIVE = pq.TransactionStatus.ACTIVE INTRANS = pq.TransactionStatus.INTRANS if True: # ASYNC @@ -88,7 +92,7 @@ class AsyncConnection(BaseConnection[Row]): context: Optional[AdaptContext] = None, row_factory: Optional[AsyncRowFactory[Row]] = None, cursor_factory: Optional[Type[AsyncCursor[Row]]] = None, - **kwargs: Any, + **kwargs: ConnParam, ) -> Self: """ Connect to a database server and return a new `AsyncConnection` instance. @@ -110,7 +114,7 @@ class AsyncConnection(BaseConnection[Row]): attempts = await conninfo_attempts_async(params) for attempt in attempts: try: - conninfo = make_conninfo(**attempt) + conninfo = make_conninfo("", **attempt) rv = await cls._wait_conn(cls._connect_gen(conninfo), timeout=timeout) break except e._NO_TRACEBACK as ex: @@ -180,14 +184,12 @@ class AsyncConnection(BaseConnection[Row]): self.pgconn.finish() @overload - def cursor(self, *, binary: bool = False) -> AsyncCursor[Row]: - ... + def cursor(self, *, binary: bool = False) -> AsyncCursor[Row]: ... @overload def cursor( self, *, binary: bool = False, row_factory: AsyncRowFactory[CursorRow] - ) -> AsyncCursor[CursorRow]: - ... + ) -> AsyncCursor[CursorRow]: ... @overload def cursor( @@ -197,8 +199,7 @@ class AsyncConnection(BaseConnection[Row]): binary: bool = False, scrollable: Optional[bool] = None, withhold: bool = False, - ) -> AsyncServerCursor[Row]: - ... + ) -> AsyncServerCursor[Row]: ... @overload def cursor( @@ -209,8 +210,7 @@ class AsyncConnection(BaseConnection[Row]): row_factory: AsyncRowFactory[CursorRow], scrollable: Optional[bool] = None, withhold: bool = False, - ) -> AsyncServerCursor[CursorRow]: - ... + ) -> AsyncServerCursor[CursorRow]: ... def cursor( self, @@ -296,20 +296,56 @@ class AsyncConnection(BaseConnection[Row]): async with tx: yield tx - async def notifies(self) -> AsyncGenerator[Notify, None]: + async def notifies( + self, *, timeout: Optional[float] = None, stop_after: Optional[int] = None + ) -> AsyncGenerator[Notify, None]: """ Yield `Notify` objects as soon as they are received from the database. + + :param timeout: maximum amount of time to wait for notifications. + `!None` means no timeout. + :param stop_after: stop after receiving this number of notifications. + You might actually receive more than this number if more than one + notifications arrives in the same packet. """ + # Allow interrupting the wait with a signal by reducing a long timeout + # into shorter interval. + if timeout is not None: + deadline = monotonic() + timeout + timeout = min(timeout, _WAIT_INTERVAL) + else: + deadline = None + timeout = _WAIT_INTERVAL + + nreceived = 0 + while True: - async with self.lock: - try: - ns = await self.wait(notifies(self.pgconn)) - except e._NO_TRACEBACK as ex: - raise ex.with_traceback(None) - enc = pgconn_encoding(self.pgconn) + # Collect notifications. Also get the connection encoding if any + # notification is received to makes sure that they are consistent. + try: + async with self.lock: + ns = await self.wait(notifies(self.pgconn), timeout=timeout) + if ns: + enc = pgconn_encoding(self.pgconn) + except e._NO_TRACEBACK as ex: + raise ex.with_traceback(None) + + # Emit the notifications received. for pgn in ns: n = Notify(pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid) yield n + nreceived += 1 + + # Stop if we have received enough notifications. + if stop_after is not None and nreceived >= stop_after: + break + + # Check the deadline after the loop to ensure that timeout=0 + # polls at least once. + if deadline: + timeout = min(_WAIT_INTERVAL, deadline - monotonic()) + if timeout < 0.0: + break @asynccontextmanager async def pipeline(self) -> AsyncIterator[AsyncPipeline]: @@ -331,7 +367,9 @@ class AsyncConnection(BaseConnection[Row]): assert pipeline is self._pipeline self._pipeline = None - async def wait(self, gen: PQGen[RV], timeout: Optional[float] = 0.1) -> RV: + async def wait( + self, gen: PQGen[RV], timeout: Optional[float] = _WAIT_INTERVAL + ) -> RV: """ Consume a generator operating on the connection. @@ -341,13 +379,14 @@ class AsyncConnection(BaseConnection[Row]): try: return await waiting.wait_async(gen, self.pgconn.socket, timeout=timeout) except _INTERRUPTED: - # On Ctrl-C, try to cancel the query in the server, otherwise - # the connection will remain stuck in ACTIVE state. - self._try_cancel(self.pgconn) - try: - await waiting.wait_async(gen, self.pgconn.socket, timeout=timeout) - except e.QueryCanceled: - pass # as expected + if self.pgconn.transaction_status == ACTIVE: + # On Ctrl-C, try to cancel the query in the server, otherwise + # the connection will remain stuck in ACTIVE state. + self._try_cancel(self.pgconn) + try: + await waiting.wait_async(gen, self.pgconn.socket, timeout=timeout) + except e.QueryCanceled: + pass # as expected raise @classmethod diff --git a/psycopg/psycopg/conninfo.py b/psycopg/psycopg/conninfo.py index 82da58822..1401426b2 100644 --- a/psycopg/psycopg/conninfo.py +++ b/psycopg/psycopg/conninfo.py @@ -7,17 +7,15 @@ Functions to manipulate conninfo strings from __future__ import annotations import re -from typing import Any from . import pq from . import errors as e - from . import _conninfo_utils from . import _conninfo_attempts from . import _conninfo_attempts_async +from .abc import ConnParam, ConnDict # re-exoprts -ConnDict = _conninfo_utils.ConnDict conninfo_attempts = _conninfo_attempts.conninfo_attempts conninfo_attempts_async = _conninfo_attempts_async.conninfo_attempts_async @@ -27,7 +25,7 @@ conninfo_attempts_async = _conninfo_attempts_async.conninfo_attempts_async _DEFAULT_CONNECT_TIMEOUT = 130 -def make_conninfo(conninfo: str = "", **kwargs: Any) -> str: +def make_conninfo(conninfo: str = "", **kwargs: ConnParam) -> str: """ Merge a string and keyword params into a single conninfo string. @@ -68,7 +66,7 @@ def make_conninfo(conninfo: str = "", **kwargs: Any) -> str: return conninfo -def conninfo_to_dict(conninfo: str = "", **kwargs: Any) -> ConnDict: +def conninfo_to_dict(conninfo: str = "", **kwargs: ConnParam) -> ConnDict: """ Convert the `!conninfo` string into a dictionary of parameters. @@ -84,7 +82,9 @@ def conninfo_to_dict(conninfo: str = "", **kwargs: Any) -> ConnDict: #LIBPQ-CONNSTRING """ opts = _parse_conninfo(conninfo) - rv = {opt.keyword.decode(): opt.val.decode() for opt in opts if opt.val is not None} + rv: ConnDict = { + opt.keyword.decode(): opt.val.decode() for opt in opts if opt.val is not None + } for k, v in kwargs.items(): if v is not None: rv[k] = v diff --git a/psycopg/psycopg/cursor.py b/psycopg/psycopg/cursor.py index 10741c95f..6b48929bc 100644 --- a/psycopg/psycopg/cursor.py +++ b/psycopg/psycopg/cursor.py @@ -34,12 +34,12 @@ class Cursor(BaseCursor["Connection[Any]", Row]): __slots__ = () @overload - def __init__(self, connection: Connection[Row]): - ... + def __init__(self, connection: Connection[Row]): ... @overload - def __init__(self, connection: Connection[Any], *, row_factory: RowFactory[Row]): - ... + def __init__( + self, connection: Connection[Any], *, row_factory: RowFactory[Row] + ): ... def __init__( self, diff --git a/psycopg/psycopg/cursor_async.py b/psycopg/psycopg/cursor_async.py index 603560155..55dc9a5c2 100644 --- a/psycopg/psycopg/cursor_async.py +++ b/psycopg/psycopg/cursor_async.py @@ -31,14 +31,12 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]): __slots__ = () @overload - def __init__(self, connection: AsyncConnection[Row]): - ... + def __init__(self, connection: AsyncConnection[Row]): ... @overload def __init__( self, connection: AsyncConnection[Any], *, row_factory: AsyncRowFactory[Row] - ): - ... + ): ... def __init__( self, diff --git a/psycopg/psycopg/errors.py b/psycopg/psycopg/errors.py index d2cd81207..d2e2a955e 100644 --- a/psycopg/psycopg/errors.py +++ b/psycopg/psycopg/errors.py @@ -21,12 +21,11 @@ DBAPI-defined Exceptions are defined in the following hierarchy:: from dataclasses import dataclass, field, fields from typing import Any, Callable, Dict, List, NoReturn, Optional, Sequence, Tuple, Type from typing import Union, TYPE_CHECKING -from typing_extensions import TypeAlias from asyncio import CancelledError from .pq.abc import PGconn, PGresult from .pq._enums import ConnStatus, DiagnosticField, PipelineStatus, TransactionStatus -from ._compat import TypeGuard +from ._compat import TypeAlias, TypeGuard if TYPE_CHECKING: from .pq.misc import PGnotify, ConninfoOption diff --git a/psycopg/psycopg/generators.py b/psycopg/psycopg/generators.py index 4f2ec878b..2e463196e 100644 --- a/psycopg/psycopg/generators.py +++ b/psycopg/psycopg/generators.py @@ -7,10 +7,15 @@ the operations, yielding a polling state whenever there is to wait. The functions in the `waiting` module are the ones who wait more or less cooperatively for the socket to be ready and make these generators continue. -All these generators yield pairs (fileno, `Wait`) whenever an operation would -block. The generator can be restarted sending the appropriate `Ready` state -when the file descriptor is ready. - +These generators yield `Wait` objects whenever an operation would block. These +generators assume the connection fileno will not change. In case of the +connection function, where the fileno may change, the generators yield pairs +(fileno, `Wait`). + +The generator can be restarted sending the appropriate `Ready` state when the +file descriptor is ready. If a None value is sent, it means that the wait +function timed out without any file descriptor becoming ready; in this case the +generator should probably yield the same value again in order to wait more. """ # Copyright (C) 2020 The Psycopg Team @@ -119,7 +124,11 @@ def _send(pgconn: PGconn) -> PQGen[None]: if f == 0: break - ready = yield WAIT_RW + while True: + ready = yield WAIT_RW + if ready: + break + if ready & READY_R: # This call may read notifies: they will be saved in the # PGconn buffer and passed to Python later, in `fetch()`. @@ -168,12 +177,19 @@ def _fetch(pgconn: PGconn) -> PQGen[Optional[PGresult]]: Return a result from the database (whether success or error). """ if pgconn.is_busy(): - yield WAIT_R + while True: + ready = yield WAIT_R + if ready: + break + while True: pgconn.consume_input() if not pgconn.is_busy(): break - yield WAIT_R + while True: + ready = yield WAIT_R + if ready: + break _consume_notifies(pgconn) @@ -191,7 +207,10 @@ def _pipeline_communicate( results = [] while True: - ready = yield WAIT_RW + while True: + ready = yield WAIT_RW + if ready: + break if ready & READY_R: pgconn.consume_input() @@ -263,7 +282,10 @@ def copy_from(pgconn: PGconn) -> PQGen[Union[memoryview, PGresult]]: break # would block - yield WAIT_R + while True: + ready = yield WAIT_R + if ready: + break pgconn.consume_input() if nbytes > 0: @@ -291,17 +313,26 @@ def copy_to(pgconn: PGconn, buffer: Buffer) -> PQGen[None]: # into smaller ones. We prefer to do it there instead of here in order to # do it upstream the queue decoupling the writer task from the producer one. while pgconn.put_copy_data(buffer) == 0: - yield WAIT_W + while True: + ready = yield WAIT_W + if ready: + break def copy_end(pgconn: PGconn, error: Optional[bytes]) -> PQGen[PGresult]: # Retry enqueuing end copy message until successful while pgconn.put_copy_end(error) == 0: - yield WAIT_W + while True: + ready = yield WAIT_W + if ready: + break # Repeat until it the message is flushed to the server while True: - yield WAIT_W + while True: + ready = yield WAIT_W + if ready: + break f = pgconn.flush() if f == 0: break diff --git a/psycopg/psycopg/pq/_pq_ctypes.py b/psycopg/psycopg/pq/_pq_ctypes.py index 9d4dd1814..1b0f391f2 100644 --- a/psycopg/psycopg/pq/_pq_ctypes.py +++ b/psycopg/psycopg/pq/_pq_ctypes.py @@ -29,7 +29,10 @@ FILE_ptr = POINTER(FILE) if sys.platform == "linux": libcname = ctypes.util.find_library("c") - assert libcname + if not libcname: + # Likely this is a system using musl libc, see the following bug: + # https://github.com/python/cpython/issues/65821 + libcname = "libc.so" libc = ctypes.cdll.LoadLibrary(libcname) fdopen = libc.fdopen diff --git a/psycopg/psycopg/pq/abc.py b/psycopg/psycopg/pq/abc.py index 3a76d56c0..13a077211 100644 --- a/psycopg/psycopg/pq/abc.py +++ b/psycopg/psycopg/pq/abc.py @@ -6,9 +6,9 @@ Protocol objects to represent objects exposed by different pq implementations. from typing import Any, Callable, List, Optional, Protocol, Sequence, Tuple from typing import Union, TYPE_CHECKING -from typing_extensions import TypeAlias from ._enums import Format, Trace +from .._compat import TypeAlias if TYPE_CHECKING: from .misc import PGnotify, ConninfoOption, PGresAttDesc @@ -22,112 +22,83 @@ class PGconn(Protocol): notify_handler: Optional[Callable[["PGnotify"], None]] @classmethod - def connect(cls, conninfo: bytes) -> "PGconn": - ... + def connect(cls, conninfo: bytes) -> "PGconn": ... @classmethod - def connect_start(cls, conninfo: bytes) -> "PGconn": - ... + def connect_start(cls, conninfo: bytes) -> "PGconn": ... - def connect_poll(self) -> int: - ... + def connect_poll(self) -> int: ... - def finish(self) -> None: - ... + def finish(self) -> None: ... @property - def info(self) -> List["ConninfoOption"]: - ... + def info(self) -> List["ConninfoOption"]: ... - def reset(self) -> None: - ... + def reset(self) -> None: ... - def reset_start(self) -> None: - ... + def reset_start(self) -> None: ... - def reset_poll(self) -> int: - ... + def reset_poll(self) -> int: ... @classmethod - def ping(self, conninfo: bytes) -> int: - ... + def ping(self, conninfo: bytes) -> int: ... @property - def db(self) -> bytes: - ... + def db(self) -> bytes: ... @property - def user(self) -> bytes: - ... + def user(self) -> bytes: ... @property - def password(self) -> bytes: - ... + def password(self) -> bytes: ... @property - def host(self) -> bytes: - ... + def host(self) -> bytes: ... @property - def hostaddr(self) -> bytes: - ... + def hostaddr(self) -> bytes: ... @property - def port(self) -> bytes: - ... + def port(self) -> bytes: ... @property - def tty(self) -> bytes: - ... + def tty(self) -> bytes: ... @property - def options(self) -> bytes: - ... + def options(self) -> bytes: ... @property - def status(self) -> int: - ... + def status(self) -> int: ... @property - def transaction_status(self) -> int: - ... + def transaction_status(self) -> int: ... - def parameter_status(self, name: bytes) -> Optional[bytes]: - ... + def parameter_status(self, name: bytes) -> Optional[bytes]: ... @property - def error_message(self) -> bytes: - ... + def error_message(self) -> bytes: ... @property - def server_version(self) -> int: - ... + def server_version(self) -> int: ... @property - def socket(self) -> int: - ... + def socket(self) -> int: ... @property - def backend_pid(self) -> int: - ... + def backend_pid(self) -> int: ... @property - def needs_password(self) -> bool: - ... + def needs_password(self) -> bool: ... @property - def used_password(self) -> bool: - ... + def used_password(self) -> bool: ... @property - def ssl_in_use(self) -> bool: - ... + def ssl_in_use(self) -> bool: ... - def exec_(self, command: bytes) -> "PGresult": - ... + def exec_(self, command: bytes) -> "PGresult": ... - def send_query(self, command: bytes) -> None: - ... + def send_query(self, command: bytes) -> None: ... def exec_params( self, @@ -136,8 +107,7 @@ class PGconn(Protocol): param_types: Optional[Sequence[int]] = None, param_formats: Optional[Sequence[int]] = None, result_format: int = Format.TEXT, - ) -> "PGresult": - ... + ) -> "PGresult": ... def send_query_params( self, @@ -146,16 +116,14 @@ class PGconn(Protocol): param_types: Optional[Sequence[int]] = None, param_formats: Optional[Sequence[int]] = None, result_format: int = Format.TEXT, - ) -> None: - ... + ) -> None: ... def send_prepare( self, name: bytes, command: bytes, param_types: Optional[Sequence[int]] = None, - ) -> None: - ... + ) -> None: ... def send_query_prepared( self, @@ -163,16 +131,14 @@ class PGconn(Protocol): param_values: Optional[Sequence[Optional[Buffer]]], param_formats: Optional[Sequence[int]] = None, result_format: int = Format.TEXT, - ) -> None: - ... + ) -> None: ... def prepare( self, name: bytes, command: bytes, param_types: Optional[Sequence[int]] = None, - ) -> "PGresult": - ... + ) -> "PGresult": ... def exec_prepared( self, @@ -180,216 +146,153 @@ class PGconn(Protocol): param_values: Optional[Sequence[Buffer]], param_formats: Optional[Sequence[int]] = None, result_format: int = 0, - ) -> "PGresult": - ... + ) -> "PGresult": ... - def describe_prepared(self, name: bytes) -> "PGresult": - ... + def describe_prepared(self, name: bytes) -> "PGresult": ... - def send_describe_prepared(self, name: bytes) -> None: - ... + def send_describe_prepared(self, name: bytes) -> None: ... - def describe_portal(self, name: bytes) -> "PGresult": - ... + def describe_portal(self, name: bytes) -> "PGresult": ... - def send_describe_portal(self, name: bytes) -> None: - ... + def send_describe_portal(self, name: bytes) -> None: ... - def close_prepared(self, name: bytes) -> "PGresult": - ... + def close_prepared(self, name: bytes) -> "PGresult": ... - def send_close_prepared(self, name: bytes) -> None: - ... + def send_close_prepared(self, name: bytes) -> None: ... - def close_portal(self, name: bytes) -> "PGresult": - ... + def close_portal(self, name: bytes) -> "PGresult": ... - def send_close_portal(self, name: bytes) -> None: - ... + def send_close_portal(self, name: bytes) -> None: ... - def get_result(self) -> Optional["PGresult"]: - ... + def get_result(self) -> Optional["PGresult"]: ... - def consume_input(self) -> None: - ... + def consume_input(self) -> None: ... - def is_busy(self) -> int: - ... + def is_busy(self) -> int: ... @property - def nonblocking(self) -> int: - ... + def nonblocking(self) -> int: ... @nonblocking.setter - def nonblocking(self, arg: int) -> None: - ... + def nonblocking(self, arg: int) -> None: ... - def flush(self) -> int: - ... + def flush(self) -> int: ... - def set_single_row_mode(self) -> None: - ... + def set_single_row_mode(self) -> None: ... - def get_cancel(self) -> "PGcancel": - ... + def get_cancel(self) -> "PGcancel": ... - def notifies(self) -> Optional["PGnotify"]: - ... + def notifies(self) -> Optional["PGnotify"]: ... - def put_copy_data(self, buffer: Buffer) -> int: - ... + def put_copy_data(self, buffer: Buffer) -> int: ... - def put_copy_end(self, error: Optional[bytes] = None) -> int: - ... + def put_copy_end(self, error: Optional[bytes] = None) -> int: ... - def get_copy_data(self, async_: int) -> Tuple[int, memoryview]: - ... + def get_copy_data(self, async_: int) -> Tuple[int, memoryview]: ... - def trace(self, fileno: int) -> None: - ... + def trace(self, fileno: int) -> None: ... - def set_trace_flags(self, flags: Trace) -> None: - ... + def set_trace_flags(self, flags: Trace) -> None: ... - def untrace(self) -> None: - ... + def untrace(self) -> None: ... def encrypt_password( self, passwd: bytes, user: bytes, algorithm: Optional[bytes] = None - ) -> bytes: - ... + ) -> bytes: ... - def make_empty_result(self, exec_status: int) -> "PGresult": - ... + def make_empty_result(self, exec_status: int) -> "PGresult": ... @property - def pipeline_status(self) -> int: - ... + def pipeline_status(self) -> int: ... - def enter_pipeline_mode(self) -> None: - ... + def enter_pipeline_mode(self) -> None: ... - def exit_pipeline_mode(self) -> None: - ... + def exit_pipeline_mode(self) -> None: ... - def pipeline_sync(self) -> None: - ... + def pipeline_sync(self) -> None: ... - def send_flush_request(self) -> None: - ... + def send_flush_request(self) -> None: ... class PGresult(Protocol): - def clear(self) -> None: - ... + def clear(self) -> None: ... @property - def status(self) -> int: - ... + def status(self) -> int: ... @property - def error_message(self) -> bytes: - ... + def error_message(self) -> bytes: ... - def error_field(self, fieldcode: int) -> Optional[bytes]: - ... + def error_field(self, fieldcode: int) -> Optional[bytes]: ... @property - def ntuples(self) -> int: - ... + def ntuples(self) -> int: ... @property - def nfields(self) -> int: - ... + def nfields(self) -> int: ... - def fname(self, column_number: int) -> Optional[bytes]: - ... + def fname(self, column_number: int) -> Optional[bytes]: ... - def ftable(self, column_number: int) -> int: - ... + def ftable(self, column_number: int) -> int: ... - def ftablecol(self, column_number: int) -> int: - ... + def ftablecol(self, column_number: int) -> int: ... - def fformat(self, column_number: int) -> int: - ... + def fformat(self, column_number: int) -> int: ... - def ftype(self, column_number: int) -> int: - ... + def ftype(self, column_number: int) -> int: ... - def fmod(self, column_number: int) -> int: - ... + def fmod(self, column_number: int) -> int: ... - def fsize(self, column_number: int) -> int: - ... + def fsize(self, column_number: int) -> int: ... @property - def binary_tuples(self) -> int: - ... + def binary_tuples(self) -> int: ... - def get_value(self, row_number: int, column_number: int) -> Optional[bytes]: - ... + def get_value(self, row_number: int, column_number: int) -> Optional[bytes]: ... @property - def nparams(self) -> int: - ... + def nparams(self) -> int: ... - def param_type(self, param_number: int) -> int: - ... + def param_type(self, param_number: int) -> int: ... @property - def command_status(self) -> Optional[bytes]: - ... + def command_status(self) -> Optional[bytes]: ... @property - def command_tuples(self) -> Optional[int]: - ... + def command_tuples(self) -> Optional[int]: ... @property - def oid_value(self) -> int: - ... + def oid_value(self) -> int: ... - def set_attributes(self, descriptions: List["PGresAttDesc"]) -> None: - ... + def set_attributes(self, descriptions: List["PGresAttDesc"]) -> None: ... class PGcancel(Protocol): - def free(self) -> None: - ... + def free(self) -> None: ... - def cancel(self) -> None: - ... + def cancel(self) -> None: ... class Conninfo(Protocol): @classmethod - def get_defaults(cls) -> List["ConninfoOption"]: - ... + def get_defaults(cls) -> List["ConninfoOption"]: ... @classmethod - def parse(cls, conninfo: bytes) -> List["ConninfoOption"]: - ... + def parse(cls, conninfo: bytes) -> List["ConninfoOption"]: ... @classmethod - def _options_from_array(cls, opts: Sequence[Any]) -> List["ConninfoOption"]: - ... + def _options_from_array(cls, opts: Sequence[Any]) -> List["ConninfoOption"]: ... class Escaping(Protocol): - def __init__(self, conn: Optional[PGconn] = None): - ... + def __init__(self, conn: Optional[PGconn] = None): ... - def escape_literal(self, data: Buffer) -> bytes: - ... + def escape_literal(self, data: Buffer) -> bytes: ... - def escape_identifier(self, data: Buffer) -> bytes: - ... + def escape_identifier(self, data: Buffer) -> bytes: ... - def escape_string(self, data: Buffer) -> bytes: - ... + def escape_string(self, data: Buffer) -> bytes: ... - def escape_bytea(self, data: Buffer) -> bytes: - ... + def escape_bytea(self, data: Buffer) -> bytes: ... - def unescape_bytea(self, data: Buffer) -> bytes: - ... + def unescape_bytea(self, data: Buffer) -> bytes: ... diff --git a/psycopg/psycopg/rows.py b/psycopg/psycopg/rows.py index 4c2f7781b..07e0dbcaf 100644 --- a/psycopg/psycopg/rows.py +++ b/psycopg/psycopg/rows.py @@ -8,11 +8,10 @@ import functools from typing import Any, Callable, Dict, List, Optional, NamedTuple, NoReturn from typing import TYPE_CHECKING, Protocol, Sequence, Tuple, Type from collections import namedtuple -from typing_extensions import TypeAlias from . import pq from . import errors as e -from ._compat import TypeVar +from ._compat import TypeAlias, TypeVar from ._encodings import _as_python_identifier if TYPE_CHECKING: @@ -44,8 +43,7 @@ class RowMaker(Protocol[Row]): Typically, `!RowMaker` functions are returned by `RowFactory`. """ - def __call__(self, __values: Sequence[Any]) -> Row: - ... + def __call__(self, __values: Sequence[Any]) -> Row: ... class RowFactory(Protocol[Row]): @@ -62,8 +60,7 @@ class RowFactory(Protocol[Row]): use the values to create a dictionary for each record. """ - def __call__(self, __cursor: "Cursor[Any]") -> RowMaker[Row]: - ... + def __call__(self, __cursor: "Cursor[Any]") -> RowMaker[Row]: ... class AsyncRowFactory(Protocol[Row]): @@ -71,8 +68,7 @@ class AsyncRowFactory(Protocol[Row]): Like `RowFactory`, taking an async cursor as argument. """ - def __call__(self, __cursor: "AsyncCursor[Any]") -> RowMaker[Row]: - ... + def __call__(self, __cursor: "AsyncCursor[Any]") -> RowMaker[Row]: ... class BaseRowFactory(Protocol[Row]): @@ -80,8 +76,7 @@ class BaseRowFactory(Protocol[Row]): Like `RowFactory`, taking either type of cursor as argument. """ - def __call__(self, __cursor: "BaseCursor[Any, Any]") -> RowMaker[Row]: - ... + def __call__(self, __cursor: "BaseCursor[Any, Any]") -> RowMaker[Row]: ... TupleRow: TypeAlias = Tuple[Any, ...] @@ -212,6 +207,28 @@ def kwargs_row(func: Callable[..., T]) -> BaseRowFactory[T]: return kwargs_row_ +def scalar_row(cursor: "BaseCursor[Any, Any]") -> "RowMaker[Any]": + """ + Generate a row factory returning the first column + as a scalar value. + """ + res = cursor.pgresult + if not res: + return no_result + + nfields = _get_nfields(res) + if nfields is None: + return no_result + + if nfields < 1: + raise e.ProgrammingError("at least one column expected") + + def scalar_row_(values: Sequence[Any]) -> Any: + return values[0] + + return scalar_row_ + + def no_result(values: Sequence[Any]) -> NoReturn: """A `RowMaker` that always fail. diff --git a/psycopg/psycopg/server_cursor.py b/psycopg/psycopg/server_cursor.py index 1c6e77aa1..2f5f44739 100644 --- a/psycopg/psycopg/server_cursor.py +++ b/psycopg/psycopg/server_cursor.py @@ -222,8 +222,7 @@ class ServerCursor(ServerCursorMixin["Connection[Any]", Row], Cursor[Row]): *, scrollable: Optional[bool] = None, withhold: bool = False, - ): - ... + ): ... @overload def __init__( @@ -234,8 +233,7 @@ class ServerCursor(ServerCursorMixin["Connection[Any]", Row], Cursor[Row]): row_factory: RowFactory[Row], scrollable: Optional[bool] = None, withhold: bool = False, - ): - ... + ): ... def __init__( self, @@ -363,8 +361,7 @@ class AsyncServerCursor( *, scrollable: Optional[bool] = None, withhold: bool = False, - ): - ... + ): ... @overload def __init__( @@ -375,8 +372,7 @@ class AsyncServerCursor( row_factory: AsyncRowFactory[Row], scrollable: Optional[bool] = None, withhold: bool = False, - ): - ... + ): ... def __init__( self, diff --git a/psycopg/psycopg/sql.py b/psycopg/psycopg/sql.py index d793bf389..a94f77f6e 100644 --- a/psycopg/psycopg/sql.py +++ b/psycopg/psycopg/sql.py @@ -55,7 +55,7 @@ class Composable(ABC): return f"{self.__class__.__name__}({self._obj!r})" @abstractmethod - def as_bytes(self, context: Optional[AdaptContext]) -> bytes: + def as_bytes(self, context: Optional[AdaptContext] = None) -> bytes: """ Return the value of the object as bytes. @@ -69,7 +69,7 @@ class Composable(ABC): """ raise NotImplementedError - def as_string(self, context: Optional[AdaptContext]) -> str: + def as_string(self, context: Optional[AdaptContext] = None) -> str: """ Return the value of the object as string. @@ -130,7 +130,7 @@ class Composed(Composable): seq = [obj if isinstance(obj, Composable) else Literal(obj) for obj in seq] super().__init__(seq) - def as_bytes(self, context: Optional[AdaptContext]) -> bytes: + def as_bytes(self, context: Optional[AdaptContext] = None) -> bytes: return b"".join(obj.as_bytes(context) for obj in self._obj) def __iter__(self) -> Iterator[Composable]: @@ -200,10 +200,10 @@ class SQL(Composable): if not isinstance(obj, str): raise TypeError(f"SQL values must be strings, got {obj!r} instead") - def as_string(self, context: Optional[AdaptContext]) -> str: + def as_string(self, context: Optional[AdaptContext] = None) -> str: return self._obj - def as_bytes(self, context: Optional[AdaptContext]) -> bytes: + def as_bytes(self, context: Optional[AdaptContext] = None) -> bytes: conn = context.connection if context else None enc = conn_encoding(conn) return self._obj.encode(enc) @@ -362,15 +362,22 @@ class Identifier(Composable): def __repr__(self) -> str: return f"{self.__class__.__name__}({', '.join(map(repr, self._obj))})" - def as_bytes(self, context: Optional[AdaptContext]) -> bytes: + def as_bytes(self, context: Optional[AdaptContext] = None) -> bytes: conn = context.connection if context else None - if not conn: - raise ValueError("a connection is necessary for Identifier") - esc = Escaping(conn.pgconn) - enc = conn_encoding(conn) - escs = [esc.escape_identifier(s.encode(enc)) for s in self._obj] + if conn: + esc = Escaping(conn.pgconn) + enc = conn_encoding(conn) + escs = [esc.escape_identifier(s.encode(enc)) for s in self._obj] + else: + escs = [self._escape_identifier(s.encode()) for s in self._obj] return b".".join(escs) + def _escape_identifier(self, s: bytes) -> bytes: + """ + Approximation of PQescapeIdentifier taking no connection. + """ + return b'"' + s.replace(b'"', b'""') + b'"' + class Literal(Composable): """ @@ -393,7 +400,7 @@ class Literal(Composable): """ - def as_bytes(self, context: Optional[AdaptContext]) -> bytes: + def as_bytes(self, context: Optional[AdaptContext] = None) -> bytes: tx = Transformer.from_context(context) return tx.as_literal(self._obj) @@ -452,11 +459,11 @@ class Placeholder(Composable): return f"{self.__class__.__name__}({', '.join(parts)})" - def as_string(self, context: Optional[AdaptContext]) -> str: + def as_string(self, context: Optional[AdaptContext] = None) -> str: code = self._format.value return f"%({self._obj}){code}" if self._obj else f"%{code}" - def as_bytes(self, context: Optional[AdaptContext]) -> bytes: + def as_bytes(self, context: Optional[AdaptContext] = None) -> bytes: conn = context.connection if context else None enc = conn_encoding(conn) return self.as_string(context).encode(enc) diff --git a/psycopg/psycopg/types/enum.py b/psycopg/psycopg/types/enum.py index e15c11299..6e20dd3ce 100644 --- a/psycopg/psycopg/types/enum.py +++ b/psycopg/psycopg/types/enum.py @@ -1,10 +1,10 @@ """ Adapters for the enum type. """ + from enum import Enum from typing import Any, Dict, Generic, Optional, Mapping, Sequence from typing import Tuple, Type, Union, cast, TYPE_CHECKING -from typing_extensions import TypeAlias from .. import sql from .. import postgres @@ -12,7 +12,7 @@ from .. import errors as e from ..pq import Format from ..abc import AdaptContext, Query from ..adapt import Buffer, Dumper, Loader -from .._compat import cache, TypeVar +from .._compat import cache, TypeAlias, TypeVar from .._encodings import conn_encoding from .._typeinfo import TypeInfo diff --git a/psycopg/psycopg/types/hstore.py b/psycopg/psycopg/types/hstore.py index 851a0556f..5bc261f55 100644 --- a/psycopg/psycopg/types/hstore.py +++ b/psycopg/psycopg/types/hstore.py @@ -6,14 +6,13 @@ Dict to hstore adaptation import re from typing import Dict, List, Optional, Type -from typing_extensions import TypeAlias from .. import errors as e from .. import postgres from ..abc import Buffer, AdaptContext from .._oids import TEXT_OID from ..adapt import PyFormat, RecursiveDumper, RecursiveLoader -from .._compat import cache +from .._compat import cache, TypeAlias from .._typeinfo import TypeInfo _re_escape = re.compile(r'(["\\])') diff --git a/psycopg/psycopg/types/multirange.py b/psycopg/psycopg/types/multirange.py index d672f6be8..51f61d1a7 100644 --- a/psycopg/psycopg/types/multirange.py +++ b/psycopg/psycopg/types/multirange.py @@ -91,12 +91,10 @@ class Multirange(MutableSequence[Range[T]]): return f"{{{', '.join(map(str, self._ranges))}}}" @overload - def __getitem__(self, index: int) -> Range[T]: - ... + def __getitem__(self, index: int) -> Range[T]: ... @overload - def __getitem__(self, index: slice) -> "Multirange[T]": - ... + def __getitem__(self, index: slice) -> "Multirange[T]": ... def __getitem__(self, index: Union[int, slice]) -> "Union[Range[T],Multirange[T]]": if isinstance(index, int): @@ -108,12 +106,10 @@ class Multirange(MutableSequence[Range[T]]): return len(self._ranges) @overload - def __setitem__(self, index: int, value: Range[T]) -> None: - ... + def __setitem__(self, index: int, value: Range[T]) -> None: ... @overload - def __setitem__(self, index: slice, value: Iterable[Range[T]]) -> None: - ... + def __setitem__(self, index: slice, value: Iterable[Range[T]]) -> None: ... def __setitem__( self, diff --git a/psycopg/psycopg/types/net.py b/psycopg/psycopg/types/net.py index 983de9a03..76522dcbb 100644 --- a/psycopg/psycopg/types/net.py +++ b/psycopg/psycopg/types/net.py @@ -5,12 +5,12 @@ Adapters for network types. # Copyright (C) 2020 The Psycopg Team from typing import Callable, Optional, Type, Union, TYPE_CHECKING -from typing_extensions import TypeAlias from .. import _oids from ..pq import Format from ..abc import AdaptContext from ..adapt import Buffer, Dumper, Loader +from .._compat import TypeAlias if TYPE_CHECKING: import ipaddress diff --git a/psycopg/psycopg/types/numeric.py b/psycopg/psycopg/types/numeric.py index f394bdac7..1817740fd 100644 --- a/psycopg/psycopg/types/numeric.py +++ b/psycopg/psycopg/types/numeric.py @@ -379,8 +379,7 @@ class _MixedNumericDumper(Dumper, ABC): _MixedNumericDumper.int_classes = int @abstractmethod - def dump(self, obj: Union[Decimal, int, "numpy.integer[Any]"]) -> Buffer: - ... + def dump(self, obj: Union[Decimal, int, "numpy.integer[Any]"]) -> Buffer: ... class NumericDumper(_MixedNumericDumper): diff --git a/psycopg/psycopg/waiting.py b/psycopg/psycopg/waiting.py index d6db0d922..4f307b6ef 100644 --- a/psycopg/psycopg/waiting.py +++ b/psycopg/psycopg/waiting.py @@ -26,6 +26,7 @@ from ._cmodule import _psycopg WAIT_R = Wait.R WAIT_W = Wait.W WAIT_RW = Wait.RW +READY_NONE = Ready.NONE READY_R = Ready.R READY_W = Ready.W READY_RW = Ready.RW @@ -51,16 +52,17 @@ def wait_selector(gen: PQGen[RV], fileno: int, timeout: Optional[float] = None) try: s = next(gen) with DefaultSelector() as sel: + sel.register(fileno, s) while True: - sel.register(fileno, s) - rlist = None - while not rlist: - rlist = sel.select(timeout=timeout) + rlist = sel.select(timeout=timeout) + if not rlist: + gen.send(READY_NONE) + continue + sel.unregister(fileno) - # note: this line should require a cast, but mypy doesn't complain - ready: Ready = rlist[0][1] - assert s & ready + ready = rlist[0][1] s = gen.send(ready) + sel.register(fileno, s) except StopIteration as ex: rv: RV = ex.args[0] if ex.args else None @@ -92,7 +94,7 @@ def wait_conn(gen: PQGenConn[RV], timeout: Optional[float] = None) -> RV: sel.unregister(fileno) if not rlist: raise e.ConnectionTimeout("connection timeout expired") - ready: Ready = rlist[0][1] # type: ignore[assignment] + ready = rlist[0][1] fileno, s = gen.send(ready) except StopIteration as ex: @@ -110,7 +112,7 @@ async def wait_async( `Ready` values when it would block. :param fileno: the file descriptor to wait on. :param timeout: timeout (in seconds) to check for other interrupt, e.g. - to allow Ctrl-C. If zero or None, wait indefinitely. + to allow Ctrl-C. If zero, wait indefinitely. :return: whatever `!gen` returns on completion. Behave like in `wait()`, but exposing an `asyncio` interface. @@ -119,12 +121,12 @@ async def wait_async( # Not sure this is the best implementation but it's a start. ev = Event() loop = get_event_loop() - ready: Ready + ready: int s: Wait def wakeup(state: Ready) -> None: nonlocal ready - ready |= state # type: ignore[assignment] + ready |= state ev.set() try: @@ -135,19 +137,19 @@ async def wait_async( if not reader and not writer: raise e.InternalError(f"bad poll status: {s}") ev.clear() - ready = 0 # type: ignore[assignment] + ready = 0 if reader: loop.add_reader(fileno, wakeup, READY_R) if writer: loop.add_writer(fileno, wakeup, READY_W) try: - if timeout is None: - await ev.wait() - else: + if timeout is not None: try: await wait_for(ev.wait(), timeout) except TimeoutError: pass + else: + await ev.wait() finally: if reader: loop.remove_reader(fileno) @@ -155,6 +157,9 @@ async def wait_async( loop.remove_writer(fileno) s = gen.send(ready) + except OSError as ex: + # Assume the connection was closed + raise e.OperationalError(str(ex)) except StopIteration as ex: rv: RV = ex.args[0] if ex.args else None return rv @@ -245,9 +250,10 @@ def wait_select(gen: PQGen[RV], fileno: int, timeout: Optional[float] = None) -> if wl: ready |= READY_W if not ready: + gen.send(READY_NONE) continue - # assert s & ready - s = gen.send(ready) # type: ignore + + s = gen.send(ready) except StopIteration as ex: rv: RV = ex.args[0] if ex.args else None @@ -285,24 +291,22 @@ def wait_epoll(gen: PQGen[RV], fileno: int, timeout: Optional[float] = None) -> s = next(gen) if timeout is None or timeout < 0: - timeout = 0 - else: - timeout = int(timeout * 1000.0) + timeout = 0.0 with select.epoll() as epoll: evmask = _epoll_evmasks[s] epoll.register(fileno, evmask) while True: - fileevs = None - while not fileevs: - fileevs = epoll.poll(timeout) + fileevs = epoll.poll(timeout) + if not fileevs: + gen.send(READY_NONE) + continue ev = fileevs[0][1] ready = 0 if ev & ~select.EPOLLOUT: ready = READY_R if ev & ~select.EPOLLIN: ready |= READY_W - # assert s & ready s = gen.send(ready) evmask = _epoll_evmasks[s] epoll.modify(fileno, evmask) @@ -340,16 +344,17 @@ def wait_poll(gen: PQGen[RV], fileno: int, timeout: Optional[float] = None) -> R evmask = _poll_evmasks[s] poll.register(fileno, evmask) while True: - fileevs = None - while not fileevs: - fileevs = poll.poll(timeout) + fileevs = poll.poll(timeout) + if not fileevs: + gen.send(READY_NONE) + continue + ev = fileevs[0][1] ready = 0 if ev & ~select.POLLOUT: ready = READY_R if ev & ~select.POLLIN: ready |= READY_W - # assert s & ready s = gen.send(ready) evmask = _poll_evmasks[s] poll.modify(fileno, evmask) diff --git a/psycopg/setup.cfg b/psycopg/setup.cfg index f734c40ec..fbb544677 100644 --- a/psycopg/setup.cfg +++ b/psycopg/setup.cfg @@ -74,7 +74,7 @@ test = pytest-randomly >= 3.5 dev = ast-comments >= 1.1.2 - black >= 23.1.0 + black >= 24.1.0 codespell >= 2.2 dnspython >= 2.1 flake8 >= 4.0 diff --git a/psycopg_c/.flake8 b/psycopg_c/.flake8 index 2ae629c2d..40a061b1e 100644 --- a/psycopg_c/.flake8 +++ b/psycopg_c/.flake8 @@ -1,3 +1,3 @@ [flake8] max-line-length = 88 -ignore = W503, E203 +ignore = W503, E203, E704 diff --git a/psycopg_c/psycopg_c/_psycopg/generators.pyx b/psycopg_c/psycopg_c/_psycopg/generators.pyx index a51fce5e2..70335cf89 100644 --- a/psycopg_c/psycopg_c/_psycopg/generators.pyx +++ b/psycopg_c/psycopg_c/_psycopg/generators.pyx @@ -18,9 +18,11 @@ from psycopg._encodings import conninfo_encoding cdef object WAIT_W = Wait.W cdef object WAIT_R = Wait.R cdef object WAIT_RW = Wait.RW +cdef object PY_READY_NONE = Ready.NONE cdef object PY_READY_R = Ready.R cdef object PY_READY_W = Ready.W cdef object PY_READY_RW = Ready.RW +cdef int READY_NONE = Ready.NONE cdef int READY_R = Ready.R cdef int READY_W = Ready.W cdef int READY_RW = Ready.RW @@ -96,15 +98,19 @@ def send(pq.PGconn pgconn) -> PQGen[None]: to retrieve the results available. """ cdef libpq.PGconn *pgconn_ptr = pgconn._pgconn_ptr - cdef int status + cdef int ready cdef int cires while True: if pgconn.flush() == 0: break - status = yield WAIT_RW - if status & READY_R: + while True: + ready = yield WAIT_RW + if ready: + break + + if ready & READY_R: with nogil: # This call may read notifies which will be saved in the # PGconn buffer and passed to Python later. @@ -166,11 +172,16 @@ def fetch(pq.PGconn pgconn) -> PQGen[Optional[PGresult]]: cdef libpq.PGconn *pgconn_ptr = pgconn._pgconn_ptr cdef int cires, ibres cdef libpq.PGresult *pgres + cdef object ready with nogil: ibres = libpq.PQisBusy(pgconn_ptr) if ibres: - yield WAIT_R + while True: + ready = yield WAIT_R + if ready: + break + while True: with nogil: cires = libpq.PQconsumeInput(pgconn_ptr) @@ -182,7 +193,10 @@ def fetch(pq.PGconn pgconn) -> PQGen[Optional[PGresult]]: f"consuming input failed: {error_message(pgconn)}") if not ibres: break - yield WAIT_R + while True: + ready = yield WAIT_R + if ready: + break _consume_notifies(pgconn) @@ -211,7 +225,10 @@ def pipeline_communicate( cdef pq.PGresult r while True: - ready = yield WAIT_RW + while True: + ready = yield WAIT_RW + if ready: + break if ready & READY_R: with nogil: diff --git a/psycopg_c/psycopg_c/_psycopg/waiting.pyx b/psycopg_c/psycopg_c/_psycopg/waiting.pyx index 33c54c513..3a6cc6e25 100644 --- a/psycopg_c/psycopg_c/_psycopg/waiting.pyx +++ b/psycopg_c/psycopg_c/_psycopg/waiting.pyx @@ -51,7 +51,7 @@ static int wait_c_impl(int fileno, int wait, float timeout) { int select_rv; - int rv = 0; + int rv = -1; #if defined(HAVE_POLL) && !defined(HAVE_BROKEN_POLL) @@ -83,11 +83,14 @@ retry_eintr: goto retry_eintr; } - if (select_rv < 0) { goto error; } if (PyErr_CheckSignals()) { goto finally; } + if (select_rv < 0) { goto finally; } /* poll error */ - if (input_fd.events & POLLIN) { rv |= SELECT_EV_READ; } - if (input_fd.events & POLLOUT) { rv |= SELECT_EV_WRITE; } + rv = 0; /* success, maybe with timeout */ + if (select_rv >= 0) { + if (input_fd.events & POLLIN) { rv |= SELECT_EV_READ; } + if (input_fd.events & POLLOUT) { rv |= SELECT_EV_WRITE; } + } #else @@ -135,11 +138,14 @@ retry_eintr: goto retry_eintr; } - if (select_rv < 0) { goto error; } if (PyErr_CheckSignals()) { goto finally; } + if (select_rv < 0) { goto error; } /* select error */ - if (FD_ISSET(fileno, &ifds)) { rv |= SELECT_EV_READ; } - if (FD_ISSET(fileno, &ofds)) { rv |= SELECT_EV_WRITE; } + rv = 0; + if (select_rv > 0) { + if (FD_ISSET(fileno, &ifds)) { rv |= SELECT_EV_READ; } + if (FD_ISSET(fileno, &ofds)) { rv |= SELECT_EV_WRITE; } + } #endif /* HAVE_POLL */ @@ -147,6 +153,8 @@ retry_eintr: error: + rv = -1; + #ifdef MS_WINDOWS if (select_rv == SOCKET_ERROR) { PyErr_SetExcFromWindowsErr(PyExc_OSError, WSAGetLastError()); @@ -162,7 +170,7 @@ error: finally: - return -1; + return rv; } """ @@ -191,8 +199,8 @@ def wait_c(gen: PQGen[RV], int fileno, timeout = None) -> RV: while True: ready = wait_c_impl(fileno, wait, ctimeout) - if ready == 0: - continue + if ready == READY_NONE: + pyready = PY_READY_NONE elif ready == READY_R: pyready = PY_READY_R elif ready == READY_RW: diff --git a/psycopg_c/psycopg_c/types/datetime.pyx b/psycopg_c/psycopg_c/types/datetime.pyx index 0ec4179a2..4b0784bde 100644 --- a/psycopg_c/psycopg_c/types/datetime.pyx +++ b/psycopg_c/psycopg_c/types/datetime.pyx @@ -4,6 +4,7 @@ Cython adapters for date/time types. # Copyright (C) 2021 The Psycopg Team +from libc.stdint cimport int64_t from libc.string cimport memset, strchr from cpython cimport datetime as cdt from cpython.dict cimport PyDict_GetItem @@ -391,7 +392,7 @@ cdef class DateLoader(CLoader): if length != 10: self._error_date(data, "unexpected length") - cdef int vals[3] + cdef int64_t vals[3] memset(vals, 0, sizeof(vals)) cdef const char *ptr @@ -437,7 +438,7 @@ cdef class TimeLoader(CLoader): cdef object cload(self, const char *data, size_t length): - cdef int vals[3] + cdef int64_t vals[3] memset(vals, 0, sizeof(vals)) cdef const char *ptr cdef const char *end = data + length @@ -494,7 +495,7 @@ cdef class TimetzLoader(CLoader): cdef object cload(self, const char *data, size_t length): - cdef int vals[3] + cdef int64_t vals[3] memset(vals, 0, sizeof(vals)) cdef const char *ptr cdef const char *end = data + length @@ -581,7 +582,7 @@ cdef class TimestampLoader(CLoader): if self._order == ORDER_PGDM or self._order == ORDER_PGMD: return self._cload_pg(data, end) - cdef int vals[6] + cdef int64_t vals[6] memset(vals, 0, sizeof(vals)) cdef const char *ptr @@ -611,7 +612,7 @@ cdef class TimestampLoader(CLoader): raise _get_timestamp_load_error(self._pgconn, data, ex) from None cdef object _cload_pg(self, const char *data, const char *end): - cdef int vals[4] + cdef int64_t vals[4] memset(vals, 0, sizeof(vals)) cdef const char *ptr @@ -721,7 +722,7 @@ cdef class TimestamptzLoader(_BaseTimestamptzLoader): if end[-1] == b'C': # ends with BC raise _get_timestamp_load_error(self._pgconn, data) from None - cdef int vals[6] + cdef int64_t vals[6] memset(vals, 0, sizeof(vals)) # Parse the first 6 groups of digits (date and time) @@ -862,9 +863,10 @@ cdef class IntervalLoader(CLoader): if self._style == INTERVALSTYLE_OTHERS: return self._cload_notimpl(data, length) - cdef int days = 0, secs = 0, us = 0 + cdef int days = 0, us = 0 + cdef int64_t secs = 0 cdef char sign - cdef int val + cdef int64_t val cdef const char *ptr = data cdef const char *sep cdef const char *end = ptr + length @@ -908,7 +910,7 @@ cdef class IntervalLoader(CLoader): break # Parse the time part. An eventual sign was already consumed in the loop - cdef int vals[3] + cdef int64_t vals[3] memset(vals, 0, sizeof(vals)) if ptr != NULL: ptr = _parse_date_values(ptr, end, vals, ARRAYSIZE(vals)) @@ -918,6 +920,10 @@ cdef class IntervalLoader(CLoader): secs = vals[2] + 60 * (vals[1] + 60 * vals[0]) + if secs > 86_400: + days += secs // 86_400 + secs %= 86_400 + if ptr[0] == b'.': ptr = _parse_micros(ptr + 1, &us) @@ -966,11 +972,11 @@ cdef class IntervalBinaryLoader(CLoader): # Work only with positive values as the cdivision behaves differently # with negative values, and cdivision=False adds overhead. cdef int64_t aval = val if val >= 0 else -val - cdef int us, ussecs, usdays + cdef int64_t us, ussecs, usdays - # Group the micros in biggers stuff or timedelta_new might overflow + # Group the micros in bigger stuff or timedelta_new might overflow with cython.cdivision(True): - ussecs = (aval // 1_000_000) + ussecs = (aval // 1_000_000) us = aval % 1_000_000 usdays = ussecs // 86_400 @@ -988,7 +994,7 @@ cdef class IntervalBinaryLoader(CLoader): cdef const char *_parse_date_values( - const char *ptr, const char *end, int *vals, int nvals + const char *ptr, const char *end, int64_t *vals, int nvals ): """ Parse *nvals* numeric values separated by non-numeric chars. @@ -1046,7 +1052,7 @@ cdef int _parse_timezone_to_seconds(const char **bufptr, const char *end): cdef char sgn = ptr[0] # Parse at most three groups of digits - cdef int vals[3] + cdef int64_t vals[3] memset(vals, 0, sizeof(vals)) ptr = _parse_date_values(ptr + 1, end, vals, ARRAYSIZE(vals)) diff --git a/psycopg_pool/.flake8 b/psycopg_pool/.flake8 index 2ae629c2d..40a061b1e 100644 --- a/psycopg_pool/.flake8 +++ b/psycopg_pool/.flake8 @@ -1,3 +1,3 @@ [flake8] max-line-length = 88 -ignore = W503, E203 +ignore = W503, E203, E704 diff --git a/psycopg_pool/psycopg_pool/_acompat.py b/psycopg_pool/psycopg_pool/_acompat.py index 4e4fa20b0..d58548515 100644 --- a/psycopg_pool/psycopg_pool/_acompat.py +++ b/psycopg_pool/psycopg_pool/_acompat.py @@ -17,9 +17,7 @@ import logging import threading from typing import Any, Callable, Coroutine, TYPE_CHECKING -from typing_extensions import TypeAlias - -from ._compat import TypeVar +from ._compat import TypeAlias, TypeVar logger = logging.getLogger("psycopg.pool") T = TypeVar("T") diff --git a/psycopg_pool/psycopg_pool/_compat.py b/psycopg_pool/psycopg_pool/_compat.py index 5917ff31b..3fc645cbe 100644 --- a/psycopg_pool/psycopg_pool/_compat.py +++ b/psycopg_pool/psycopg_pool/_compat.py @@ -14,6 +14,11 @@ if sys.version_info >= (3, 9): else: from typing import Counter, Deque +if sys.version_info >= (3, 10): + from typing import TypeAlias +else: + from typing_extensions import TypeAlias + if sys.version_info >= (3, 11): from typing import Self else: @@ -28,6 +33,7 @@ __all__ = [ "Counter", "Deque", "Self", + "TypeAlias", "TypeVar", ] diff --git a/psycopg_pool/psycopg_pool/abc.py b/psycopg_pool/psycopg_pool/abc.py index 6cc85a2c5..07209a64c 100644 --- a/psycopg_pool/psycopg_pool/abc.py +++ b/psycopg_pool/psycopg_pool/abc.py @@ -8,9 +8,7 @@ from __future__ import annotations from typing import Any, Awaitable, Callable, Union, TYPE_CHECKING -from typing_extensions import TypeAlias - -from ._compat import TypeVar +from ._compat import TypeAlias, TypeVar if TYPE_CHECKING: from .pool import ConnectionPool diff --git a/psycopg_pool/psycopg_pool/null_pool.py b/psycopg_pool/psycopg_pool/null_pool.py index a5408d08e..0be5d6148 100644 --- a/psycopg_pool/psycopg_pool/null_pool.py +++ b/psycopg_pool/psycopg_pool/null_pool.py @@ -26,6 +26,7 @@ logger = logging.getLogger("psycopg.pool") class NullConnectionPool(_BaseNullConnectionPool, ConnectionPool[CT]): + def __init__( self, conninfo: str = "", @@ -47,6 +48,7 @@ class NullConnectionPool(_BaseNullConnectionPool, ConnectionPool[CT]): reconnect_failed: Optional[ConnectFailedCB] = None, num_workers: int = 3, ): # Note: min_size default value changed to 0. + super().__init__( conninfo, open=open, diff --git a/psycopg_pool/psycopg_pool/pool.py b/psycopg_pool/psycopg_pool/pool.py index 7ab999b1e..a021e14a0 100644 --- a/psycopg_pool/psycopg_pool/pool.py +++ b/psycopg_pool/psycopg_pool/pool.py @@ -913,8 +913,7 @@ class MaintenanceTask(ABC): pool.run_task(self) @abstractmethod - def _run(self, pool: ConnectionPool[Any]) -> None: - ... + def _run(self, pool: ConnectionPool[Any]) -> None: ... class StopWorker(MaintenanceTask): @@ -925,6 +924,7 @@ class StopWorker(MaintenanceTask): class AddConnection(MaintenanceTask): + def __init__( self, pool: ConnectionPool[Any], diff --git a/psycopg_pool/psycopg_pool/pool_async.py b/psycopg_pool/psycopg_pool/pool_async.py index fd8f63782..d0770dd77 100644 --- a/psycopg_pool/psycopg_pool/pool_async.py +++ b/psycopg_pool/psycopg_pool/pool_async.py @@ -962,8 +962,7 @@ class MaintenanceTask(ABC): pool.run_task(self) @abstractmethod - async def _run(self, pool: AsyncConnectionPool[Any]) -> None: - ... + async def _run(self, pool: AsyncConnectionPool[Any]) -> None: ... class StopWorker(MaintenanceTask): diff --git a/psycopg_pool/psycopg_pool/sched.py b/psycopg_pool/psycopg_pool/sched.py index 58cadc36d..40954b996 100644 --- a/psycopg_pool/psycopg_pool/sched.py +++ b/psycopg_pool/psycopg_pool/sched.py @@ -27,6 +27,7 @@ logger = logging.getLogger(__name__) class Scheduler: + def __init__(self) -> None: self._queue: List[Task] = [] self._lock = Lock() diff --git a/tests/conftest.py b/tests/conftest.py index 98b03426e..cc1273f78 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -73,9 +73,9 @@ def pytest_sessionstart(session): asyncio_options: Dict[str, Any] = {} if sys.platform == "win32": - asyncio_options[ - "loop_factory" - ] = asyncio.WindowsSelectorEventLoopPolicy().new_event_loop + asyncio_options["loop_factory"] = ( + asyncio.WindowsSelectorEventLoopPolicy().new_event_loop + ) @pytest.fixture( diff --git a/tests/constraints.txt b/tests/constraints.txt index a676a993f..5c106971c 100644 --- a/tests/constraints.txt +++ b/tests/constraints.txt @@ -17,7 +17,7 @@ pytest-cov == 3.0.0 pytest-randomly == 3.5.0 # From the 'dev' extra -black == 23.1.0 +black == 24.1.0 dnspython == 2.1.0 flake8 == 4.0.0 types-setuptools == 57.4.0 diff --git a/tests/dbapi20.py b/tests/dbapi20.py index c873a4e66..ea98800f5 100644 --- a/tests/dbapi20.py +++ b/tests/dbapi20.py @@ -13,6 +13,8 @@ -- Ian Bicking ''' +from __future__ import annotations + __rcs_id__ = '$Id: dbapi20.py,v 1.11 2005/01/02 02:41:01 zenzen Exp $' __version__ = '$Revision: 1.12 $'[11:-2] __author__ = 'Stuart Bishop ' @@ -101,7 +103,7 @@ class DatabaseAPI20Test(unittest.TestCase): # method is to be found driver: Any = None connect_args = () # List of arguments to pass to connect - connect_kw_args: Dict[str, Any] = {} # Keyword arguments for connect + connect_kw_args: Dict[Any, Any] = {} # Keyword arguments for connect table_prefix = 'dbapi20test_' # If you need to specify a prefix for tables ddl1 = 'create table %sbooze (name varchar(20))' % table_prefix diff --git a/tests/fix_db.py b/tests/fix_db.py index 9abeda79e..37ee7ac32 100644 --- a/tests/fix_db.py +++ b/tests/fix_db.py @@ -119,7 +119,7 @@ def dsn_env(dsn): continue args[opt.keyword.decode()] = os.environ[opt.envvar.decode()] - return make_conninfo(**args) + return make_conninfo("", **args) @pytest.fixture(scope="session") diff --git a/tests/fix_proxy.py b/tests/fix_proxy.py index 1d566b5e5..6a7487786 100644 --- a/tests/fix_proxy.py +++ b/tests/fix_proxy.py @@ -58,7 +58,8 @@ class Proxy: cdict = conninfo.conninfo_to_dict(server_dsn) # Get server params - host = cdict.get("host") or os.environ.get("PGHOST") + host = cdict.get("host") or os.environ.get("PGHOST", "") + assert isinstance(host, str) self.server_host = host if host and not host.startswith("/") else "localhost" self.server_port = cdict.get("port") or os.environ.get("PGPORT", "5432") @@ -70,7 +71,7 @@ class Proxy: cdict["host"] = self.client_host cdict["port"] = self.client_port cdict["sslmode"] = "disable" # not supported by the proxy - self.client_dsn = conninfo.make_conninfo(**cdict) + self.client_dsn = conninfo.make_conninfo("", **cdict) # The running proxy process self.proc = None diff --git a/tests/fix_psycopg.py b/tests/fix_psycopg.py index 80e0c626d..bc054fdb7 100644 --- a/tests/fix_psycopg.py +++ b/tests/fix_psycopg.py @@ -24,7 +24,6 @@ def global_adapters(): @pytest.fixture -@pytest.mark.crdb_skip("2-phase commit") def tpc(svcconn): tpc = Tpc(svcconn) tpc.check_tpc() diff --git a/tests/pool/test_pool.py b/tests/pool/test_pool.py index e2aeb7cc3..ff8fafd88 100644 --- a/tests/pool/test_pool.py +++ b/tests/pool/test_pool.py @@ -42,10 +42,11 @@ def test_bad_size(dsn, min_size, max_size): class MyRow(Dict[str, Any]): - ... + pass def test_generic_connection_type(dsn): + def configure(conn: psycopg.Connection[Any]) -> None: set_autocommit(conn, True) @@ -78,10 +79,12 @@ def test_generic_connection_type(dsn): def test_non_generic_connection_type(dsn): + def configure(conn: psycopg.Connection[Any]) -> None: set_autocommit(conn, True) class MyConnection(psycopg.Connection[MyRow]): + def __init__(self, *args: Any, **kwargs: Any): kwargs["row_factory"] = class_row(MyRow) super().__init__(*args, **kwargs) @@ -638,6 +641,7 @@ def test_uniform_use(dsn): @pytest.mark.slow @pytest.mark.timing def test_resize(dsn): + def sampler(): sleep(0.05) # ensure sampling happens after shrink check while True: diff --git a/tests/pool/test_pool_async.py b/tests/pool/test_pool_async.py index 0bc84f8bc..6a699e40b 100644 --- a/tests/pool/test_pool_async.py +++ b/tests/pool/test_pool_async.py @@ -42,7 +42,7 @@ async def test_bad_size(dsn, min_size, max_size): class MyRow(Dict[str, Any]): - ... + pass async def test_generic_connection_type(dsn): diff --git a/tests/pool/test_pool_common.py b/tests/pool/test_pool_common.py index ddf78a693..a8815da20 100644 --- a/tests/pool/test_pool_common.py +++ b/tests/pool/test_pool_common.py @@ -35,6 +35,7 @@ def test_defaults(pool_cls, dsn): def test_connection_class(pool_cls, dsn): + class MyConn(psycopg.Connection[Any]): pass @@ -158,6 +159,7 @@ def test_configure_broken(pool_cls, dsn, caplog): @pytest.mark.timing @pytest.mark.crdb_skip("backend pid") def test_queue(pool_cls, dsn): + def worker(n): t0 = time() with p.connection() as conn: @@ -182,6 +184,7 @@ def test_queue(pool_cls, dsn): @pytest.mark.slow def test_queue_size(pool_cls, dsn): + def worker(t, ev=None): try: with p.connection(): @@ -217,6 +220,7 @@ def test_queue_size(pool_cls, dsn): @pytest.mark.timing @pytest.mark.crdb_skip("backend pid") def test_queue_timeout(pool_cls, dsn): + def worker(n): t0 = time() try: @@ -246,6 +250,7 @@ def test_queue_timeout(pool_cls, dsn): @pytest.mark.slow @pytest.mark.timing def test_dead_client(pool_cls, dsn): + def worker(i, timeout): try: with p.connection(timeout=timeout) as conn: @@ -273,6 +278,7 @@ def test_dead_client(pool_cls, dsn): @pytest.mark.timing @pytest.mark.crdb_skip("backend pid") def test_queue_timeout_override(pool_cls, dsn): + def worker(n): t0 = time() timeout = 0.25 if n == 3 else None @@ -382,6 +388,7 @@ def test_close_connection_on_pool_close(pool_cls, dsn): def test_closed_queue(pool_cls, dsn): + def w1(): with p.connection() as conn: e1.set() # Tell w0 that w1 got a connection @@ -493,6 +500,7 @@ def test_jitter(pool_cls): @pytest.mark.slow @pytest.mark.timing def test_stats_measures(pool_cls, dsn): + def worker(n): with p.connection() as conn: conn.execute("select pg_sleep(0.2)") @@ -532,6 +540,7 @@ def test_stats_measures(pool_cls, dsn): @pytest.mark.slow @pytest.mark.timing def test_stats_usage(pool_cls, dsn): + def worker(n): try: with p.connection(timeout=0.3) as conn: @@ -613,6 +622,7 @@ def test_check_init(pool_cls, dsn): @pytest.mark.slow def test_check_timeout(pool_cls, dsn): + def check(conn): raise Exception() diff --git a/tests/pool/test_pool_null.py b/tests/pool/test_pool_null.py index fbe698df6..c54014572 100644 --- a/tests/pool/test_pool_null.py +++ b/tests/pool/test_pool_null.py @@ -40,10 +40,11 @@ def test_bad_size(dsn, min_size, max_size): class MyRow(Dict[str, Any]): - ... + pass def test_generic_connection_type(dsn): + def configure(conn: psycopg.Connection[Any]) -> None: set_autocommit(conn, True) @@ -76,10 +77,12 @@ def test_generic_connection_type(dsn): def test_non_generic_connection_type(dsn): + def configure(conn: psycopg.Connection[Any]) -> None: set_autocommit(conn, True) class MyConnection(psycopg.Connection[MyRow]): + def __init__(self, *args: Any, **kwargs: Any): kwargs["row_factory"] = class_row(MyRow) super().__init__(*args, **kwargs) diff --git a/tests/pool/test_pool_null_async.py b/tests/pool/test_pool_null_async.py index 09c0e2150..b610045cc 100644 --- a/tests/pool/test_pool_null_async.py +++ b/tests/pool/test_pool_null_async.py @@ -40,7 +40,7 @@ async def test_bad_size(dsn, min_size, max_size): class MyRow(Dict[str, Any]): - ... + pass async def test_generic_connection_type(dsn): diff --git a/tests/scripts/dectest.py b/tests/scripts/dectest.py index a49f11685..0fb65ad90 100644 --- a/tests/scripts/dectest.py +++ b/tests/scripts/dectest.py @@ -1,6 +1,7 @@ """ A quick and rough performance comparison of text vs. binary Decimal adaptation """ + from random import randrange from decimal import Decimal import psycopg diff --git a/tests/scripts/pipeline-demo.py b/tests/scripts/pipeline-demo.py index ec952293a..74cc04d1f 100644 --- a/tests/scripts/pipeline-demo.py +++ b/tests/scripts/pipeline-demo.py @@ -7,6 +7,7 @@ We do not fetch results explicitly (using cursor.fetch*()), this is handled by execute() calls when pgconn socket is read-ready, which happens when the output buffer is full. """ + import argparse import asyncio import logging diff --git a/tests/test_concurrency_async.py b/tests/test_concurrency_async.py index 150d77477..a14fb93e6 100644 --- a/tests/test_concurrency_async.py +++ b/tests/test_concurrency_async.py @@ -6,7 +6,7 @@ import threading import subprocess as sp from asyncio import create_task from asyncio.queues import Queue -from typing import List, Tuple +from typing import List import pytest @@ -58,50 +58,6 @@ async def test_concurrent_execution(aconn_cls, dsn): assert time.time() - t0 < 0.8, "something broken in concurrency" -@pytest.mark.slow -@pytest.mark.timing -@pytest.mark.crdb_skip("notify") -async def test_notifies(aconn_cls, aconn, dsn): - nconn = await aconn_cls.connect(dsn, autocommit=True) - npid = nconn.pgconn.backend_pid - - async def notifier(): - cur = nconn.cursor() - await asyncio.sleep(0.25) - await cur.execute("notify foo, '1'") - await asyncio.sleep(0.25) - await cur.execute("notify foo, '2'") - await nconn.close() - - async def receiver(): - await aconn.set_autocommit(True) - cur = aconn.cursor() - await cur.execute("listen foo") - gen = aconn.notifies() - async for n in gen: - ns.append((n, time.time())) - if len(ns) >= 2: - await gen.aclose() - - ns: List[Tuple[psycopg.Notify, float]] = [] - t0 = time.time() - workers = [notifier(), receiver()] - await asyncio.gather(*workers) - assert len(ns) == 2 - - n, t1 = ns[0] - assert n.pid == npid - assert n.channel == "foo" - assert n.payload == "1" - assert t1 - t0 == pytest.approx(0.25, abs=0.05) - - n, t1 = ns[1] - assert n.pid == npid - assert n.channel == "foo" - assert n.payload == "2" - assert t1 - t0 == pytest.approx(0.5, abs=0.05) - - async def canceller(aconn, errors): try: await asyncio.sleep(0.5) diff --git a/tests/test_connection.py b/tests/test_connection.py index 8456dba45..f17f8d2a0 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -9,7 +9,7 @@ import weakref from typing import Any, List import psycopg -from psycopg import Notify, pq, errors as e +from psycopg import pq, errors as e from psycopg.rows import tuple_row from psycopg.conninfo import conninfo_to_dict, timeout_from_conninfo @@ -34,6 +34,7 @@ def test_connect_bad(conn_cls): def test_connect_str_subclass(conn_cls, dsn): + class MyString(str): pass @@ -467,6 +468,7 @@ def test_connect_args( ], ) def test_connect_badargs(conn_cls, monkeypatch, pgconn, args, kwargs, exctype): + def fake_connect(conninfo): return pgconn yield @@ -526,47 +528,6 @@ def test_notice_handlers(conn, caplog): conn.remove_notice_handler(cb1) -@pytest.mark.crdb_skip("notify") -def test_notify_handlers(conn): - nots1 = [] - nots2 = [] - - def cb1(n): - nots1.append(n) - - conn.add_notify_handler(cb1) - conn.add_notify_handler(lambda n: nots2.append(n)) - - conn.set_autocommit(True) - cur = conn.cursor() - cur.execute("listen foo") - cur.execute("notify foo, 'n1'") - - assert len(nots1) == 1 - n = nots1[0] - assert n.channel == "foo" - assert n.payload == "n1" - assert n.pid == conn.pgconn.backend_pid - - assert len(nots2) == 1 - assert nots2[0] == nots1[0] - - conn.remove_notify_handler(cb1) - cur.execute("notify foo, 'n2'") - - assert len(nots1) == 1 - assert len(nots2) == 2 - n = nots2[1] - assert isinstance(n, Notify) - assert n.channel == "foo" - assert n.payload == "n2" - assert n.pid == conn.pgconn.backend_pid - assert hash(n) - - with pytest.raises(ValueError): - conn.remove_notify_handler(cb1) - - def test_execute(conn): cur = conn.execute("select %s, %s", [10, 20]) assert cur.fetchone() == (10, 20) @@ -642,6 +603,7 @@ def test_cursor_factory(conn): def test_cursor_factory_connect(conn_cls, dsn): + class MyCursor(psycopg.Cursor[psycopg.rows.Row]): pass diff --git a/tests/test_connection_async.py b/tests/test_connection_async.py index 2e950aff4..d7aa7ca8b 100644 --- a/tests/test_connection_async.py +++ b/tests/test_connection_async.py @@ -6,7 +6,7 @@ import weakref from typing import Any, List import psycopg -from psycopg import Notify, pq, errors as e +from psycopg import pq, errors as e from psycopg.rows import tuple_row from psycopg.conninfo import conninfo_to_dict, timeout_from_conninfo @@ -526,47 +526,6 @@ async def test_notice_handlers(aconn, caplog): aconn.remove_notice_handler(cb1) -@pytest.mark.crdb_skip("notify") -async def test_notify_handlers(aconn): - nots1 = [] - nots2 = [] - - def cb1(n): - nots1.append(n) - - aconn.add_notify_handler(cb1) - aconn.add_notify_handler(lambda n: nots2.append(n)) - - await aconn.set_autocommit(True) - cur = aconn.cursor() - await cur.execute("listen foo") - await cur.execute("notify foo, 'n1'") - - assert len(nots1) == 1 - n = nots1[0] - assert n.channel == "foo" - assert n.payload == "n1" - assert n.pid == aconn.pgconn.backend_pid - - assert len(nots2) == 1 - assert nots2[0] == nots1[0] - - aconn.remove_notify_handler(cb1) - await cur.execute("notify foo, 'n2'") - - assert len(nots1) == 1 - assert len(nots2) == 2 - n = nots2[1] - assert isinstance(n, Notify) - assert n.channel == "foo" - assert n.payload == "n2" - assert n.pid == aconn.pgconn.backend_pid - assert hash(n) - - with pytest.raises(ValueError): - aconn.remove_notify_handler(cb1) - - async def test_execute(aconn): cur = await aconn.execute("select %s, %s", [10, 20]) assert await cur.fetchone() == (10, 20) @@ -637,7 +596,7 @@ async def test_cursor_factory(aconn): async with aconn.cursor() as cur: assert isinstance(cur, MyCursor) - async with (await aconn.execute("select 1")) as cur: + async with await aconn.execute("select 1") as cur: assert isinstance(cur, MyCursor) diff --git a/tests/test_conninfo_attempts.py b/tests/test_conninfo_attempts.py index c2855760a..0f4ba1b11 100644 --- a/tests/test_conninfo_attempts.py +++ b/tests/test_conninfo_attempts.py @@ -165,14 +165,14 @@ def test_conninfo_random_multi_host(): def test_conninfo_random_multi_ips(fake_resolve): args = {"host": "alot.com"} - hostaddrs = [att["hostaddr"] for att in conninfo_attempts(args)] + hostaddrs = [str(att["hostaddr"]) for att in conninfo_attempts(args)] assert len(hostaddrs) == 20 assert hostaddrs == sorted(hostaddrs) args["load_balance_hosts"] = "disable" - hostaddrs = [att["hostaddr"] for att in conninfo_attempts(args)] + hostaddrs = [str(att["hostaddr"]) for att in conninfo_attempts(args)] assert hostaddrs == sorted(hostaddrs) args["load_balance_hosts"] = "random" - hostaddrs = [att["hostaddr"] for att in conninfo_attempts(args)] + hostaddrs = [str(att["hostaddr"]) for att in conninfo_attempts(args)] assert hostaddrs != sorted(hostaddrs) diff --git a/tests/test_conninfo_attempts_async.py b/tests/test_conninfo_attempts_async.py index bf6da880f..aada9f1e0 100644 --- a/tests/test_conninfo_attempts_async.py +++ b/tests/test_conninfo_attempts_async.py @@ -172,14 +172,14 @@ async def test_conninfo_random_multi_host(): async def test_conninfo_random_multi_ips(fake_resolve): args = {"host": "alot.com"} - hostaddrs = [att["hostaddr"] for att in await conninfo_attempts_async(args)] + hostaddrs = [str(att["hostaddr"]) for att in await conninfo_attempts_async(args)] assert len(hostaddrs) == 20 assert hostaddrs == sorted(hostaddrs) args["load_balance_hosts"] = "disable" - hostaddrs = [att["hostaddr"] for att in await conninfo_attempts_async(args)] + hostaddrs = [str(att["hostaddr"]) for att in await conninfo_attempts_async(args)] assert hostaddrs == sorted(hostaddrs) args["load_balance_hosts"] = "random" - hostaddrs = [att["hostaddr"] for att in await conninfo_attempts_async(args)] + hostaddrs = [str(att["hostaddr"]) for att in await conninfo_attempts_async(args)] assert hostaddrs != sorted(hostaddrs) diff --git a/tests/test_copy.py b/tests/test_copy.py index fda854e60..55f9e3b77 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -306,6 +306,7 @@ def test_subclass_adapter(conn, format): BaseDumper = StrBinaryDumper # type: ignore class MyStrDumper(BaseDumper): + def dump(self, obj): return super().dump(obj) * 2 @@ -641,6 +642,7 @@ def test_worker_life(conn, format, buffer): def test_worker_error_propagated(conn, monkeypatch): + def copy_to_broken(pgconn, buffer): raise ZeroDivisionError yield @@ -803,6 +805,7 @@ def test_copy_table_across(conn_cls, dsn, faker, mode): class DataGenerator: + def __init__(self, conn, nrecs, srec, offset=0, block_size=8192): self.conn = conn self.nrecs = nrecs diff --git a/tests/test_cursor_common.py b/tests/test_cursor_common.py index 159e67cb1..d5b0d1513 100644 --- a/tests/test_cursor_common.py +++ b/tests/test_cursor_common.py @@ -576,6 +576,7 @@ def test_row_factory_none(conn): def test_bad_row_factory(conn): + def broken_factory(cur): 1 / 0 @@ -584,6 +585,7 @@ def test_bad_row_factory(conn): cur.execute("select 1") def broken_maker(cur): + def make_row(seq): 1 / 0 diff --git a/tests/test_generators.py b/tests/test_generators.py index ecb8da987..2df55e3e0 100644 --- a/tests/test_generators.py +++ b/tests/test_generators.py @@ -25,7 +25,7 @@ def test_connect_operationalerror_pgconn(generators, dsn, monkeypatch): except KeyError: info = conninfo_to_dict(dsn) del info["password"] # should not raise per check above. - dsn = make_conninfo(**info) + dsn = make_conninfo("", **info) gen = generators.connect(dsn) with pytest.raises( diff --git a/tests/test_notify.py b/tests/test_notify.py new file mode 100644 index 000000000..c67b33115 --- /dev/null +++ b/tests/test_notify.py @@ -0,0 +1,195 @@ +# WARNING: this file is auto-generated by 'async_to_sync.py' +# from the original file 'test_notify_async.py' +# DO NOT CHANGE! Change the original file instead. +from __future__ import annotations + +from time import time + +import pytest +from psycopg import Notify + +from .acompat import sleep, gather, spawn + +pytestmark = pytest.mark.crdb_skip("notify") + + +def test_notify_handlers(conn): + nots1 = [] + nots2 = [] + + def cb1(n): + nots1.append(n) + + conn.add_notify_handler(cb1) + conn.add_notify_handler(lambda n: nots2.append(n)) + + conn.set_autocommit(True) + conn.execute("listen foo") + conn.execute("notify foo, 'n1'") + + assert len(nots1) == 1 + n = nots1[0] + assert n.channel == "foo" + assert n.payload == "n1" + assert n.pid == conn.pgconn.backend_pid + + assert len(nots2) == 1 + assert nots2[0] == nots1[0] + + conn.remove_notify_handler(cb1) + conn.execute("notify foo, 'n2'") + + assert len(nots1) == 1 + assert len(nots2) == 2 + n = nots2[1] + assert isinstance(n, Notify) + assert n.channel == "foo" + assert n.payload == "n2" + assert n.pid == conn.pgconn.backend_pid + assert hash(n) + + with pytest.raises(ValueError): + conn.remove_notify_handler(cb1) + + +@pytest.mark.slow +@pytest.mark.timing +def test_notify(conn_cls, conn, dsn): + npid = None + + def notifier(): + with conn_cls.connect(dsn, autocommit=True) as nconn: + nonlocal npid + npid = nconn.pgconn.backend_pid + + sleep(0.25) + nconn.execute("notify foo, '1'") + sleep(0.25) + nconn.execute("notify foo, '2'") + + def receiver(): + conn.set_autocommit(True) + cur = conn.cursor() + cur.execute("listen foo") + gen = conn.notifies() + for n in gen: + ns.append((n, time())) + if len(ns) >= 2: + gen.close() + + ns: list[tuple[Notify, float]] = [] + t0 = time() + workers = [spawn(notifier), spawn(receiver)] + gather(*workers) + assert len(ns) == 2 + + n, t1 = ns[0] + assert n.pid == npid + assert n.channel == "foo" + assert n.payload == "1" + assert t1 - t0 == pytest.approx(0.25, abs=0.05) + + n, t1 = ns[1] + assert n.pid == npid + assert n.channel == "foo" + assert n.payload == "2" + assert t1 - t0 == pytest.approx(0.5, abs=0.05) + + +@pytest.mark.slow +@pytest.mark.timing +def test_no_notify_timeout(conn): + conn.set_autocommit(True) + t0 = time() + for n in conn.notifies(timeout=0.5): + assert False + dt = time() - t0 + assert 0.5 <= dt < 0.75 + + +@pytest.mark.slow +@pytest.mark.timing +def test_notify_timeout(conn_cls, conn, dsn): + conn.set_autocommit(True) + conn.execute("listen foo") + + def notifier(): + with conn_cls.connect(dsn, autocommit=True) as nconn: + sleep(0.25) + nconn.execute("notify foo, '1'") + + worker = spawn(notifier) + try: + times = [time()] + for n in conn.notifies(timeout=0.5): + times.append(time()) + times.append(time()) + finally: + gather(worker) + + assert len(times) == 3 + assert times[1] - times[0] == pytest.approx(0.25, 0.1) + assert times[2] - times[1] == pytest.approx(0.25, 0.1) + + +@pytest.mark.slow +def test_notify_timeout_0(conn_cls, conn, dsn): + conn.set_autocommit(True) + conn.execute("listen foo") + + ns = list(conn.notifies(timeout=0)) + assert not ns + + with conn_cls.connect(dsn, autocommit=True) as nconn: + nconn.execute("notify foo, '1'") + sleep(0.1) + + ns = list(conn.notifies(timeout=0)) + assert len(ns) == 1 + + +@pytest.mark.slow +def test_stop_after(conn_cls, conn, dsn): + conn.set_autocommit(True) + conn.execute("listen foo") + + def notifier(): + with conn_cls.connect(dsn, autocommit=True) as nconn: + nconn.execute("notify foo, '1'") + sleep(0.1) + nconn.execute("notify foo, '2'") + sleep(0.1) + nconn.execute("notify foo, '3'") + + worker = spawn(notifier) + try: + ns = list(conn.notifies(timeout=1.0, stop_after=2)) + assert len(ns) == 2 + assert ns[0].payload == "1" + assert ns[1].payload == "2" + finally: + gather(worker) + + ns = list(conn.notifies(timeout=0.0)) + assert len(ns) == 1 + assert ns[0].payload == "3" + + +def test_stop_after_batch(conn_cls, conn, dsn): + conn.set_autocommit(True) + conn.execute("listen foo") + + def notifier(): + with conn_cls.connect(dsn, autocommit=True) as nconn: + with nconn.transaction(): + nconn.execute("notify foo, '1'") + nconn.execute("notify foo, '2'") + + worker = spawn(notifier) + try: + ns = list(conn.notifies(timeout=1.0, stop_after=1)) + assert len(ns) == 2 + assert ns[0].payload == "1" + assert ns[1].payload == "2" + finally: + gather(worker) diff --git a/tests/test_notify_async.py b/tests/test_notify_async.py new file mode 100644 index 000000000..f4f0901d6 --- /dev/null +++ b/tests/test_notify_async.py @@ -0,0 +1,192 @@ +from __future__ import annotations + +from time import time + +import pytest +from psycopg import Notify + +from .acompat import alist, asleep, gather, spawn + +pytestmark = pytest.mark.crdb_skip("notify") + + +async def test_notify_handlers(aconn): + nots1 = [] + nots2 = [] + + def cb1(n): + nots1.append(n) + + aconn.add_notify_handler(cb1) + aconn.add_notify_handler(lambda n: nots2.append(n)) + + await aconn.set_autocommit(True) + await aconn.execute("listen foo") + await aconn.execute("notify foo, 'n1'") + + assert len(nots1) == 1 + n = nots1[0] + assert n.channel == "foo" + assert n.payload == "n1" + assert n.pid == aconn.pgconn.backend_pid + + assert len(nots2) == 1 + assert nots2[0] == nots1[0] + + aconn.remove_notify_handler(cb1) + await aconn.execute("notify foo, 'n2'") + + assert len(nots1) == 1 + assert len(nots2) == 2 + n = nots2[1] + assert isinstance(n, Notify) + assert n.channel == "foo" + assert n.payload == "n2" + assert n.pid == aconn.pgconn.backend_pid + assert hash(n) + + with pytest.raises(ValueError): + aconn.remove_notify_handler(cb1) + + +@pytest.mark.slow +@pytest.mark.timing +async def test_notify(aconn_cls, aconn, dsn): + npid = None + + async def notifier(): + async with await aconn_cls.connect(dsn, autocommit=True) as nconn: + nonlocal npid + npid = nconn.pgconn.backend_pid + + await asleep(0.25) + await nconn.execute("notify foo, '1'") + await asleep(0.25) + await nconn.execute("notify foo, '2'") + + async def receiver(): + await aconn.set_autocommit(True) + cur = aconn.cursor() + await cur.execute("listen foo") + gen = aconn.notifies() + async for n in gen: + ns.append((n, time())) + if len(ns) >= 2: + await gen.aclose() + + ns: list[tuple[Notify, float]] = [] + t0 = time() + workers = [spawn(notifier), spawn(receiver)] + await gather(*workers) + assert len(ns) == 2 + + n, t1 = ns[0] + assert n.pid == npid + assert n.channel == "foo" + assert n.payload == "1" + assert t1 - t0 == pytest.approx(0.25, abs=0.05) + + n, t1 = ns[1] + assert n.pid == npid + assert n.channel == "foo" + assert n.payload == "2" + assert t1 - t0 == pytest.approx(0.5, abs=0.05) + + +@pytest.mark.slow +@pytest.mark.timing +async def test_no_notify_timeout(aconn): + await aconn.set_autocommit(True) + t0 = time() + async for n in aconn.notifies(timeout=0.5): + assert False + dt = time() - t0 + assert 0.5 <= dt < 0.75 + + +@pytest.mark.slow +@pytest.mark.timing +async def test_notify_timeout(aconn_cls, aconn, dsn): + await aconn.set_autocommit(True) + await aconn.execute("listen foo") + + async def notifier(): + async with await aconn_cls.connect(dsn, autocommit=True) as nconn: + await asleep(0.25) + await nconn.execute("notify foo, '1'") + + worker = spawn(notifier) + try: + times = [time()] + async for n in aconn.notifies(timeout=0.5): + times.append(time()) + times.append(time()) + finally: + await gather(worker) + + assert len(times) == 3 + assert times[1] - times[0] == pytest.approx(0.25, 0.1) + assert times[2] - times[1] == pytest.approx(0.25, 0.1) + + +@pytest.mark.slow +async def test_notify_timeout_0(aconn_cls, aconn, dsn): + await aconn.set_autocommit(True) + await aconn.execute("listen foo") + + ns = await alist(aconn.notifies(timeout=0)) + assert not ns + + async with await aconn_cls.connect(dsn, autocommit=True) as nconn: + await nconn.execute("notify foo, '1'") + await asleep(0.1) + + ns = await alist(aconn.notifies(timeout=0)) + assert len(ns) == 1 + + +@pytest.mark.slow +async def test_stop_after(aconn_cls, aconn, dsn): + await aconn.set_autocommit(True) + await aconn.execute("listen foo") + + async def notifier(): + async with await aconn_cls.connect(dsn, autocommit=True) as nconn: + await nconn.execute("notify foo, '1'") + await asleep(0.1) + await nconn.execute("notify foo, '2'") + await asleep(0.1) + await nconn.execute("notify foo, '3'") + + worker = spawn(notifier) + try: + ns = await alist(aconn.notifies(timeout=1.0, stop_after=2)) + assert len(ns) == 2 + assert ns[0].payload == "1" + assert ns[1].payload == "2" + finally: + await gather(worker) + + ns = await alist(aconn.notifies(timeout=0.0)) + assert len(ns) == 1 + assert ns[0].payload == "3" + + +async def test_stop_after_batch(aconn_cls, aconn, dsn): + await aconn.set_autocommit(True) + await aconn.execute("listen foo") + + async def notifier(): + async with await aconn_cls.connect(dsn, autocommit=True) as nconn: + async with nconn.transaction(): + await nconn.execute("notify foo, '1'") + await nconn.execute("notify foo, '2'") + + worker = spawn(notifier) + try: + ns = await alist(aconn.notifies(timeout=1.0, stop_after=1)) + assert len(ns) == 2 + assert ns[0].payload == "1" + assert ns[1].payload == "2" + finally: + await gather(worker) diff --git a/tests/test_psycopg_dbapi20.py b/tests/test_psycopg_dbapi20.py index a89344974..00e37017e 100644 --- a/tests/test_psycopg_dbapi20.py +++ b/tests/test_psycopg_dbapi20.py @@ -18,7 +18,7 @@ def with_dsn(request, session_dsn): class PsycopgTests(dbapi20.DatabaseAPI20Test): driver = psycopg # connect_args = () # set by the fixture - connect_kw_args: Dict[str, Any] = {} + connect_kw_args: Dict[Any, Any] = {} def test_nextset(self): # tested elsewhere diff --git a/tests/test_rows.py b/tests/test_rows.py index 5165b8007..93240b5eb 100644 --- a/tests/test_rows.py +++ b/tests/test_rows.py @@ -102,6 +102,16 @@ def test_kwargs_row(conn): assert p.age == 42 +def test_scalar_row(conn): + cur = conn.cursor(row_factory=rows.scalar_row) + cur.execute("select 1") + assert cur.fetchone() == 1 + cur.execute("select 1, 2") + assert cur.fetchone() == 1 + with pytest.raises(psycopg.ProgrammingError): + cur.execute("select") + + @pytest.mark.parametrize( "factory", "tuple_row dict_row namedtuple_row class_row args_row kwargs_row".split(), diff --git a/tests/test_sql.py b/tests/test_sql.py index b1ec8d85d..b5f1b37ca 100644 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -267,35 +267,51 @@ class TestIdentifier: assert sql.Identifier("foo") != "foo" assert sql.Identifier("foo") != sql.SQL("foo") - @pytest.mark.parametrize( - "args, want", - [ - (("foo",), '"foo"'), - (("foo", "bar"), '"foo"."bar"'), - (("fo'o", 'ba"r'), '"fo\'o"."ba""r"'), - ], - ) + _as_string_params = [ + (("foo",), '"foo"'), + (("foo", "bar"), '"foo"."bar"'), + (("fo'o", 'ba"r'), '"fo\'o"."ba""r"'), + ] + + @pytest.mark.parametrize("args, want", _as_string_params) def test_as_string(self, conn, args, want): assert sql.Identifier(*args).as_string(conn) == want - @pytest.mark.parametrize( - "args, want, enc", - [ - crdb_encoding(("foo",), '"foo"', "ascii"), - crdb_encoding(("foo", "bar"), '"foo"."bar"', "ascii"), - crdb_encoding(("fo'o", 'ba"r'), '"fo\'o"."ba""r"', "ascii"), - (("foo", eur), f'"foo"."{eur}"', "utf8"), - crdb_encoding(("foo", eur), f'"foo"."{eur}"', "latin9"), - ], - ) + @pytest.mark.parametrize("args, want", _as_string_params) + def test_as_string_no_conn(self, args, want): + assert sql.Identifier(*args).as_string(None) == want + assert sql.Identifier(*args).as_string() == want + + _as_bytes_params = [ + crdb_encoding(("foo",), '"foo"', "ascii"), + crdb_encoding(("foo", "bar"), '"foo"."bar"', "ascii"), + crdb_encoding(("fo'o", 'ba"r'), '"fo\'o"."ba""r"', "ascii"), + (("foo", eur), f'"foo"."{eur}"', "utf8"), + crdb_encoding(("foo", eur), f'"foo"."{eur}"', "latin9"), + ] + + @pytest.mark.parametrize("args, want, enc", _as_bytes_params) def test_as_bytes(self, conn, args, want, enc): want = want.encode(enc) conn.execute(f"set client_encoding to {py2pgenc(enc).decode()}") assert sql.Identifier(*args).as_bytes(conn) == want + @pytest.mark.parametrize("args, want, enc", _as_bytes_params) + def test_as_bytes_no_conn(self, conn, args, want, enc): + want = want.encode() + assert sql.Identifier(*args).as_bytes(None) == want + assert sql.Identifier(*args).as_bytes() == want + def test_join(self): assert not hasattr(sql.Identifier("foo"), "join") + def test_escape_no_conn(self, conn): + conn.execute("set client_encoding to 'utf8'") + for c in range(1, 128): + s = chr(c) + want = sql.Identifier(s).as_bytes(conn) + assert want == sql.Identifier(s).as_bytes(None) + class TestLiteral: def test_class(self): @@ -312,17 +328,40 @@ class TestLiteral: assert repr(sql.Literal("foo")) == "Literal('foo')" assert str(sql.Literal("foo")) == "Literal('foo')" - def test_as_string(self, conn): - assert sql.Literal(None).as_string(conn) == "NULL" - assert no_e(sql.Literal("foo").as_string(conn)) == "'foo'" - assert sql.Literal(42).as_string(conn) == "42" - assert sql.Literal(dt.date(2017, 1, 1)).as_string(conn) == "'2017-01-01'::date" - - def test_as_bytes(self, conn): - assert sql.Literal(None).as_bytes(conn) == b"NULL" - assert no_e(sql.Literal("foo").as_bytes(conn)) == b"'foo'" - assert sql.Literal(42).as_bytes(conn) == b"42" - assert sql.Literal(dt.date(2017, 1, 1)).as_bytes(conn) == b"'2017-01-01'::date" + _params = [ + (None, "NULL"), + ("foo", "'foo'"), + (42, "42"), + (dt.date(2017, 1, 1), "'2017-01-01'::date"), + ] + + @pytest.mark.parametrize("obj, want", _params) + def test_as_string(self, conn, obj, want): + got = sql.Literal(obj).as_string(conn) + if isinstance(obj, str): + got = no_e(got) + assert got == want + + @pytest.mark.parametrize("obj, want", _params) + def test_as_bytes(self, conn, obj, want): + got = sql.Literal(obj).as_bytes(conn) + if isinstance(obj, str): + got = no_e(got) + assert got == want.encode() + + @pytest.mark.parametrize("obj, want", _params) + def test_as_string_no_conn(self, obj, want): + got = sql.Literal(obj).as_string() + if isinstance(obj, str): + got = no_e(got) + assert got == want + + @pytest.mark.parametrize("obj, want", _params) + def test_as_bytes_no_conn(self, obj, want): + got = sql.Literal(obj).as_bytes() + if isinstance(obj, str): + got = no_e(got) + assert got == want.encode() @pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")]) def test_as_bytes_encoding(self, conn, encoding): @@ -459,6 +498,7 @@ class TestSQL: def test_as_string(self, conn): assert sql.SQL("foo").as_string(conn) == "foo" + assert sql.SQL("foo").as_string() == "foo" @pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")]) def test_as_bytes(self, conn, encoding): @@ -467,6 +507,10 @@ class TestSQL: assert sql.SQL(eur).as_bytes(conn) == eur.encode(encoding) + def test_no_conn(self): + assert sql.SQL(eur).as_string() == eur + assert sql.SQL(eur).as_bytes() == eur.encode() + class TestComposed: def test_class(self): @@ -526,10 +570,12 @@ class TestComposed: def test_as_string(self, conn): obj = sql.Composed([sql.SQL("foo"), sql.SQL("bar")]) assert obj.as_string(conn) == "foobar" + assert obj.as_string() == "foobar" def test_as_bytes(self, conn): obj = sql.Composed([sql.SQL("foo"), sql.SQL("bar")]) assert obj.as_bytes(conn) == b"foobar" + assert obj.as_bytes() == b"foobar" @pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")]) def test_as_bytes_encoding(self, conn, encoding): @@ -570,17 +616,21 @@ class TestPlaceholder: def test_as_string(self, conn, format): ph = sql.Placeholder(format=format) assert ph.as_string(conn) == f"%{format.value}" + assert ph.as_string() == f"%{format.value}" ph = sql.Placeholder(name="foo", format=format) assert ph.as_string(conn) == f"%(foo){format.value}" + assert ph.as_string() == f"%(foo){format.value}" @pytest.mark.parametrize("format", PyFormat) def test_as_bytes(self, conn, format): ph = sql.Placeholder(format=format) - assert ph.as_bytes(conn) == f"%{format.value}".encode("ascii") + assert ph.as_bytes(conn) == f"%{format.value}".encode() + assert ph.as_bytes() == f"%{format.value}".encode() ph = sql.Placeholder(name="foo", format=format) - assert ph.as_bytes(conn) == f"%(foo){format.value}".encode("ascii") + assert ph.as_bytes(conn) == f"%(foo){format.value}".encode() + assert ph.as_bytes() == f"%(foo){format.value}".encode() class TestValues: @@ -595,7 +645,7 @@ class TestValues: def no_e(s): """Drop an eventual E from E'' quotes""" - if isinstance(s, memoryview): + if isinstance(s, (memoryview, bytearray)): s = bytes(s) if isinstance(s, str): diff --git a/tests/test_tpc.py b/tests/test_tpc.py index 41023ccc8..864f8a786 100644 --- a/tests/test_tpc.py +++ b/tests/test_tpc.py @@ -22,6 +22,7 @@ def test_tpc_disabled(conn, pipeline): class TestTPC: + def test_tpc_commit(self, conn, tpc): xid = conn.xid(1, "gtrid", "bqual") assert conn.info.transaction_status == TransactionStatus.IDLE diff --git a/tests/test_waiting.py b/tests/test_waiting.py index 6a9ad88f3..c4d8915e8 100644 --- a/tests/test_waiting.py +++ b/tests/test_waiting.py @@ -1,6 +1,7 @@ +import sys +import time import select # noqa: used in pytest.mark.skipif import socket -import sys import pytest @@ -26,6 +27,7 @@ waitfns = [ pytest.param("wait_c", marks=pytest.mark.skipif("not psycopg._cmodule._psycopg")), ] +events = ["R", "W", "RW"] timeouts = [pytest.param({}, id="blank")] timeouts += [pytest.param({"timeout": x}, id=str(x)) for x in [None, 0, 0.2, 10]] @@ -44,9 +46,11 @@ def test_wait_conn_bad(dsn): @pytest.mark.parametrize("waitfn", waitfns) -@pytest.mark.parametrize("wait, ready", zip(waiting.Wait, waiting.Ready)) +@pytest.mark.parametrize("event", events) @skip_if_not_linux -def test_wait_ready(waitfn, wait, ready): +def test_wait_ready(waitfn, event): + wait = getattr(waiting.Wait, event) + ready = getattr(waiting.Ready, event) waitfn = getattr(waiting, waitfn) def gen(): @@ -80,6 +84,34 @@ def test_wait_bad(pgconn, waitfn): waitfn(gen, pgconn.socket) +@pytest.mark.slow +@pytest.mark.timing +@pytest.mark.parametrize("waitfn", waitfns) +def test_wait_timeout(pgconn, waitfn): + waitfn = getattr(waiting, waitfn) + + pgconn.send_query(b"select pg_sleep(0.5)") + gen = generators.execute(pgconn) + + ts = [time.time()] + + def gen_wrapper(): + try: + for x in gen: + res = yield x + ts.append(time.time()) + gen.send(res) + except StopIteration as ex: + return ex.value + + (res,) = waitfn(gen_wrapper(), pgconn.socket, timeout=0.1) + assert res.status == ExecStatus.TUPLES_OK + ds = [t1 - t0 for t0, t1 in zip(ts[:-1], ts[1:])] + assert len(ds) >= 5 + for d in ds[:5]: + assert d == pytest.approx(0.1, 0.05) + + @pytest.mark.slow @pytest.mark.skipif( "sys.platform == 'win32'", reason="win32 works ok, but FDs are mysterious" @@ -130,9 +162,12 @@ async def test_wait_conn_async_bad(dsn): @pytest.mark.anyio -@pytest.mark.parametrize("wait, ready", zip(waiting.Wait, waiting.Ready)) +@pytest.mark.parametrize("event", events) @skip_if_not_linux -async def test_wait_ready_async(wait, ready): +async def test_wait_ready_async(event): + wait = getattr(waiting.Wait, event) + ready = getattr(waiting.Ready, event) + def gen(): r = yield wait return r diff --git a/tests/types/test_datetime.py b/tests/types/test_datetime.py index b582a96d9..2f283885f 100644 --- a/tests/types/test_datetime.py +++ b/tests/types/test_datetime.py @@ -713,6 +713,7 @@ class TestInterval: ("-90d", "-3 month"), ("186d", "6 mons 6 days"), ("736d", "2 years 6 days"), + ("83063d,81640s,447000m", "1993534:40:40.447"), ], ) @pytest.mark.parametrize("fmt_out", pq.Format) diff --git a/tests/types/test_numeric.py b/tests/types/test_numeric.py index fffc32f31..8e5db7668 100644 --- a/tests/types/test_numeric.py +++ b/tests/types/test_numeric.py @@ -402,9 +402,11 @@ def test_dump_numeric_binary(conn, expr): @pytest.mark.parametrize( "fmt_in", [ - f - if f != PyFormat.BINARY - else pytest.param(f, marks=pytest.mark.crdb_skip("binary decimal")) + ( + f + if f != PyFormat.BINARY + else pytest.param(f, marks=pytest.mark.crdb_skip("binary decimal")) + ) for f in PyFormat ], ) diff --git a/tools/async_to_sync.py b/tools/async_to_sync.py index d264c78bd..c4c053a1a 100755 --- a/tools/async_to_sync.py +++ b/tools/async_to_sync.py @@ -48,6 +48,7 @@ ALL_INPUTS = """ tests/test_cursor_common_async.py tests/test_cursor_raw_async.py tests/test_cursor_server_async.py + tests/test_notify_async.py tests/test_pipeline_async.py tests/test_prepared_async.py tests/test_tpc_async.py diff --git a/tools/update_oids.py b/tools/update_oids.py index 22f04ec53..7e303bb51 100755 --- a/tools/update_oids.py +++ b/tools/update_oids.py @@ -19,11 +19,11 @@ import argparse import subprocess as sp from typing import List from pathlib import Path -from typing_extensions import TypeAlias import psycopg from psycopg.rows import TupleRow from psycopg.crdb import CrdbConnection +from psycopg._compat import TypeAlias Connection: TypeAlias = psycopg.Connection[TupleRow]