[flake8]
max-line-length = 88
-ignore = W503, E203
+ignore = W503, E203, E704
extend-exclude = .venv build
per-file-ignores =
# Autogenerated section
--health-retries 5
steps:
- - uses: actions/checkout@v3
+ - uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
--health-retries 5
steps:
- - uses: actions/checkout@v3
+ - uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
if: true
steps:
- - uses: actions/checkout@v3
+ - uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
platform: [manylinux, musllinux]
steps:
- - uses: actions/checkout@v3
+ - uses: actions/checkout@v4
- name: Set up QEMU for multi-arch build
# Check https://github.com/docker/setup-qemu-action for newer versions.
run: python3 ./tools/build/copy_to_binary.py
- name: Build wheels
- uses: pypa/cibuildwheel@v2.16.0
+ uses: pypa/cibuildwheel@v2.16.5
with:
package-dir: psycopg_binary
env:
pyver: [cp38, cp39, cp310, cp311, cp312]
steps:
- - uses: actions/checkout@v3
+ - uses: actions/checkout@v4
- name: Create the binary package source tree
run: python3 ./tools/build/copy_to_binary.py
- name: Build wheels
- uses: pypa/cibuildwheel@v2.16.0
+ uses: pypa/cibuildwheel@v2.16.5
with:
package-dir: psycopg_binary
env:
pyver: [cp38, cp39, cp310, cp311, cp312]
steps:
- - uses: actions/checkout@v3
+ - uses: actions/checkout@v4
- name: Start PostgreSQL service for test
run: |
run: python3 ./tools/build/copy_to_binary.py
- name: Build wheels
- uses: pypa/cibuildwheel@v2.16.0
+ uses: pypa/cibuildwheel@v2.16.5
with:
package-dir: psycopg_binary
env:
- {package: psycopg_pool, format: wheel}
steps:
- - uses: actions/checkout@v3
+ - uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
- {package: psycopg_c, format: sdist, impl: c}
steps:
- - uses: actions/checkout@v3
+ - uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
MARKERS: ""
steps:
- - uses: actions/checkout@v3
+ - uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
NOT_MARKERS: "timing proxy mypy"
steps:
- - uses: actions/checkout@v3
+ - uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
shell: bash
steps:
- - uses: actions/checkout@v3
+ - uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
PSYCOPG_TEST_DSN: "host=127.0.0.1 port=26257 user=root dbname=defaultdb"
steps:
- - uses: actions/checkout@v3
+ - uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
Notifications are received as instances of `Notify`. If you are reserving a
connection only to receive notifications, the simplest way is to consume the
`Connection.notifies` generator. The generator can be stopped using
-`!close()`.
+`!close()`. Starting from Psycopg 3.2, the method supports options to receive
+notifications only for a certain time or up to a certain number.
.. note::
any sessions in the database generates a :sql:`NOTIFY` on one of the
listened channels.
+ .. versionchanged:: 3.2
+
+ Added `!timeout` and `!stop_after` parameters.
+
.. automethod:: add_notify_handler
See :ref:`async-notify` for details.
...
.. automethod:: notifies
+
+ .. versionchanged:: 3.2
+
+ Added `!timeout` and `!stop_after` parameters.
+
.. automethod:: set_autocommit
.. automethod:: set_isolation_level
.. automethod:: set_read_only
.. autofunction:: tuple_row
.. autofunction:: dict_row
.. autofunction:: namedtuple_row
+.. autofunction:: scalar_row
+
+ .. versionadded:: 3.2
+
.. autofunction:: class_row
This is not a row factory, but rather a factory of row factories.
.. autoclass:: Composable()
- .. automethod:: as_bytes
.. automethod:: as_string
+ .. versionchanged:: 3.2
+
+ The `!context` parameter is optional.
+
+ .. warning::
+
+ If a context is not specified, the results are "generic" and not
+ tailored for a specific target connection. Details such as the
+ connection encoding and escaping style will not be taken into
+ account.
+
+ .. automethod:: as_bytes
+
+ .. versionchanged:: 3.2
+
+ The `!context` parameter is optional. See `as_string` for details.
+
.. autoclass:: SQL
- Add support for integer, floating point, boolean `NumPy scalar types`__
(:ticket:`#332`).
+- Add `!timeout` and `!stop_after` parameters to `Connection.notifies()`
+ (:ticket:`340`).
- Add :ref:`raw-query-cursors` to execute queries using placeholders in
PostgreSQL format (`$1`, `$2`...) (:ticket:`#560`).
+- Add `~rows.scalar_row` to return scalar values from a query (:ticket:`#723`).
- Add `~Connection.set_autocommit()` on sync connections, and similar
transaction control methods available on the async connections.
- Add support for libpq functions to close prepared statements and portals
introduced in libpq v17 (:ticket:`#603`).
+- The `!context` parameter of `sql` objects `~sql.Composable.as_string()` and
+ `~sql.Composable.as_bytes()` methods is not optional (:ticket:`#716`).
- Disable receiving more than one result on the same cursor in pipeline mode,
to iterate through `~Cursor.nextset()`. The behaviour was different than
in non-pipeline mode and not totally reliable (:ticket:`#604`).
.. __: https://numpy.org/doc/stable/reference/arrays.scalars.html#built-in-scalar-types
-Psycopg 3.1.17 (unreleased)
+Psycopg 3.1.18 (unreleased)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
-- Fix multiple connection attempts when a host name resolve to multiple
- IP addresses (:ticket:`699`).
-- Use `typing.Self` as a more correct return value annotation of context
- managers and other self-returning methods (see :ticket:`708`).
+- Fix possible deadlock on pipeline exit (:ticket:`#685`).
+- Fix overflow loading large intervals in C module (:ticket:`#719`).
+- Fix compatibility with musl libc distributions affected by `CPython issue
+ #65821`__ (:ticket:`#725`).
+
+.. __: https://github.com/python/cpython/issues/65821
Current release
---------------
+Psycopg 3.1.17
+^^^^^^^^^^^^^^
+
+- Fix multiple connection attempts when a host name resolve to multiple
+ IP addresses (:ticket:`#699`).
+- Use `typing.Self` as a more correct return value annotation of context
+ managers and other self-returning methods (see :ticket:`#708`).
+
+
Psycopg 3.1.16
^^^^^^^^^^^^^^
[flake8]
max-line-length = 88
-ignore = W503, E203
+ignore = W503, E203, E704
per-file-ignores =
# Autogenerated section
psycopg/errors.py: E125, E128, E302
import threading
from typing import Any, Callable, Coroutine, TYPE_CHECKING
-from typing_extensions import TypeAlias
-
-from ._compat import TypeVar
+from ._compat import TypeAlias, TypeVar
Worker: TypeAlias = threading.Thread
AWorker: TypeAlias = "asyncio.Task[None]"
cache = lru_cache(maxsize=None)
if sys.version_info >= (3, 10):
- from typing import TypeGuard
+ from typing import TypeGuard, TypeAlias
else:
- from typing_extensions import TypeGuard
+ from typing_extensions import TypeGuard, TypeAlias
if sys.version_info >= (3, 11):
from typing import LiteralString, Self
"Deque",
"LiteralString",
"Self",
+ "TypeAlias",
"TypeGuard",
"TypeVar",
"ZoneInfo",
from weakref import ref, ReferenceType
from warnings import warn
from functools import partial
-from typing_extensions import TypeAlias
from . import pq
from . import errors as e
from .rows import Row
from .adapt import AdaptersMap
from ._enums import IsolationLevel
-from ._compat import LiteralString, Self, TypeVar
+from ._compat import LiteralString, Self, TypeAlias, TypeVar
from .pq.misc import connection_summary
from ._pipeline import BasePipeline
from ._encodings import pgconn_encoding
from random import shuffle
from . import errors as e
-from ._conninfo_utils import ConnDict, get_param, is_ip_address, get_param_def
+from .abc import ConnDict, ConnMapping
+from ._conninfo_utils import get_param, is_ip_address, get_param_def
from ._conninfo_utils import split_attempts
logger = logging.getLogger("psycopg")
-def conninfo_attempts(params: ConnDict) -> list[ConnDict]:
+def conninfo_attempts(params: ConnMapping) -> list[ConnDict]:
"""Split a set of connection params on the single attempts to perform.
A connection param can perform more than one attempt more than one ``host``
from random import shuffle
from . import errors as e
-from ._conninfo_utils import ConnDict, get_param, is_ip_address, get_param_def
+from .abc import ConnDict, ConnMapping
+from ._conninfo_utils import get_param, is_ip_address, get_param_def
from ._conninfo_utils import split_attempts
if True: # ASYNC:
logger = logging.getLogger("psycopg")
-async def conninfo_attempts_async(params: ConnDict) -> list[ConnDict]:
+async def conninfo_attempts_async(params: ConnMapping) -> list[ConnDict]:
"""Split a set of connection params on the single attempts to perform.
A connection param can perform more than one attempt more than one ``host``
from __future__ import annotations
import os
-from typing import Any
+from typing import TYPE_CHECKING
from functools import lru_cache
from ipaddress import ip_address
from dataclasses import dataclass
-from typing_extensions import TypeAlias
from . import pq
+from .abc import ConnDict, ConnMapping
from . import errors as e
-ConnDict: TypeAlias = "dict[str, Any]"
+if TYPE_CHECKING:
+ from typing import Any # noqa: F401
-def split_attempts(params: ConnDict) -> list[ConnDict]:
+def split_attempts(params: ConnMapping) -> list[ConnDict]:
"""
Split connection parameters with a sequence of hosts into separate attempts.
"""
# A single attempt to make. Don't mangle the conninfo string.
if nhosts <= 1:
- return [params]
+ return [{**params}]
if len(ports) == 1:
ports *= nhosts
# Now all lists are either empty or have the same length
rv = []
for i in range(nhosts):
- attempt = params.copy()
+ attempt = {**params}
if hosts:
attempt["host"] = hosts[i]
if hostaddrs:
return rv
-def get_param(params: ConnDict, name: str) -> str | None:
+def get_param(params: ConnMapping, name: str) -> str | None:
"""
Return a value from a connection string.
self._row_mode = False # true if the user is using write_row()
@abstractmethod
- def parse_row(self, data: Buffer) -> Optional[Tuple[Any, ...]]:
- ...
+ def parse_row(self, data: Buffer) -> Optional[Tuple[Any, ...]]: ...
@abstractmethod
- def write(self, buffer: Union[Buffer, str]) -> Buffer:
- ...
+ def write(self, buffer: Union[Buffer, str]) -> Buffer: ...
@abstractmethod
- def write_row(self, row: Sequence[Any]) -> Buffer:
- ...
+ def write_row(self, row: Sequence[Any]) -> Buffer: ...
@abstractmethod
- def end(self) -> Buffer:
- ...
+ def end(self) -> Buffer: ...
class TextFormatter(Formatter):
pgenc = params.get("client_encoding")
if pgenc:
try:
- return pg2pyenc(pgenc.encode())
+ return pg2pyenc(str(pgenc).encode())
except NotSupportedError:
pass
class Ready(IntEnum):
+ NONE = 0
R = EVENT_READ
W = EVENT_WRITE
RW = EVENT_READ | EVENT_WRITE
import logging
from types import TracebackType
from typing import Any, List, Optional, Union, Tuple, Type, TYPE_CHECKING
-from typing_extensions import TypeAlias
from . import pq
from . import errors as e
from .abc import PipelineCommand, PQGen
-from ._compat import Deque, Self
+from ._compat import Deque, Self, TypeAlias
from .pq.misc import connection_summary
from ._encodings import pgconn_encoding
from ._preparing import Key, Prepare
self._enqueue_sync()
yield from self._communicate_gen()
finally:
- # No need to force flush since we emitted a sync just before.
- yield from self._fetch_gen(flush=False)
+ yield from self._fetch_gen(flush=True)
def _communicate_gen(self) -> PQGen[None]:
"""Communicate with pipeline to send commands and possibly fetch
from enum import IntEnum, auto
from typing import Iterator, Optional, Sequence, Tuple, TYPE_CHECKING
from collections import OrderedDict
-from typing_extensions import TypeAlias
from . import pq
-from ._compat import Deque
+from ._compat import Deque, TypeAlias
from ._queries import PostgresQuery
if TYPE_CHECKING:
from typing import Any, Dict, List, Optional, Sequence, Tuple
from typing import DefaultDict, TYPE_CHECKING
from collections import defaultdict
-from typing_extensions import TypeAlias
from . import pq
from . import abc
from .abc import Buffer, LoadFunc, AdaptContext, PyFormat, DumperKey, NoneType
from .rows import Row, RowMaker
from ._oids import INVALID_OID, TEXT_OID
+from ._compat import TypeAlias
from ._encodings import conn_encoding
if TYPE_CHECKING:
from typing import Any, Callable, Dict, List, Mapping, Match, NamedTuple, Optional
from typing import Sequence, Tuple, Union, TYPE_CHECKING
from functools import lru_cache
-from typing_extensions import TypeAlias
from . import pq
from . import errors as e
from .sql import Composable
from .abc import Buffer, Query, Params
from ._enums import PyFormat
-from ._compat import TypeGuard
+from ._compat import TypeAlias, TypeGuard
from ._encodings import conn_encoding
if TYPE_CHECKING:
import struct
from typing import Callable, cast, Optional, Protocol, Tuple
-from typing_extensions import TypeAlias
-from .abc import Buffer
from . import errors as e
+from .abc import Buffer
+from ._compat import TypeAlias
PackInt: TypeAlias = Callable[[int], bytes]
UnpackInt: TypeAlias = Callable[[Buffer], Tuple[int]]
class UnpackLen(Protocol):
- def __call__(self, data: Buffer, start: Optional[int]) -> Tuple[int]:
- ...
+ def __call__(self, data: Buffer, start: Optional[int]) -> Tuple[int]: ...
pack_int2 = cast(PackInt, struct.Struct("!h").pack)
from typing import Any, Dict, Iterator, Optional, overload
from typing import Sequence, Tuple, Type, Union, TYPE_CHECKING
-from typing_extensions import TypeAlias
from . import sql
from . import errors as e
from .abc import AdaptContext, Query
from .rows import dict_row
-from ._compat import TypeVar
+from ._compat import TypeAlias, TypeVar
from ._encodings import conn_encoding
if TYPE_CHECKING:
@classmethod
def fetch(
cls: Type[T], conn: "Connection[Any]", name: Union[str, sql.Identifier]
- ) -> Optional[T]:
- ...
+ ) -> Optional[T]: ...
@overload
@classmethod
async def fetch(
cls: Type[T], conn: "AsyncConnection[Any]", name: Union[str, sql.Identifier]
- ) -> Optional[T]:
- ...
+ ) -> Optional[T]: ...
@classmethod
def fetch(
yield t
@overload
- def __getitem__(self, key: Union[str, int]) -> TypeInfo:
- ...
+ def __getitem__(self, key: Union[str, int]) -> TypeInfo: ...
@overload
- def __getitem__(self, key: Tuple[Type[T], int]) -> T:
- ...
+ def __getitem__(self, key: Tuple[Type[T], int]) -> T: ...
def __getitem__(self, key: RegistryKey) -> TypeInfo:
"""
raise KeyError(f"couldn't find the type {key!r} in the types registry")
@overload
- def get(self, key: Union[str, int]) -> Optional[TypeInfo]:
- ...
+ def get(self, key: Union[str, int]) -> Optional[TypeInfo]: ...
@overload
- def get(self, key: Tuple[Type[T], int]) -> Optional[T]:
- ...
+ def get(self, key: Tuple[Type[T], int]) -> Optional[T]: ...
def get(self, key: RegistryKey) -> Optional[TypeInfo]:
"""
# Copyright (C) 2020 The Psycopg Team
-from typing import Any, Callable, Generator, Mapping
+from typing import Any, Dict, Callable, Generator, Mapping
from typing import List, Optional, Protocol, Sequence, Tuple, Union
from typing import TYPE_CHECKING
-from typing_extensions import TypeAlias
from . import pq
from ._enums import PyFormat as PyFormat
-from ._compat import LiteralString, TypeVar
+from ._compat import LiteralString, TypeAlias, TypeVar
if TYPE_CHECKING:
from . import sql
ConnectionType = TypeVar("ConnectionType", bound="BaseConnection[Any]")
PipelineCommand: TypeAlias = Callable[[], None]
DumperKey: TypeAlias = Union[type, Tuple["DumperKey", ...]]
+ConnParam: TypeAlias = Union[str, int, None]
+ConnDict: TypeAlias = Dict[str, ConnParam]
+ConnMapping: TypeAlias = Mapping[str, ConnParam]
+
# Waiting protocol types
RV = TypeVar("RV")
-PQGenConn: TypeAlias = Generator[Tuple[int, "Wait"], "Ready", RV]
+PQGenConn: TypeAlias = Generator[Tuple[int, "Wait"], Union["Ready", int], RV]
"""Generator for processes where the connection file number can change.
This can happen in connection and reset, but not in normal querying.
"""
-PQGen: TypeAlias = Generator["Wait", "Ready", RV]
+PQGen: TypeAlias = Generator["Wait", Union["Ready", int], RV]
"""Generator for processes where the connection file number won't change.
"""
def __call__(
self, gen: PQGen[RV], fileno: int, timeout: Optional[float] = None
- ) -> RV:
- ...
+ ) -> RV: ...
# Adaptation types
oid: int
"""The oid to pass to the server, if known; 0 otherwise (class attribute)."""
- def __init__(self, cls: type, context: Optional[AdaptContext] = None):
- ...
+ def __init__(self, cls: type, context: Optional[AdaptContext] = None): ...
def dump(self, obj: Any) -> Buffer:
"""Convert the object `!obj` to PostgreSQL representation.
This is a class attribute.
"""
- def __init__(self, oid: int, context: Optional[AdaptContext] = None):
- ...
+ def __init__(self, oid: int, context: Optional[AdaptContext] = None): ...
def load(self, data: Buffer) -> Any:
"""
types: Optional[Tuple[int, ...]]
formats: Optional[List[pq.Format]]
- def __init__(self, context: Optional[AdaptContext] = None):
- ...
+ def __init__(self, context: Optional[AdaptContext] = None): ...
@classmethod
- def from_context(cls, context: Optional[AdaptContext]) -> "Transformer":
- ...
+ def from_context(cls, context: Optional[AdaptContext]) -> "Transformer": ...
@property
- def connection(self) -> Optional["BaseConnection[Any]"]:
- ...
+ def connection(self) -> Optional["BaseConnection[Any]"]: ...
@property
- def encoding(self) -> str:
- ...
+ def encoding(self) -> str: ...
@property
- def adapters(self) -> "AdaptersMap":
- ...
+ def adapters(self) -> "AdaptersMap": ...
@property
- def pgresult(self) -> Optional["PGresult"]:
- ...
+ def pgresult(self) -> Optional["PGresult"]: ...
def set_pgresult(
self,
*,
set_loaders: bool = True,
format: Optional[pq.Format] = None
- ) -> None:
- ...
+ ) -> None: ...
- def set_dumper_types(self, types: Sequence[int], format: pq.Format) -> None:
- ...
+ def set_dumper_types(self, types: Sequence[int], format: pq.Format) -> None: ...
- def set_loader_types(self, types: Sequence[int], format: pq.Format) -> None:
- ...
+ def set_loader_types(self, types: Sequence[int], format: pq.Format) -> None: ...
def dump_sequence(
self, params: Sequence[Any], formats: Sequence[PyFormat]
- ) -> Sequence[Optional[Buffer]]:
- ...
+ ) -> Sequence[Optional[Buffer]]: ...
- def as_literal(self, obj: Any) -> bytes:
- ...
+ def as_literal(self, obj: Any) -> bytes: ...
- def get_dumper(self, obj: Any, format: PyFormat) -> Dumper:
- ...
+ def get_dumper(self, obj: Any, format: PyFormat) -> Dumper: ...
- def load_rows(self, row0: int, row1: int, make_row: "RowMaker[Row]") -> List["Row"]:
- ...
+ def load_rows(
+ self, row0: int, row1: int, make_row: "RowMaker[Row]"
+ ) -> List["Row"]: ...
- def load_row(self, row: int, make_row: "RowMaker[Row]") -> Optional["Row"]:
- ...
+ def load_row(self, row: int, make_row: "RowMaker[Row]") -> Optional["Row"]: ...
- def load_sequence(self, record: Sequence[Optional[Buffer]]) -> Tuple[Any, ...]:
- ...
+ def load_sequence(self, record: Sequence[Optional[Buffer]]) -> Tuple[Any, ...]: ...
- def get_loader(self, oid: int, format: pq.Format) -> Loader:
- ...
+ def get_loader(self, oid: int, format: pq.Format) -> Loader: ...
)
@abstractmethod
- def dump(self, obj: Any) -> Buffer:
- ...
+ def dump(self, obj: Any) -> Buffer: ...
def quote(self, obj: Any) -> Buffer:
"""
from __future__ import annotations
import logging
+from time import monotonic
from types import TracebackType
from typing import Any, Generator, Iterator, List, Optional
from typing import Type, Union, cast, overload, TYPE_CHECKING
from . import pq
from . import errors as e
from . import waiting
-from .abc import AdaptContext, Params, PQGen, PQGenConn, Query, RV
+from .abc import AdaptContext, ConnDict, ConnParam, Params, PQGen, PQGenConn, Query, RV
from ._tpc import Xid
from .rows import Row, RowFactory, tuple_row, args_row
from .adapt import AdaptersMap
from ._enums import IsolationLevel
from ._compat import Self
-from .conninfo import ConnDict, make_conninfo, conninfo_to_dict
+from .conninfo import make_conninfo, conninfo_to_dict
from .conninfo import conninfo_attempts, timeout_from_conninfo
from ._pipeline import Pipeline
from ._encodings import pgconn_encoding
if TYPE_CHECKING:
from .pq.abc import PGconn
+_WAIT_INTERVAL = 0.1
+
TEXT = pq.Format.TEXT
BINARY = pq.Format.BINARY
IDLE = pq.TransactionStatus.IDLE
+ACTIVE = pq.TransactionStatus.ACTIVE
INTRANS = pq.TransactionStatus.INTRANS
_INTERRUPTED = KeyboardInterrupt
context: Optional[AdaptContext] = None,
row_factory: Optional[RowFactory[Row]] = None,
cursor_factory: Optional[Type[Cursor[Row]]] = None,
- **kwargs: Any,
+ **kwargs: ConnParam,
) -> Self:
"""
Connect to a database server and return a new `Connection` instance.
attempts = conninfo_attempts(params)
for attempt in attempts:
try:
- conninfo = make_conninfo(**attempt)
+ conninfo = make_conninfo("", **attempt)
rv = cls._wait_conn(cls._connect_gen(conninfo), timeout=timeout)
break
except e._NO_TRACEBACK as ex:
self.pgconn.finish()
@overload
- def cursor(self, *, binary: bool = False) -> Cursor[Row]:
- ...
+ def cursor(self, *, binary: bool = False) -> Cursor[Row]: ...
@overload
def cursor(
self, *, binary: bool = False, row_factory: RowFactory[CursorRow]
- ) -> Cursor[CursorRow]:
- ...
+ ) -> Cursor[CursorRow]: ...
@overload
def cursor(
binary: bool = False,
scrollable: Optional[bool] = None,
withhold: bool = False,
- ) -> ServerCursor[Row]:
- ...
+ ) -> ServerCursor[Row]: ...
@overload
def cursor(
row_factory: RowFactory[CursorRow],
scrollable: Optional[bool] = None,
withhold: bool = False,
- ) -> ServerCursor[CursorRow]:
- ...
+ ) -> ServerCursor[CursorRow]: ...
def cursor(
self,
with tx:
yield tx
- def notifies(self) -> Generator[Notify, None, None]:
+ def notifies(
+ self, *, timeout: Optional[float] = None, stop_after: Optional[int] = None
+ ) -> Generator[Notify, None, None]:
"""
Yield `Notify` objects as soon as they are received from the database.
+
+ :param timeout: maximum amount of time to wait for notifications.
+ `!None` means no timeout.
+ :param stop_after: stop after receiving this number of notifications.
+ You might actually receive more than this number if more than one
+ notifications arrives in the same packet.
"""
+ # Allow interrupting the wait with a signal by reducing a long timeout
+ # into shorter interval.
+ if timeout is not None:
+ deadline = monotonic() + timeout
+ timeout = min(timeout, _WAIT_INTERVAL)
+ else:
+ deadline = None
+ timeout = _WAIT_INTERVAL
+
+ nreceived = 0
+
while True:
- with self.lock:
- try:
- ns = self.wait(notifies(self.pgconn))
- except e._NO_TRACEBACK as ex:
- raise ex.with_traceback(None)
- enc = pgconn_encoding(self.pgconn)
+ # Collect notifications. Also get the connection encoding if any
+ # notification is received to makes sure that they are consistent.
+ try:
+ with self.lock:
+ ns = self.wait(notifies(self.pgconn), timeout=timeout)
+ if ns:
+ enc = pgconn_encoding(self.pgconn)
+ except e._NO_TRACEBACK as ex:
+ raise ex.with_traceback(None)
+
+ # Emit the notifications received.
for pgn in ns:
n = Notify(pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid)
yield n
+ nreceived += 1
+
+ # Stop if we have received enough notifications.
+ if stop_after is not None and nreceived >= stop_after:
+ break
+
+ # Check the deadline after the loop to ensure that timeout=0
+ # polls at least once.
+ if deadline:
+ timeout = min(_WAIT_INTERVAL, deadline - monotonic())
+ if timeout < 0.0:
+ break
@contextmanager
def pipeline(self) -> Iterator[Pipeline]:
assert pipeline is self._pipeline
self._pipeline = None
- def wait(self, gen: PQGen[RV], timeout: Optional[float] = 0.1) -> RV:
+ def wait(self, gen: PQGen[RV], timeout: Optional[float] = _WAIT_INTERVAL) -> RV:
"""
Consume a generator operating on the connection.
try:
return waiting.wait(gen, self.pgconn.socket, timeout=timeout)
except _INTERRUPTED:
- # On Ctrl-C, try to cancel the query in the server, otherwise
- # the connection will remain stuck in ACTIVE state.
- self._try_cancel(self.pgconn)
- try:
- waiting.wait(gen, self.pgconn.socket, timeout=timeout)
- except e.QueryCanceled:
- pass # as expected
+ if self.pgconn.transaction_status == ACTIVE:
+ # On Ctrl-C, try to cancel the query in the server, otherwise
+ # the connection will remain stuck in ACTIVE state.
+ self._try_cancel(self.pgconn)
+ try:
+ waiting.wait(gen, self.pgconn.socket, timeout=timeout)
+ except e.QueryCanceled:
+ pass # as expected
raise
@classmethod
from __future__ import annotations
import logging
+from time import monotonic
from types import TracebackType
from typing import Any, AsyncGenerator, AsyncIterator, List, Optional
from typing import Type, Union, cast, overload, TYPE_CHECKING
from . import pq
from . import errors as e
from . import waiting
-from .abc import AdaptContext, Params, PQGen, PQGenConn, Query, RV
+from .abc import AdaptContext, ConnDict, ConnParam, Params, PQGen, PQGenConn, Query, RV
from ._tpc import Xid
from .rows import Row, AsyncRowFactory, tuple_row, args_row
from .adapt import AdaptersMap
from ._enums import IsolationLevel
from ._compat import Self
-from .conninfo import ConnDict, make_conninfo, conninfo_to_dict
+from .conninfo import make_conninfo, conninfo_to_dict
from .conninfo import conninfo_attempts_async, timeout_from_conninfo
from ._pipeline import AsyncPipeline
from ._encodings import pgconn_encoding
if TYPE_CHECKING:
from .pq.abc import PGconn
+_WAIT_INTERVAL = 0.1
+
TEXT = pq.Format.TEXT
BINARY = pq.Format.BINARY
IDLE = pq.TransactionStatus.IDLE
+ACTIVE = pq.TransactionStatus.ACTIVE
INTRANS = pq.TransactionStatus.INTRANS
if True: # ASYNC
context: Optional[AdaptContext] = None,
row_factory: Optional[AsyncRowFactory[Row]] = None,
cursor_factory: Optional[Type[AsyncCursor[Row]]] = None,
- **kwargs: Any,
+ **kwargs: ConnParam,
) -> Self:
"""
Connect to a database server and return a new `AsyncConnection` instance.
attempts = await conninfo_attempts_async(params)
for attempt in attempts:
try:
- conninfo = make_conninfo(**attempt)
+ conninfo = make_conninfo("", **attempt)
rv = await cls._wait_conn(cls._connect_gen(conninfo), timeout=timeout)
break
except e._NO_TRACEBACK as ex:
self.pgconn.finish()
@overload
- def cursor(self, *, binary: bool = False) -> AsyncCursor[Row]:
- ...
+ def cursor(self, *, binary: bool = False) -> AsyncCursor[Row]: ...
@overload
def cursor(
self, *, binary: bool = False, row_factory: AsyncRowFactory[CursorRow]
- ) -> AsyncCursor[CursorRow]:
- ...
+ ) -> AsyncCursor[CursorRow]: ...
@overload
def cursor(
binary: bool = False,
scrollable: Optional[bool] = None,
withhold: bool = False,
- ) -> AsyncServerCursor[Row]:
- ...
+ ) -> AsyncServerCursor[Row]: ...
@overload
def cursor(
row_factory: AsyncRowFactory[CursorRow],
scrollable: Optional[bool] = None,
withhold: bool = False,
- ) -> AsyncServerCursor[CursorRow]:
- ...
+ ) -> AsyncServerCursor[CursorRow]: ...
def cursor(
self,
async with tx:
yield tx
- async def notifies(self) -> AsyncGenerator[Notify, None]:
+ async def notifies(
+ self, *, timeout: Optional[float] = None, stop_after: Optional[int] = None
+ ) -> AsyncGenerator[Notify, None]:
"""
Yield `Notify` objects as soon as they are received from the database.
+
+ :param timeout: maximum amount of time to wait for notifications.
+ `!None` means no timeout.
+ :param stop_after: stop after receiving this number of notifications.
+ You might actually receive more than this number if more than one
+ notifications arrives in the same packet.
"""
+ # Allow interrupting the wait with a signal by reducing a long timeout
+ # into shorter interval.
+ if timeout is not None:
+ deadline = monotonic() + timeout
+ timeout = min(timeout, _WAIT_INTERVAL)
+ else:
+ deadline = None
+ timeout = _WAIT_INTERVAL
+
+ nreceived = 0
+
while True:
- async with self.lock:
- try:
- ns = await self.wait(notifies(self.pgconn))
- except e._NO_TRACEBACK as ex:
- raise ex.with_traceback(None)
- enc = pgconn_encoding(self.pgconn)
+ # Collect notifications. Also get the connection encoding if any
+ # notification is received to makes sure that they are consistent.
+ try:
+ async with self.lock:
+ ns = await self.wait(notifies(self.pgconn), timeout=timeout)
+ if ns:
+ enc = pgconn_encoding(self.pgconn)
+ except e._NO_TRACEBACK as ex:
+ raise ex.with_traceback(None)
+
+ # Emit the notifications received.
for pgn in ns:
n = Notify(pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid)
yield n
+ nreceived += 1
+
+ # Stop if we have received enough notifications.
+ if stop_after is not None and nreceived >= stop_after:
+ break
+
+ # Check the deadline after the loop to ensure that timeout=0
+ # polls at least once.
+ if deadline:
+ timeout = min(_WAIT_INTERVAL, deadline - monotonic())
+ if timeout < 0.0:
+ break
@asynccontextmanager
async def pipeline(self) -> AsyncIterator[AsyncPipeline]:
assert pipeline is self._pipeline
self._pipeline = None
- async def wait(self, gen: PQGen[RV], timeout: Optional[float] = 0.1) -> RV:
+ async def wait(
+ self, gen: PQGen[RV], timeout: Optional[float] = _WAIT_INTERVAL
+ ) -> RV:
"""
Consume a generator operating on the connection.
try:
return await waiting.wait_async(gen, self.pgconn.socket, timeout=timeout)
except _INTERRUPTED:
- # On Ctrl-C, try to cancel the query in the server, otherwise
- # the connection will remain stuck in ACTIVE state.
- self._try_cancel(self.pgconn)
- try:
- await waiting.wait_async(gen, self.pgconn.socket, timeout=timeout)
- except e.QueryCanceled:
- pass # as expected
+ if self.pgconn.transaction_status == ACTIVE:
+ # On Ctrl-C, try to cancel the query in the server, otherwise
+ # the connection will remain stuck in ACTIVE state.
+ self._try_cancel(self.pgconn)
+ try:
+ await waiting.wait_async(gen, self.pgconn.socket, timeout=timeout)
+ except e.QueryCanceled:
+ pass # as expected
raise
@classmethod
from __future__ import annotations
import re
-from typing import Any
from . import pq
from . import errors as e
-
from . import _conninfo_utils
from . import _conninfo_attempts
from . import _conninfo_attempts_async
+from .abc import ConnParam, ConnDict
# re-exoprts
-ConnDict = _conninfo_utils.ConnDict
conninfo_attempts = _conninfo_attempts.conninfo_attempts
conninfo_attempts_async = _conninfo_attempts_async.conninfo_attempts_async
_DEFAULT_CONNECT_TIMEOUT = 130
-def make_conninfo(conninfo: str = "", **kwargs: Any) -> str:
+def make_conninfo(conninfo: str = "", **kwargs: ConnParam) -> str:
"""
Merge a string and keyword params into a single conninfo string.
return conninfo
-def conninfo_to_dict(conninfo: str = "", **kwargs: Any) -> ConnDict:
+def conninfo_to_dict(conninfo: str = "", **kwargs: ConnParam) -> ConnDict:
"""
Convert the `!conninfo` string into a dictionary of parameters.
#LIBPQ-CONNSTRING
"""
opts = _parse_conninfo(conninfo)
- rv = {opt.keyword.decode(): opt.val.decode() for opt in opts if opt.val is not None}
+ rv: ConnDict = {
+ opt.keyword.decode(): opt.val.decode() for opt in opts if opt.val is not None
+ }
for k, v in kwargs.items():
if v is not None:
rv[k] = v
__slots__ = ()
@overload
- def __init__(self, connection: Connection[Row]):
- ...
+ def __init__(self, connection: Connection[Row]): ...
@overload
- def __init__(self, connection: Connection[Any], *, row_factory: RowFactory[Row]):
- ...
+ def __init__(
+ self, connection: Connection[Any], *, row_factory: RowFactory[Row]
+ ): ...
def __init__(
self,
__slots__ = ()
@overload
- def __init__(self, connection: AsyncConnection[Row]):
- ...
+ def __init__(self, connection: AsyncConnection[Row]): ...
@overload
def __init__(
self, connection: AsyncConnection[Any], *, row_factory: AsyncRowFactory[Row]
- ):
- ...
+ ): ...
def __init__(
self,
from dataclasses import dataclass, field, fields
from typing import Any, Callable, Dict, List, NoReturn, Optional, Sequence, Tuple, Type
from typing import Union, TYPE_CHECKING
-from typing_extensions import TypeAlias
from asyncio import CancelledError
from .pq.abc import PGconn, PGresult
from .pq._enums import ConnStatus, DiagnosticField, PipelineStatus, TransactionStatus
-from ._compat import TypeGuard
+from ._compat import TypeAlias, TypeGuard
if TYPE_CHECKING:
from .pq.misc import PGnotify, ConninfoOption
functions in the `waiting` module are the ones who wait more or less
cooperatively for the socket to be ready and make these generators continue.
-All these generators yield pairs (fileno, `Wait`) whenever an operation would
-block. The generator can be restarted sending the appropriate `Ready` state
-when the file descriptor is ready.
-
+These generators yield `Wait` objects whenever an operation would block. These
+generators assume the connection fileno will not change. In case of the
+connection function, where the fileno may change, the generators yield pairs
+(fileno, `Wait`).
+
+The generator can be restarted sending the appropriate `Ready` state when the
+file descriptor is ready. If a None value is sent, it means that the wait
+function timed out without any file descriptor becoming ready; in this case the
+generator should probably yield the same value again in order to wait more.
"""
# Copyright (C) 2020 The Psycopg Team
if f == 0:
break
- ready = yield WAIT_RW
+ while True:
+ ready = yield WAIT_RW
+ if ready:
+ break
+
if ready & READY_R:
# This call may read notifies: they will be saved in the
# PGconn buffer and passed to Python later, in `fetch()`.
Return a result from the database (whether success or error).
"""
if pgconn.is_busy():
- yield WAIT_R
+ while True:
+ ready = yield WAIT_R
+ if ready:
+ break
+
while True:
pgconn.consume_input()
if not pgconn.is_busy():
break
- yield WAIT_R
+ while True:
+ ready = yield WAIT_R
+ if ready:
+ break
_consume_notifies(pgconn)
results = []
while True:
- ready = yield WAIT_RW
+ while True:
+ ready = yield WAIT_RW
+ if ready:
+ break
if ready & READY_R:
pgconn.consume_input()
break
# would block
- yield WAIT_R
+ while True:
+ ready = yield WAIT_R
+ if ready:
+ break
pgconn.consume_input()
if nbytes > 0:
# into smaller ones. We prefer to do it there instead of here in order to
# do it upstream the queue decoupling the writer task from the producer one.
while pgconn.put_copy_data(buffer) == 0:
- yield WAIT_W
+ while True:
+ ready = yield WAIT_W
+ if ready:
+ break
def copy_end(pgconn: PGconn, error: Optional[bytes]) -> PQGen[PGresult]:
# Retry enqueuing end copy message until successful
while pgconn.put_copy_end(error) == 0:
- yield WAIT_W
+ while True:
+ ready = yield WAIT_W
+ if ready:
+ break
# Repeat until it the message is flushed to the server
while True:
- yield WAIT_W
+ while True:
+ ready = yield WAIT_W
+ if ready:
+ break
f = pgconn.flush()
if f == 0:
break
if sys.platform == "linux":
libcname = ctypes.util.find_library("c")
- assert libcname
+ if not libcname:
+ # Likely this is a system using musl libc, see the following bug:
+ # https://github.com/python/cpython/issues/65821
+ libcname = "libc.so"
libc = ctypes.cdll.LoadLibrary(libcname)
fdopen = libc.fdopen
from typing import Any, Callable, List, Optional, Protocol, Sequence, Tuple
from typing import Union, TYPE_CHECKING
-from typing_extensions import TypeAlias
from ._enums import Format, Trace
+from .._compat import TypeAlias
if TYPE_CHECKING:
from .misc import PGnotify, ConninfoOption, PGresAttDesc
notify_handler: Optional[Callable[["PGnotify"], None]]
@classmethod
- def connect(cls, conninfo: bytes) -> "PGconn":
- ...
+ def connect(cls, conninfo: bytes) -> "PGconn": ...
@classmethod
- def connect_start(cls, conninfo: bytes) -> "PGconn":
- ...
+ def connect_start(cls, conninfo: bytes) -> "PGconn": ...
- def connect_poll(self) -> int:
- ...
+ def connect_poll(self) -> int: ...
- def finish(self) -> None:
- ...
+ def finish(self) -> None: ...
@property
- def info(self) -> List["ConninfoOption"]:
- ...
+ def info(self) -> List["ConninfoOption"]: ...
- def reset(self) -> None:
- ...
+ def reset(self) -> None: ...
- def reset_start(self) -> None:
- ...
+ def reset_start(self) -> None: ...
- def reset_poll(self) -> int:
- ...
+ def reset_poll(self) -> int: ...
@classmethod
- def ping(self, conninfo: bytes) -> int:
- ...
+ def ping(self, conninfo: bytes) -> int: ...
@property
- def db(self) -> bytes:
- ...
+ def db(self) -> bytes: ...
@property
- def user(self) -> bytes:
- ...
+ def user(self) -> bytes: ...
@property
- def password(self) -> bytes:
- ...
+ def password(self) -> bytes: ...
@property
- def host(self) -> bytes:
- ...
+ def host(self) -> bytes: ...
@property
- def hostaddr(self) -> bytes:
- ...
+ def hostaddr(self) -> bytes: ...
@property
- def port(self) -> bytes:
- ...
+ def port(self) -> bytes: ...
@property
- def tty(self) -> bytes:
- ...
+ def tty(self) -> bytes: ...
@property
- def options(self) -> bytes:
- ...
+ def options(self) -> bytes: ...
@property
- def status(self) -> int:
- ...
+ def status(self) -> int: ...
@property
- def transaction_status(self) -> int:
- ...
+ def transaction_status(self) -> int: ...
- def parameter_status(self, name: bytes) -> Optional[bytes]:
- ...
+ def parameter_status(self, name: bytes) -> Optional[bytes]: ...
@property
- def error_message(self) -> bytes:
- ...
+ def error_message(self) -> bytes: ...
@property
- def server_version(self) -> int:
- ...
+ def server_version(self) -> int: ...
@property
- def socket(self) -> int:
- ...
+ def socket(self) -> int: ...
@property
- def backend_pid(self) -> int:
- ...
+ def backend_pid(self) -> int: ...
@property
- def needs_password(self) -> bool:
- ...
+ def needs_password(self) -> bool: ...
@property
- def used_password(self) -> bool:
- ...
+ def used_password(self) -> bool: ...
@property
- def ssl_in_use(self) -> bool:
- ...
+ def ssl_in_use(self) -> bool: ...
- def exec_(self, command: bytes) -> "PGresult":
- ...
+ def exec_(self, command: bytes) -> "PGresult": ...
- def send_query(self, command: bytes) -> None:
- ...
+ def send_query(self, command: bytes) -> None: ...
def exec_params(
self,
param_types: Optional[Sequence[int]] = None,
param_formats: Optional[Sequence[int]] = None,
result_format: int = Format.TEXT,
- ) -> "PGresult":
- ...
+ ) -> "PGresult": ...
def send_query_params(
self,
param_types: Optional[Sequence[int]] = None,
param_formats: Optional[Sequence[int]] = None,
result_format: int = Format.TEXT,
- ) -> None:
- ...
+ ) -> None: ...
def send_prepare(
self,
name: bytes,
command: bytes,
param_types: Optional[Sequence[int]] = None,
- ) -> None:
- ...
+ ) -> None: ...
def send_query_prepared(
self,
param_values: Optional[Sequence[Optional[Buffer]]],
param_formats: Optional[Sequence[int]] = None,
result_format: int = Format.TEXT,
- ) -> None:
- ...
+ ) -> None: ...
def prepare(
self,
name: bytes,
command: bytes,
param_types: Optional[Sequence[int]] = None,
- ) -> "PGresult":
- ...
+ ) -> "PGresult": ...
def exec_prepared(
self,
param_values: Optional[Sequence[Buffer]],
param_formats: Optional[Sequence[int]] = None,
result_format: int = 0,
- ) -> "PGresult":
- ...
+ ) -> "PGresult": ...
- def describe_prepared(self, name: bytes) -> "PGresult":
- ...
+ def describe_prepared(self, name: bytes) -> "PGresult": ...
- def send_describe_prepared(self, name: bytes) -> None:
- ...
+ def send_describe_prepared(self, name: bytes) -> None: ...
- def describe_portal(self, name: bytes) -> "PGresult":
- ...
+ def describe_portal(self, name: bytes) -> "PGresult": ...
- def send_describe_portal(self, name: bytes) -> None:
- ...
+ def send_describe_portal(self, name: bytes) -> None: ...
- def close_prepared(self, name: bytes) -> "PGresult":
- ...
+ def close_prepared(self, name: bytes) -> "PGresult": ...
- def send_close_prepared(self, name: bytes) -> None:
- ...
+ def send_close_prepared(self, name: bytes) -> None: ...
- def close_portal(self, name: bytes) -> "PGresult":
- ...
+ def close_portal(self, name: bytes) -> "PGresult": ...
- def send_close_portal(self, name: bytes) -> None:
- ...
+ def send_close_portal(self, name: bytes) -> None: ...
- def get_result(self) -> Optional["PGresult"]:
- ...
+ def get_result(self) -> Optional["PGresult"]: ...
- def consume_input(self) -> None:
- ...
+ def consume_input(self) -> None: ...
- def is_busy(self) -> int:
- ...
+ def is_busy(self) -> int: ...
@property
- def nonblocking(self) -> int:
- ...
+ def nonblocking(self) -> int: ...
@nonblocking.setter
- def nonblocking(self, arg: int) -> None:
- ...
+ def nonblocking(self, arg: int) -> None: ...
- def flush(self) -> int:
- ...
+ def flush(self) -> int: ...
- def set_single_row_mode(self) -> None:
- ...
+ def set_single_row_mode(self) -> None: ...
- def get_cancel(self) -> "PGcancel":
- ...
+ def get_cancel(self) -> "PGcancel": ...
- def notifies(self) -> Optional["PGnotify"]:
- ...
+ def notifies(self) -> Optional["PGnotify"]: ...
- def put_copy_data(self, buffer: Buffer) -> int:
- ...
+ def put_copy_data(self, buffer: Buffer) -> int: ...
- def put_copy_end(self, error: Optional[bytes] = None) -> int:
- ...
+ def put_copy_end(self, error: Optional[bytes] = None) -> int: ...
- def get_copy_data(self, async_: int) -> Tuple[int, memoryview]:
- ...
+ def get_copy_data(self, async_: int) -> Tuple[int, memoryview]: ...
- def trace(self, fileno: int) -> None:
- ...
+ def trace(self, fileno: int) -> None: ...
- def set_trace_flags(self, flags: Trace) -> None:
- ...
+ def set_trace_flags(self, flags: Trace) -> None: ...
- def untrace(self) -> None:
- ...
+ def untrace(self) -> None: ...
def encrypt_password(
self, passwd: bytes, user: bytes, algorithm: Optional[bytes] = None
- ) -> bytes:
- ...
+ ) -> bytes: ...
- def make_empty_result(self, exec_status: int) -> "PGresult":
- ...
+ def make_empty_result(self, exec_status: int) -> "PGresult": ...
@property
- def pipeline_status(self) -> int:
- ...
+ def pipeline_status(self) -> int: ...
- def enter_pipeline_mode(self) -> None:
- ...
+ def enter_pipeline_mode(self) -> None: ...
- def exit_pipeline_mode(self) -> None:
- ...
+ def exit_pipeline_mode(self) -> None: ...
- def pipeline_sync(self) -> None:
- ...
+ def pipeline_sync(self) -> None: ...
- def send_flush_request(self) -> None:
- ...
+ def send_flush_request(self) -> None: ...
class PGresult(Protocol):
- def clear(self) -> None:
- ...
+ def clear(self) -> None: ...
@property
- def status(self) -> int:
- ...
+ def status(self) -> int: ...
@property
- def error_message(self) -> bytes:
- ...
+ def error_message(self) -> bytes: ...
- def error_field(self, fieldcode: int) -> Optional[bytes]:
- ...
+ def error_field(self, fieldcode: int) -> Optional[bytes]: ...
@property
- def ntuples(self) -> int:
- ...
+ def ntuples(self) -> int: ...
@property
- def nfields(self) -> int:
- ...
+ def nfields(self) -> int: ...
- def fname(self, column_number: int) -> Optional[bytes]:
- ...
+ def fname(self, column_number: int) -> Optional[bytes]: ...
- def ftable(self, column_number: int) -> int:
- ...
+ def ftable(self, column_number: int) -> int: ...
- def ftablecol(self, column_number: int) -> int:
- ...
+ def ftablecol(self, column_number: int) -> int: ...
- def fformat(self, column_number: int) -> int:
- ...
+ def fformat(self, column_number: int) -> int: ...
- def ftype(self, column_number: int) -> int:
- ...
+ def ftype(self, column_number: int) -> int: ...
- def fmod(self, column_number: int) -> int:
- ...
+ def fmod(self, column_number: int) -> int: ...
- def fsize(self, column_number: int) -> int:
- ...
+ def fsize(self, column_number: int) -> int: ...
@property
- def binary_tuples(self) -> int:
- ...
+ def binary_tuples(self) -> int: ...
- def get_value(self, row_number: int, column_number: int) -> Optional[bytes]:
- ...
+ def get_value(self, row_number: int, column_number: int) -> Optional[bytes]: ...
@property
- def nparams(self) -> int:
- ...
+ def nparams(self) -> int: ...
- def param_type(self, param_number: int) -> int:
- ...
+ def param_type(self, param_number: int) -> int: ...
@property
- def command_status(self) -> Optional[bytes]:
- ...
+ def command_status(self) -> Optional[bytes]: ...
@property
- def command_tuples(self) -> Optional[int]:
- ...
+ def command_tuples(self) -> Optional[int]: ...
@property
- def oid_value(self) -> int:
- ...
+ def oid_value(self) -> int: ...
- def set_attributes(self, descriptions: List["PGresAttDesc"]) -> None:
- ...
+ def set_attributes(self, descriptions: List["PGresAttDesc"]) -> None: ...
class PGcancel(Protocol):
- def free(self) -> None:
- ...
+ def free(self) -> None: ...
- def cancel(self) -> None:
- ...
+ def cancel(self) -> None: ...
class Conninfo(Protocol):
@classmethod
- def get_defaults(cls) -> List["ConninfoOption"]:
- ...
+ def get_defaults(cls) -> List["ConninfoOption"]: ...
@classmethod
- def parse(cls, conninfo: bytes) -> List["ConninfoOption"]:
- ...
+ def parse(cls, conninfo: bytes) -> List["ConninfoOption"]: ...
@classmethod
- def _options_from_array(cls, opts: Sequence[Any]) -> List["ConninfoOption"]:
- ...
+ def _options_from_array(cls, opts: Sequence[Any]) -> List["ConninfoOption"]: ...
class Escaping(Protocol):
- def __init__(self, conn: Optional[PGconn] = None):
- ...
+ def __init__(self, conn: Optional[PGconn] = None): ...
- def escape_literal(self, data: Buffer) -> bytes:
- ...
+ def escape_literal(self, data: Buffer) -> bytes: ...
- def escape_identifier(self, data: Buffer) -> bytes:
- ...
+ def escape_identifier(self, data: Buffer) -> bytes: ...
- def escape_string(self, data: Buffer) -> bytes:
- ...
+ def escape_string(self, data: Buffer) -> bytes: ...
- def escape_bytea(self, data: Buffer) -> bytes:
- ...
+ def escape_bytea(self, data: Buffer) -> bytes: ...
- def unescape_bytea(self, data: Buffer) -> bytes:
- ...
+ def unescape_bytea(self, data: Buffer) -> bytes: ...
from typing import Any, Callable, Dict, List, Optional, NamedTuple, NoReturn
from typing import TYPE_CHECKING, Protocol, Sequence, Tuple, Type
from collections import namedtuple
-from typing_extensions import TypeAlias
from . import pq
from . import errors as e
-from ._compat import TypeVar
+from ._compat import TypeAlias, TypeVar
from ._encodings import _as_python_identifier
if TYPE_CHECKING:
Typically, `!RowMaker` functions are returned by `RowFactory`.
"""
- def __call__(self, __values: Sequence[Any]) -> Row:
- ...
+ def __call__(self, __values: Sequence[Any]) -> Row: ...
class RowFactory(Protocol[Row]):
use the values to create a dictionary for each record.
"""
- def __call__(self, __cursor: "Cursor[Any]") -> RowMaker[Row]:
- ...
+ def __call__(self, __cursor: "Cursor[Any]") -> RowMaker[Row]: ...
class AsyncRowFactory(Protocol[Row]):
Like `RowFactory`, taking an async cursor as argument.
"""
- def __call__(self, __cursor: "AsyncCursor[Any]") -> RowMaker[Row]:
- ...
+ def __call__(self, __cursor: "AsyncCursor[Any]") -> RowMaker[Row]: ...
class BaseRowFactory(Protocol[Row]):
Like `RowFactory`, taking either type of cursor as argument.
"""
- def __call__(self, __cursor: "BaseCursor[Any, Any]") -> RowMaker[Row]:
- ...
+ def __call__(self, __cursor: "BaseCursor[Any, Any]") -> RowMaker[Row]: ...
TupleRow: TypeAlias = Tuple[Any, ...]
return kwargs_row_
+def scalar_row(cursor: "BaseCursor[Any, Any]") -> "RowMaker[Any]":
+ """
+ Generate a row factory returning the first column
+ as a scalar value.
+ """
+ res = cursor.pgresult
+ if not res:
+ return no_result
+
+ nfields = _get_nfields(res)
+ if nfields is None:
+ return no_result
+
+ if nfields < 1:
+ raise e.ProgrammingError("at least one column expected")
+
+ def scalar_row_(values: Sequence[Any]) -> Any:
+ return values[0]
+
+ return scalar_row_
+
+
def no_result(values: Sequence[Any]) -> NoReturn:
"""A `RowMaker` that always fail.
*,
scrollable: Optional[bool] = None,
withhold: bool = False,
- ):
- ...
+ ): ...
@overload
def __init__(
row_factory: RowFactory[Row],
scrollable: Optional[bool] = None,
withhold: bool = False,
- ):
- ...
+ ): ...
def __init__(
self,
*,
scrollable: Optional[bool] = None,
withhold: bool = False,
- ):
- ...
+ ): ...
@overload
def __init__(
row_factory: AsyncRowFactory[Row],
scrollable: Optional[bool] = None,
withhold: bool = False,
- ):
- ...
+ ): ...
def __init__(
self,
return f"{self.__class__.__name__}({self._obj!r})"
@abstractmethod
- def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
+ def as_bytes(self, context: Optional[AdaptContext] = None) -> bytes:
"""
Return the value of the object as bytes.
"""
raise NotImplementedError
- def as_string(self, context: Optional[AdaptContext]) -> str:
+ def as_string(self, context: Optional[AdaptContext] = None) -> str:
"""
Return the value of the object as string.
seq = [obj if isinstance(obj, Composable) else Literal(obj) for obj in seq]
super().__init__(seq)
- def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
+ def as_bytes(self, context: Optional[AdaptContext] = None) -> bytes:
return b"".join(obj.as_bytes(context) for obj in self._obj)
def __iter__(self) -> Iterator[Composable]:
if not isinstance(obj, str):
raise TypeError(f"SQL values must be strings, got {obj!r} instead")
- def as_string(self, context: Optional[AdaptContext]) -> str:
+ def as_string(self, context: Optional[AdaptContext] = None) -> str:
return self._obj
- def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
+ def as_bytes(self, context: Optional[AdaptContext] = None) -> bytes:
conn = context.connection if context else None
enc = conn_encoding(conn)
return self._obj.encode(enc)
def __repr__(self) -> str:
return f"{self.__class__.__name__}({', '.join(map(repr, self._obj))})"
- def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
+ def as_bytes(self, context: Optional[AdaptContext] = None) -> bytes:
conn = context.connection if context else None
- if not conn:
- raise ValueError("a connection is necessary for Identifier")
- esc = Escaping(conn.pgconn)
- enc = conn_encoding(conn)
- escs = [esc.escape_identifier(s.encode(enc)) for s in self._obj]
+ if conn:
+ esc = Escaping(conn.pgconn)
+ enc = conn_encoding(conn)
+ escs = [esc.escape_identifier(s.encode(enc)) for s in self._obj]
+ else:
+ escs = [self._escape_identifier(s.encode()) for s in self._obj]
return b".".join(escs)
+ def _escape_identifier(self, s: bytes) -> bytes:
+ """
+ Approximation of PQescapeIdentifier taking no connection.
+ """
+ return b'"' + s.replace(b'"', b'""') + b'"'
+
class Literal(Composable):
"""
"""
- def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
+ def as_bytes(self, context: Optional[AdaptContext] = None) -> bytes:
tx = Transformer.from_context(context)
return tx.as_literal(self._obj)
return f"{self.__class__.__name__}({', '.join(parts)})"
- def as_string(self, context: Optional[AdaptContext]) -> str:
+ def as_string(self, context: Optional[AdaptContext] = None) -> str:
code = self._format.value
return f"%({self._obj}){code}" if self._obj else f"%{code}"
- def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
+ def as_bytes(self, context: Optional[AdaptContext] = None) -> bytes:
conn = context.connection if context else None
enc = conn_encoding(conn)
return self.as_string(context).encode(enc)
"""
Adapters for the enum type.
"""
+
from enum import Enum
from typing import Any, Dict, Generic, Optional, Mapping, Sequence
from typing import Tuple, Type, Union, cast, TYPE_CHECKING
-from typing_extensions import TypeAlias
from .. import sql
from .. import postgres
from ..pq import Format
from ..abc import AdaptContext, Query
from ..adapt import Buffer, Dumper, Loader
-from .._compat import cache, TypeVar
+from .._compat import cache, TypeAlias, TypeVar
from .._encodings import conn_encoding
from .._typeinfo import TypeInfo
import re
from typing import Dict, List, Optional, Type
-from typing_extensions import TypeAlias
from .. import errors as e
from .. import postgres
from ..abc import Buffer, AdaptContext
from .._oids import TEXT_OID
from ..adapt import PyFormat, RecursiveDumper, RecursiveLoader
-from .._compat import cache
+from .._compat import cache, TypeAlias
from .._typeinfo import TypeInfo
_re_escape = re.compile(r'(["\\])')
return f"{{{', '.join(map(str, self._ranges))}}}"
@overload
- def __getitem__(self, index: int) -> Range[T]:
- ...
+ def __getitem__(self, index: int) -> Range[T]: ...
@overload
- def __getitem__(self, index: slice) -> "Multirange[T]":
- ...
+ def __getitem__(self, index: slice) -> "Multirange[T]": ...
def __getitem__(self, index: Union[int, slice]) -> "Union[Range[T],Multirange[T]]":
if isinstance(index, int):
return len(self._ranges)
@overload
- def __setitem__(self, index: int, value: Range[T]) -> None:
- ...
+ def __setitem__(self, index: int, value: Range[T]) -> None: ...
@overload
- def __setitem__(self, index: slice, value: Iterable[Range[T]]) -> None:
- ...
+ def __setitem__(self, index: slice, value: Iterable[Range[T]]) -> None: ...
def __setitem__(
self,
# Copyright (C) 2020 The Psycopg Team
from typing import Callable, Optional, Type, Union, TYPE_CHECKING
-from typing_extensions import TypeAlias
from .. import _oids
from ..pq import Format
from ..abc import AdaptContext
from ..adapt import Buffer, Dumper, Loader
+from .._compat import TypeAlias
if TYPE_CHECKING:
import ipaddress
_MixedNumericDumper.int_classes = int
@abstractmethod
- def dump(self, obj: Union[Decimal, int, "numpy.integer[Any]"]) -> Buffer:
- ...
+ def dump(self, obj: Union[Decimal, int, "numpy.integer[Any]"]) -> Buffer: ...
class NumericDumper(_MixedNumericDumper):
WAIT_R = Wait.R
WAIT_W = Wait.W
WAIT_RW = Wait.RW
+READY_NONE = Ready.NONE
READY_R = Ready.R
READY_W = Ready.W
READY_RW = Ready.RW
try:
s = next(gen)
with DefaultSelector() as sel:
+ sel.register(fileno, s)
while True:
- sel.register(fileno, s)
- rlist = None
- while not rlist:
- rlist = sel.select(timeout=timeout)
+ rlist = sel.select(timeout=timeout)
+ if not rlist:
+ gen.send(READY_NONE)
+ continue
+
sel.unregister(fileno)
- # note: this line should require a cast, but mypy doesn't complain
- ready: Ready = rlist[0][1]
- assert s & ready
+ ready = rlist[0][1]
s = gen.send(ready)
+ sel.register(fileno, s)
except StopIteration as ex:
rv: RV = ex.args[0] if ex.args else None
sel.unregister(fileno)
if not rlist:
raise e.ConnectionTimeout("connection timeout expired")
- ready: Ready = rlist[0][1] # type: ignore[assignment]
+ ready = rlist[0][1]
fileno, s = gen.send(ready)
except StopIteration as ex:
`Ready` values when it would block.
:param fileno: the file descriptor to wait on.
:param timeout: timeout (in seconds) to check for other interrupt, e.g.
- to allow Ctrl-C. If zero or None, wait indefinitely.
+ to allow Ctrl-C. If zero, wait indefinitely.
:return: whatever `!gen` returns on completion.
Behave like in `wait()`, but exposing an `asyncio` interface.
# Not sure this is the best implementation but it's a start.
ev = Event()
loop = get_event_loop()
- ready: Ready
+ ready: int
s: Wait
def wakeup(state: Ready) -> None:
nonlocal ready
- ready |= state # type: ignore[assignment]
+ ready |= state
ev.set()
try:
if not reader and not writer:
raise e.InternalError(f"bad poll status: {s}")
ev.clear()
- ready = 0 # type: ignore[assignment]
+ ready = 0
if reader:
loop.add_reader(fileno, wakeup, READY_R)
if writer:
loop.add_writer(fileno, wakeup, READY_W)
try:
- if timeout is None:
- await ev.wait()
- else:
+ if timeout is not None:
try:
await wait_for(ev.wait(), timeout)
except TimeoutError:
pass
+ else:
+ await ev.wait()
finally:
if reader:
loop.remove_reader(fileno)
loop.remove_writer(fileno)
s = gen.send(ready)
+ except OSError as ex:
+ # Assume the connection was closed
+ raise e.OperationalError(str(ex))
except StopIteration as ex:
rv: RV = ex.args[0] if ex.args else None
return rv
if wl:
ready |= READY_W
if not ready:
+ gen.send(READY_NONE)
continue
- # assert s & ready
- s = gen.send(ready) # type: ignore
+
+ s = gen.send(ready)
except StopIteration as ex:
rv: RV = ex.args[0] if ex.args else None
s = next(gen)
if timeout is None or timeout < 0:
- timeout = 0
- else:
- timeout = int(timeout * 1000.0)
+ timeout = 0.0
with select.epoll() as epoll:
evmask = _epoll_evmasks[s]
epoll.register(fileno, evmask)
while True:
- fileevs = None
- while not fileevs:
- fileevs = epoll.poll(timeout)
+ fileevs = epoll.poll(timeout)
+ if not fileevs:
+ gen.send(READY_NONE)
+ continue
ev = fileevs[0][1]
ready = 0
if ev & ~select.EPOLLOUT:
ready = READY_R
if ev & ~select.EPOLLIN:
ready |= READY_W
- # assert s & ready
s = gen.send(ready)
evmask = _epoll_evmasks[s]
epoll.modify(fileno, evmask)
evmask = _poll_evmasks[s]
poll.register(fileno, evmask)
while True:
- fileevs = None
- while not fileevs:
- fileevs = poll.poll(timeout)
+ fileevs = poll.poll(timeout)
+ if not fileevs:
+ gen.send(READY_NONE)
+ continue
+
ev = fileevs[0][1]
ready = 0
if ev & ~select.POLLOUT:
ready = READY_R
if ev & ~select.POLLIN:
ready |= READY_W
- # assert s & ready
s = gen.send(ready)
evmask = _poll_evmasks[s]
poll.modify(fileno, evmask)
pytest-randomly >= 3.5
dev =
ast-comments >= 1.1.2
- black >= 23.1.0
+ black >= 24.1.0
codespell >= 2.2
dnspython >= 2.1
flake8 >= 4.0
[flake8]
max-line-length = 88
-ignore = W503, E203
+ignore = W503, E203, E704
cdef object WAIT_W = Wait.W
cdef object WAIT_R = Wait.R
cdef object WAIT_RW = Wait.RW
+cdef object PY_READY_NONE = Ready.NONE
cdef object PY_READY_R = Ready.R
cdef object PY_READY_W = Ready.W
cdef object PY_READY_RW = Ready.RW
+cdef int READY_NONE = Ready.NONE
cdef int READY_R = Ready.R
cdef int READY_W = Ready.W
cdef int READY_RW = Ready.RW
to retrieve the results available.
"""
cdef libpq.PGconn *pgconn_ptr = pgconn._pgconn_ptr
- cdef int status
+ cdef int ready
cdef int cires
while True:
if pgconn.flush() == 0:
break
- status = yield WAIT_RW
- if status & READY_R:
+ while True:
+ ready = yield WAIT_RW
+ if ready:
+ break
+
+ if ready & READY_R:
with nogil:
# This call may read notifies which will be saved in the
# PGconn buffer and passed to Python later.
cdef libpq.PGconn *pgconn_ptr = pgconn._pgconn_ptr
cdef int cires, ibres
cdef libpq.PGresult *pgres
+ cdef object ready
with nogil:
ibres = libpq.PQisBusy(pgconn_ptr)
if ibres:
- yield WAIT_R
+ while True:
+ ready = yield WAIT_R
+ if ready:
+ break
+
while True:
with nogil:
cires = libpq.PQconsumeInput(pgconn_ptr)
f"consuming input failed: {error_message(pgconn)}")
if not ibres:
break
- yield WAIT_R
+ while True:
+ ready = yield WAIT_R
+ if ready:
+ break
_consume_notifies(pgconn)
cdef pq.PGresult r
while True:
- ready = yield WAIT_RW
+ while True:
+ ready = yield WAIT_RW
+ if ready:
+ break
if ready & READY_R:
with nogil:
wait_c_impl(int fileno, int wait, float timeout)
{
int select_rv;
- int rv = 0;
+ int rv = -1;
#if defined(HAVE_POLL) && !defined(HAVE_BROKEN_POLL)
goto retry_eintr;
}
- if (select_rv < 0) { goto error; }
if (PyErr_CheckSignals()) { goto finally; }
+ if (select_rv < 0) { goto finally; } /* poll error */
- if (input_fd.events & POLLIN) { rv |= SELECT_EV_READ; }
- if (input_fd.events & POLLOUT) { rv |= SELECT_EV_WRITE; }
+ rv = 0; /* success, maybe with timeout */
+ if (select_rv >= 0) {
+ if (input_fd.events & POLLIN) { rv |= SELECT_EV_READ; }
+ if (input_fd.events & POLLOUT) { rv |= SELECT_EV_WRITE; }
+ }
#else
goto retry_eintr;
}
- if (select_rv < 0) { goto error; }
if (PyErr_CheckSignals()) { goto finally; }
+ if (select_rv < 0) { goto error; } /* select error */
- if (FD_ISSET(fileno, &ifds)) { rv |= SELECT_EV_READ; }
- if (FD_ISSET(fileno, &ofds)) { rv |= SELECT_EV_WRITE; }
+ rv = 0;
+ if (select_rv > 0) {
+ if (FD_ISSET(fileno, &ifds)) { rv |= SELECT_EV_READ; }
+ if (FD_ISSET(fileno, &ofds)) { rv |= SELECT_EV_WRITE; }
+ }
#endif /* HAVE_POLL */
error:
+ rv = -1;
+
#ifdef MS_WINDOWS
if (select_rv == SOCKET_ERROR) {
PyErr_SetExcFromWindowsErr(PyExc_OSError, WSAGetLastError());
finally:
- return -1;
+ return rv;
}
"""
while True:
ready = wait_c_impl(fileno, wait, ctimeout)
- if ready == 0:
- continue
+ if ready == READY_NONE:
+ pyready = <PyObject *>PY_READY_NONE
elif ready == READY_R:
pyready = <PyObject *>PY_READY_R
elif ready == READY_RW:
# Copyright (C) 2021 The Psycopg Team
+from libc.stdint cimport int64_t
from libc.string cimport memset, strchr
from cpython cimport datetime as cdt
from cpython.dict cimport PyDict_GetItem
if length != 10:
self._error_date(data, "unexpected length")
- cdef int vals[3]
+ cdef int64_t vals[3]
memset(vals, 0, sizeof(vals))
cdef const char *ptr
cdef object cload(self, const char *data, size_t length):
- cdef int vals[3]
+ cdef int64_t vals[3]
memset(vals, 0, sizeof(vals))
cdef const char *ptr
cdef const char *end = data + length
cdef object cload(self, const char *data, size_t length):
- cdef int vals[3]
+ cdef int64_t vals[3]
memset(vals, 0, sizeof(vals))
cdef const char *ptr
cdef const char *end = data + length
if self._order == ORDER_PGDM or self._order == ORDER_PGMD:
return self._cload_pg(data, end)
- cdef int vals[6]
+ cdef int64_t vals[6]
memset(vals, 0, sizeof(vals))
cdef const char *ptr
raise _get_timestamp_load_error(self._pgconn, data, ex) from None
cdef object _cload_pg(self, const char *data, const char *end):
- cdef int vals[4]
+ cdef int64_t vals[4]
memset(vals, 0, sizeof(vals))
cdef const char *ptr
if end[-1] == b'C': # ends with BC
raise _get_timestamp_load_error(self._pgconn, data) from None
- cdef int vals[6]
+ cdef int64_t vals[6]
memset(vals, 0, sizeof(vals))
# Parse the first 6 groups of digits (date and time)
if self._style == INTERVALSTYLE_OTHERS:
return self._cload_notimpl(data, length)
- cdef int days = 0, secs = 0, us = 0
+ cdef int days = 0, us = 0
+ cdef int64_t secs = 0
cdef char sign
- cdef int val
+ cdef int64_t val
cdef const char *ptr = data
cdef const char *sep
cdef const char *end = ptr + length
break
# Parse the time part. An eventual sign was already consumed in the loop
- cdef int vals[3]
+ cdef int64_t vals[3]
memset(vals, 0, sizeof(vals))
if ptr != NULL:
ptr = _parse_date_values(ptr, end, vals, ARRAYSIZE(vals))
secs = vals[2] + 60 * (vals[1] + 60 * vals[0])
+ if secs > 86_400:
+ days += secs // 86_400
+ secs %= 86_400
+
if ptr[0] == b'.':
ptr = _parse_micros(ptr + 1, &us)
# Work only with positive values as the cdivision behaves differently
# with negative values, and cdivision=False adds overhead.
cdef int64_t aval = val if val >= 0 else -val
- cdef int us, ussecs, usdays
+ cdef int64_t us, ussecs, usdays
- # Group the micros in biggers stuff or timedelta_new might overflow
+ # Group the micros in bigger stuff or timedelta_new might overflow
with cython.cdivision(True):
- ussecs = <int>(aval // 1_000_000)
+ ussecs = <int64_t>(aval // 1_000_000)
us = aval % 1_000_000
usdays = ussecs // 86_400
cdef const char *_parse_date_values(
- const char *ptr, const char *end, int *vals, int nvals
+ const char *ptr, const char *end, int64_t *vals, int nvals
):
"""
Parse *nvals* numeric values separated by non-numeric chars.
cdef char sgn = ptr[0]
# Parse at most three groups of digits
- cdef int vals[3]
+ cdef int64_t vals[3]
memset(vals, 0, sizeof(vals))
ptr = _parse_date_values(ptr + 1, end, vals, ARRAYSIZE(vals))
[flake8]
max-line-length = 88
-ignore = W503, E203
+ignore = W503, E203, E704
import threading
from typing import Any, Callable, Coroutine, TYPE_CHECKING
-from typing_extensions import TypeAlias
-
-from ._compat import TypeVar
+from ._compat import TypeAlias, TypeVar
logger = logging.getLogger("psycopg.pool")
T = TypeVar("T")
else:
from typing import Counter, Deque
+if sys.version_info >= (3, 10):
+ from typing import TypeAlias
+else:
+ from typing_extensions import TypeAlias
+
if sys.version_info >= (3, 11):
from typing import Self
else:
"Counter",
"Deque",
"Self",
+ "TypeAlias",
"TypeVar",
]
from typing import Any, Awaitable, Callable, Union, TYPE_CHECKING
-from typing_extensions import TypeAlias
-
-from ._compat import TypeVar
+from ._compat import TypeAlias, TypeVar
if TYPE_CHECKING:
from .pool import ConnectionPool
class NullConnectionPool(_BaseNullConnectionPool, ConnectionPool[CT]):
+
def __init__(
self,
conninfo: str = "",
reconnect_failed: Optional[ConnectFailedCB] = None,
num_workers: int = 3,
): # Note: min_size default value changed to 0.
+
super().__init__(
conninfo,
open=open,
pool.run_task(self)
@abstractmethod
- def _run(self, pool: ConnectionPool[Any]) -> None:
- ...
+ def _run(self, pool: ConnectionPool[Any]) -> None: ...
class StopWorker(MaintenanceTask):
class AddConnection(MaintenanceTask):
+
def __init__(
self,
pool: ConnectionPool[Any],
pool.run_task(self)
@abstractmethod
- async def _run(self, pool: AsyncConnectionPool[Any]) -> None:
- ...
+ async def _run(self, pool: AsyncConnectionPool[Any]) -> None: ...
class StopWorker(MaintenanceTask):
class Scheduler:
+
def __init__(self) -> None:
self._queue: List[Task] = []
self._lock = Lock()
asyncio_options: Dict[str, Any] = {}
if sys.platform == "win32":
- asyncio_options[
- "loop_factory"
- ] = asyncio.WindowsSelectorEventLoopPolicy().new_event_loop
+ asyncio_options["loop_factory"] = (
+ asyncio.WindowsSelectorEventLoopPolicy().new_event_loop
+ )
@pytest.fixture(
pytest-randomly == 3.5.0
# From the 'dev' extra
-black == 23.1.0
+black == 24.1.0
dnspython == 2.1.0
flake8 == 4.0.0
types-setuptools == 57.4.0
-- Ian Bicking
'''
+from __future__ import annotations
+
__rcs_id__ = '$Id: dbapi20.py,v 1.11 2005/01/02 02:41:01 zenzen Exp $'
__version__ = '$Revision: 1.12 $'[11:-2]
__author__ = 'Stuart Bishop <stuart@stuartbishop.net>'
# method is to be found
driver: Any = None
connect_args = () # List of arguments to pass to connect
- connect_kw_args: Dict[str, Any] = {} # Keyword arguments for connect
+ connect_kw_args: Dict[Any, Any] = {} # Keyword arguments for connect
table_prefix = 'dbapi20test_' # If you need to specify a prefix for tables
ddl1 = 'create table %sbooze (name varchar(20))' % table_prefix
continue
args[opt.keyword.decode()] = os.environ[opt.envvar.decode()]
- return make_conninfo(**args)
+ return make_conninfo("", **args)
@pytest.fixture(scope="session")
cdict = conninfo.conninfo_to_dict(server_dsn)
# Get server params
- host = cdict.get("host") or os.environ.get("PGHOST")
+ host = cdict.get("host") or os.environ.get("PGHOST", "")
+ assert isinstance(host, str)
self.server_host = host if host and not host.startswith("/") else "localhost"
self.server_port = cdict.get("port") or os.environ.get("PGPORT", "5432")
cdict["host"] = self.client_host
cdict["port"] = self.client_port
cdict["sslmode"] = "disable" # not supported by the proxy
- self.client_dsn = conninfo.make_conninfo(**cdict)
+ self.client_dsn = conninfo.make_conninfo("", **cdict)
# The running proxy process
self.proc = None
@pytest.fixture
-@pytest.mark.crdb_skip("2-phase commit")
def tpc(svcconn):
tpc = Tpc(svcconn)
tpc.check_tpc()
class MyRow(Dict[str, Any]):
- ...
+ pass
def test_generic_connection_type(dsn):
+
def configure(conn: psycopg.Connection[Any]) -> None:
set_autocommit(conn, True)
def test_non_generic_connection_type(dsn):
+
def configure(conn: psycopg.Connection[Any]) -> None:
set_autocommit(conn, True)
class MyConnection(psycopg.Connection[MyRow]):
+
def __init__(self, *args: Any, **kwargs: Any):
kwargs["row_factory"] = class_row(MyRow)
super().__init__(*args, **kwargs)
@pytest.mark.slow
@pytest.mark.timing
def test_resize(dsn):
+
def sampler():
sleep(0.05) # ensure sampling happens after shrink check
while True:
class MyRow(Dict[str, Any]):
- ...
+ pass
async def test_generic_connection_type(dsn):
def test_connection_class(pool_cls, dsn):
+
class MyConn(psycopg.Connection[Any]):
pass
@pytest.mark.timing
@pytest.mark.crdb_skip("backend pid")
def test_queue(pool_cls, dsn):
+
def worker(n):
t0 = time()
with p.connection() as conn:
@pytest.mark.slow
def test_queue_size(pool_cls, dsn):
+
def worker(t, ev=None):
try:
with p.connection():
@pytest.mark.timing
@pytest.mark.crdb_skip("backend pid")
def test_queue_timeout(pool_cls, dsn):
+
def worker(n):
t0 = time()
try:
@pytest.mark.slow
@pytest.mark.timing
def test_dead_client(pool_cls, dsn):
+
def worker(i, timeout):
try:
with p.connection(timeout=timeout) as conn:
@pytest.mark.timing
@pytest.mark.crdb_skip("backend pid")
def test_queue_timeout_override(pool_cls, dsn):
+
def worker(n):
t0 = time()
timeout = 0.25 if n == 3 else None
def test_closed_queue(pool_cls, dsn):
+
def w1():
with p.connection() as conn:
e1.set() # Tell w0 that w1 got a connection
@pytest.mark.slow
@pytest.mark.timing
def test_stats_measures(pool_cls, dsn):
+
def worker(n):
with p.connection() as conn:
conn.execute("select pg_sleep(0.2)")
@pytest.mark.slow
@pytest.mark.timing
def test_stats_usage(pool_cls, dsn):
+
def worker(n):
try:
with p.connection(timeout=0.3) as conn:
@pytest.mark.slow
def test_check_timeout(pool_cls, dsn):
+
def check(conn):
raise Exception()
class MyRow(Dict[str, Any]):
- ...
+ pass
def test_generic_connection_type(dsn):
+
def configure(conn: psycopg.Connection[Any]) -> None:
set_autocommit(conn, True)
def test_non_generic_connection_type(dsn):
+
def configure(conn: psycopg.Connection[Any]) -> None:
set_autocommit(conn, True)
class MyConnection(psycopg.Connection[MyRow]):
+
def __init__(self, *args: Any, **kwargs: Any):
kwargs["row_factory"] = class_row(MyRow)
super().__init__(*args, **kwargs)
class MyRow(Dict[str, Any]):
- ...
+ pass
async def test_generic_connection_type(dsn):
"""
A quick and rough performance comparison of text vs. binary Decimal adaptation
"""
+
from random import randrange
from decimal import Decimal
import psycopg
handled by execute() calls when pgconn socket is read-ready, which
happens when the output buffer is full.
"""
+
import argparse
import asyncio
import logging
import subprocess as sp
from asyncio import create_task
from asyncio.queues import Queue
-from typing import List, Tuple
+from typing import List
import pytest
assert time.time() - t0 < 0.8, "something broken in concurrency"
-@pytest.mark.slow
-@pytest.mark.timing
-@pytest.mark.crdb_skip("notify")
-async def test_notifies(aconn_cls, aconn, dsn):
- nconn = await aconn_cls.connect(dsn, autocommit=True)
- npid = nconn.pgconn.backend_pid
-
- async def notifier():
- cur = nconn.cursor()
- await asyncio.sleep(0.25)
- await cur.execute("notify foo, '1'")
- await asyncio.sleep(0.25)
- await cur.execute("notify foo, '2'")
- await nconn.close()
-
- async def receiver():
- await aconn.set_autocommit(True)
- cur = aconn.cursor()
- await cur.execute("listen foo")
- gen = aconn.notifies()
- async for n in gen:
- ns.append((n, time.time()))
- if len(ns) >= 2:
- await gen.aclose()
-
- ns: List[Tuple[psycopg.Notify, float]] = []
- t0 = time.time()
- workers = [notifier(), receiver()]
- await asyncio.gather(*workers)
- assert len(ns) == 2
-
- n, t1 = ns[0]
- assert n.pid == npid
- assert n.channel == "foo"
- assert n.payload == "1"
- assert t1 - t0 == pytest.approx(0.25, abs=0.05)
-
- n, t1 = ns[1]
- assert n.pid == npid
- assert n.channel == "foo"
- assert n.payload == "2"
- assert t1 - t0 == pytest.approx(0.5, abs=0.05)
-
-
async def canceller(aconn, errors):
try:
await asyncio.sleep(0.5)
from typing import Any, List
import psycopg
-from psycopg import Notify, pq, errors as e
+from psycopg import pq, errors as e
from psycopg.rows import tuple_row
from psycopg.conninfo import conninfo_to_dict, timeout_from_conninfo
def test_connect_str_subclass(conn_cls, dsn):
+
class MyString(str):
pass
],
)
def test_connect_badargs(conn_cls, monkeypatch, pgconn, args, kwargs, exctype):
+
def fake_connect(conninfo):
return pgconn
yield
conn.remove_notice_handler(cb1)
-@pytest.mark.crdb_skip("notify")
-def test_notify_handlers(conn):
- nots1 = []
- nots2 = []
-
- def cb1(n):
- nots1.append(n)
-
- conn.add_notify_handler(cb1)
- conn.add_notify_handler(lambda n: nots2.append(n))
-
- conn.set_autocommit(True)
- cur = conn.cursor()
- cur.execute("listen foo")
- cur.execute("notify foo, 'n1'")
-
- assert len(nots1) == 1
- n = nots1[0]
- assert n.channel == "foo"
- assert n.payload == "n1"
- assert n.pid == conn.pgconn.backend_pid
-
- assert len(nots2) == 1
- assert nots2[0] == nots1[0]
-
- conn.remove_notify_handler(cb1)
- cur.execute("notify foo, 'n2'")
-
- assert len(nots1) == 1
- assert len(nots2) == 2
- n = nots2[1]
- assert isinstance(n, Notify)
- assert n.channel == "foo"
- assert n.payload == "n2"
- assert n.pid == conn.pgconn.backend_pid
- assert hash(n)
-
- with pytest.raises(ValueError):
- conn.remove_notify_handler(cb1)
-
-
def test_execute(conn):
cur = conn.execute("select %s, %s", [10, 20])
assert cur.fetchone() == (10, 20)
def test_cursor_factory_connect(conn_cls, dsn):
+
class MyCursor(psycopg.Cursor[psycopg.rows.Row]):
pass
from typing import Any, List
import psycopg
-from psycopg import Notify, pq, errors as e
+from psycopg import pq, errors as e
from psycopg.rows import tuple_row
from psycopg.conninfo import conninfo_to_dict, timeout_from_conninfo
aconn.remove_notice_handler(cb1)
-@pytest.mark.crdb_skip("notify")
-async def test_notify_handlers(aconn):
- nots1 = []
- nots2 = []
-
- def cb1(n):
- nots1.append(n)
-
- aconn.add_notify_handler(cb1)
- aconn.add_notify_handler(lambda n: nots2.append(n))
-
- await aconn.set_autocommit(True)
- cur = aconn.cursor()
- await cur.execute("listen foo")
- await cur.execute("notify foo, 'n1'")
-
- assert len(nots1) == 1
- n = nots1[0]
- assert n.channel == "foo"
- assert n.payload == "n1"
- assert n.pid == aconn.pgconn.backend_pid
-
- assert len(nots2) == 1
- assert nots2[0] == nots1[0]
-
- aconn.remove_notify_handler(cb1)
- await cur.execute("notify foo, 'n2'")
-
- assert len(nots1) == 1
- assert len(nots2) == 2
- n = nots2[1]
- assert isinstance(n, Notify)
- assert n.channel == "foo"
- assert n.payload == "n2"
- assert n.pid == aconn.pgconn.backend_pid
- assert hash(n)
-
- with pytest.raises(ValueError):
- aconn.remove_notify_handler(cb1)
-
-
async def test_execute(aconn):
cur = await aconn.execute("select %s, %s", [10, 20])
assert await cur.fetchone() == (10, 20)
async with aconn.cursor() as cur:
assert isinstance(cur, MyCursor)
- async with (await aconn.execute("select 1")) as cur:
+ async with await aconn.execute("select 1") as cur:
assert isinstance(cur, MyCursor)
def test_conninfo_random_multi_ips(fake_resolve):
args = {"host": "alot.com"}
- hostaddrs = [att["hostaddr"] for att in conninfo_attempts(args)]
+ hostaddrs = [str(att["hostaddr"]) for att in conninfo_attempts(args)]
assert len(hostaddrs) == 20
assert hostaddrs == sorted(hostaddrs)
args["load_balance_hosts"] = "disable"
- hostaddrs = [att["hostaddr"] for att in conninfo_attempts(args)]
+ hostaddrs = [str(att["hostaddr"]) for att in conninfo_attempts(args)]
assert hostaddrs == sorted(hostaddrs)
args["load_balance_hosts"] = "random"
- hostaddrs = [att["hostaddr"] for att in conninfo_attempts(args)]
+ hostaddrs = [str(att["hostaddr"]) for att in conninfo_attempts(args)]
assert hostaddrs != sorted(hostaddrs)
async def test_conninfo_random_multi_ips(fake_resolve):
args = {"host": "alot.com"}
- hostaddrs = [att["hostaddr"] for att in await conninfo_attempts_async(args)]
+ hostaddrs = [str(att["hostaddr"]) for att in await conninfo_attempts_async(args)]
assert len(hostaddrs) == 20
assert hostaddrs == sorted(hostaddrs)
args["load_balance_hosts"] = "disable"
- hostaddrs = [att["hostaddr"] for att in await conninfo_attempts_async(args)]
+ hostaddrs = [str(att["hostaddr"]) for att in await conninfo_attempts_async(args)]
assert hostaddrs == sorted(hostaddrs)
args["load_balance_hosts"] = "random"
- hostaddrs = [att["hostaddr"] for att in await conninfo_attempts_async(args)]
+ hostaddrs = [str(att["hostaddr"]) for att in await conninfo_attempts_async(args)]
assert hostaddrs != sorted(hostaddrs)
BaseDumper = StrBinaryDumper # type: ignore
class MyStrDumper(BaseDumper):
+
def dump(self, obj):
return super().dump(obj) * 2
def test_worker_error_propagated(conn, monkeypatch):
+
def copy_to_broken(pgconn, buffer):
raise ZeroDivisionError
yield
class DataGenerator:
+
def __init__(self, conn, nrecs, srec, offset=0, block_size=8192):
self.conn = conn
self.nrecs = nrecs
def test_bad_row_factory(conn):
+
def broken_factory(cur):
1 / 0
cur.execute("select 1")
def broken_maker(cur):
+
def make_row(seq):
1 / 0
except KeyError:
info = conninfo_to_dict(dsn)
del info["password"] # should not raise per check above.
- dsn = make_conninfo(**info)
+ dsn = make_conninfo("", **info)
gen = generators.connect(dsn)
with pytest.raises(
--- /dev/null
+# WARNING: this file is auto-generated by 'async_to_sync.py'
+# from the original file 'test_notify_async.py'
+# DO NOT CHANGE! Change the original file instead.
+from __future__ import annotations
+
+from time import time
+
+import pytest
+from psycopg import Notify
+
+from .acompat import sleep, gather, spawn
+
+pytestmark = pytest.mark.crdb_skip("notify")
+
+
+def test_notify_handlers(conn):
+ nots1 = []
+ nots2 = []
+
+ def cb1(n):
+ nots1.append(n)
+
+ conn.add_notify_handler(cb1)
+ conn.add_notify_handler(lambda n: nots2.append(n))
+
+ conn.set_autocommit(True)
+ conn.execute("listen foo")
+ conn.execute("notify foo, 'n1'")
+
+ assert len(nots1) == 1
+ n = nots1[0]
+ assert n.channel == "foo"
+ assert n.payload == "n1"
+ assert n.pid == conn.pgconn.backend_pid
+
+ assert len(nots2) == 1
+ assert nots2[0] == nots1[0]
+
+ conn.remove_notify_handler(cb1)
+ conn.execute("notify foo, 'n2'")
+
+ assert len(nots1) == 1
+ assert len(nots2) == 2
+ n = nots2[1]
+ assert isinstance(n, Notify)
+ assert n.channel == "foo"
+ assert n.payload == "n2"
+ assert n.pid == conn.pgconn.backend_pid
+ assert hash(n)
+
+ with pytest.raises(ValueError):
+ conn.remove_notify_handler(cb1)
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+def test_notify(conn_cls, conn, dsn):
+ npid = None
+
+ def notifier():
+ with conn_cls.connect(dsn, autocommit=True) as nconn:
+ nonlocal npid
+ npid = nconn.pgconn.backend_pid
+
+ sleep(0.25)
+ nconn.execute("notify foo, '1'")
+ sleep(0.25)
+ nconn.execute("notify foo, '2'")
+
+ def receiver():
+ conn.set_autocommit(True)
+ cur = conn.cursor()
+ cur.execute("listen foo")
+ gen = conn.notifies()
+ for n in gen:
+ ns.append((n, time()))
+ if len(ns) >= 2:
+ gen.close()
+
+ ns: list[tuple[Notify, float]] = []
+ t0 = time()
+ workers = [spawn(notifier), spawn(receiver)]
+ gather(*workers)
+ assert len(ns) == 2
+
+ n, t1 = ns[0]
+ assert n.pid == npid
+ assert n.channel == "foo"
+ assert n.payload == "1"
+ assert t1 - t0 == pytest.approx(0.25, abs=0.05)
+
+ n, t1 = ns[1]
+ assert n.pid == npid
+ assert n.channel == "foo"
+ assert n.payload == "2"
+ assert t1 - t0 == pytest.approx(0.5, abs=0.05)
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+def test_no_notify_timeout(conn):
+ conn.set_autocommit(True)
+ t0 = time()
+ for n in conn.notifies(timeout=0.5):
+ assert False
+ dt = time() - t0
+ assert 0.5 <= dt < 0.75
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+def test_notify_timeout(conn_cls, conn, dsn):
+ conn.set_autocommit(True)
+ conn.execute("listen foo")
+
+ def notifier():
+ with conn_cls.connect(dsn, autocommit=True) as nconn:
+ sleep(0.25)
+ nconn.execute("notify foo, '1'")
+
+ worker = spawn(notifier)
+ try:
+ times = [time()]
+ for n in conn.notifies(timeout=0.5):
+ times.append(time())
+ times.append(time())
+ finally:
+ gather(worker)
+
+ assert len(times) == 3
+ assert times[1] - times[0] == pytest.approx(0.25, 0.1)
+ assert times[2] - times[1] == pytest.approx(0.25, 0.1)
+
+
+@pytest.mark.slow
+def test_notify_timeout_0(conn_cls, conn, dsn):
+ conn.set_autocommit(True)
+ conn.execute("listen foo")
+
+ ns = list(conn.notifies(timeout=0))
+ assert not ns
+
+ with conn_cls.connect(dsn, autocommit=True) as nconn:
+ nconn.execute("notify foo, '1'")
+ sleep(0.1)
+
+ ns = list(conn.notifies(timeout=0))
+ assert len(ns) == 1
+
+
+@pytest.mark.slow
+def test_stop_after(conn_cls, conn, dsn):
+ conn.set_autocommit(True)
+ conn.execute("listen foo")
+
+ def notifier():
+ with conn_cls.connect(dsn, autocommit=True) as nconn:
+ nconn.execute("notify foo, '1'")
+ sleep(0.1)
+ nconn.execute("notify foo, '2'")
+ sleep(0.1)
+ nconn.execute("notify foo, '3'")
+
+ worker = spawn(notifier)
+ try:
+ ns = list(conn.notifies(timeout=1.0, stop_after=2))
+ assert len(ns) == 2
+ assert ns[0].payload == "1"
+ assert ns[1].payload == "2"
+ finally:
+ gather(worker)
+
+ ns = list(conn.notifies(timeout=0.0))
+ assert len(ns) == 1
+ assert ns[0].payload == "3"
+
+
+def test_stop_after_batch(conn_cls, conn, dsn):
+ conn.set_autocommit(True)
+ conn.execute("listen foo")
+
+ def notifier():
+ with conn_cls.connect(dsn, autocommit=True) as nconn:
+ with nconn.transaction():
+ nconn.execute("notify foo, '1'")
+ nconn.execute("notify foo, '2'")
+
+ worker = spawn(notifier)
+ try:
+ ns = list(conn.notifies(timeout=1.0, stop_after=1))
+ assert len(ns) == 2
+ assert ns[0].payload == "1"
+ assert ns[1].payload == "2"
+ finally:
+ gather(worker)
--- /dev/null
+from __future__ import annotations
+
+from time import time
+
+import pytest
+from psycopg import Notify
+
+from .acompat import alist, asleep, gather, spawn
+
+pytestmark = pytest.mark.crdb_skip("notify")
+
+
+async def test_notify_handlers(aconn):
+ nots1 = []
+ nots2 = []
+
+ def cb1(n):
+ nots1.append(n)
+
+ aconn.add_notify_handler(cb1)
+ aconn.add_notify_handler(lambda n: nots2.append(n))
+
+ await aconn.set_autocommit(True)
+ await aconn.execute("listen foo")
+ await aconn.execute("notify foo, 'n1'")
+
+ assert len(nots1) == 1
+ n = nots1[0]
+ assert n.channel == "foo"
+ assert n.payload == "n1"
+ assert n.pid == aconn.pgconn.backend_pid
+
+ assert len(nots2) == 1
+ assert nots2[0] == nots1[0]
+
+ aconn.remove_notify_handler(cb1)
+ await aconn.execute("notify foo, 'n2'")
+
+ assert len(nots1) == 1
+ assert len(nots2) == 2
+ n = nots2[1]
+ assert isinstance(n, Notify)
+ assert n.channel == "foo"
+ assert n.payload == "n2"
+ assert n.pid == aconn.pgconn.backend_pid
+ assert hash(n)
+
+ with pytest.raises(ValueError):
+ aconn.remove_notify_handler(cb1)
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+async def test_notify(aconn_cls, aconn, dsn):
+ npid = None
+
+ async def notifier():
+ async with await aconn_cls.connect(dsn, autocommit=True) as nconn:
+ nonlocal npid
+ npid = nconn.pgconn.backend_pid
+
+ await asleep(0.25)
+ await nconn.execute("notify foo, '1'")
+ await asleep(0.25)
+ await nconn.execute("notify foo, '2'")
+
+ async def receiver():
+ await aconn.set_autocommit(True)
+ cur = aconn.cursor()
+ await cur.execute("listen foo")
+ gen = aconn.notifies()
+ async for n in gen:
+ ns.append((n, time()))
+ if len(ns) >= 2:
+ await gen.aclose()
+
+ ns: list[tuple[Notify, float]] = []
+ t0 = time()
+ workers = [spawn(notifier), spawn(receiver)]
+ await gather(*workers)
+ assert len(ns) == 2
+
+ n, t1 = ns[0]
+ assert n.pid == npid
+ assert n.channel == "foo"
+ assert n.payload == "1"
+ assert t1 - t0 == pytest.approx(0.25, abs=0.05)
+
+ n, t1 = ns[1]
+ assert n.pid == npid
+ assert n.channel == "foo"
+ assert n.payload == "2"
+ assert t1 - t0 == pytest.approx(0.5, abs=0.05)
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+async def test_no_notify_timeout(aconn):
+ await aconn.set_autocommit(True)
+ t0 = time()
+ async for n in aconn.notifies(timeout=0.5):
+ assert False
+ dt = time() - t0
+ assert 0.5 <= dt < 0.75
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+async def test_notify_timeout(aconn_cls, aconn, dsn):
+ await aconn.set_autocommit(True)
+ await aconn.execute("listen foo")
+
+ async def notifier():
+ async with await aconn_cls.connect(dsn, autocommit=True) as nconn:
+ await asleep(0.25)
+ await nconn.execute("notify foo, '1'")
+
+ worker = spawn(notifier)
+ try:
+ times = [time()]
+ async for n in aconn.notifies(timeout=0.5):
+ times.append(time())
+ times.append(time())
+ finally:
+ await gather(worker)
+
+ assert len(times) == 3
+ assert times[1] - times[0] == pytest.approx(0.25, 0.1)
+ assert times[2] - times[1] == pytest.approx(0.25, 0.1)
+
+
+@pytest.mark.slow
+async def test_notify_timeout_0(aconn_cls, aconn, dsn):
+ await aconn.set_autocommit(True)
+ await aconn.execute("listen foo")
+
+ ns = await alist(aconn.notifies(timeout=0))
+ assert not ns
+
+ async with await aconn_cls.connect(dsn, autocommit=True) as nconn:
+ await nconn.execute("notify foo, '1'")
+ await asleep(0.1)
+
+ ns = await alist(aconn.notifies(timeout=0))
+ assert len(ns) == 1
+
+
+@pytest.mark.slow
+async def test_stop_after(aconn_cls, aconn, dsn):
+ await aconn.set_autocommit(True)
+ await aconn.execute("listen foo")
+
+ async def notifier():
+ async with await aconn_cls.connect(dsn, autocommit=True) as nconn:
+ await nconn.execute("notify foo, '1'")
+ await asleep(0.1)
+ await nconn.execute("notify foo, '2'")
+ await asleep(0.1)
+ await nconn.execute("notify foo, '3'")
+
+ worker = spawn(notifier)
+ try:
+ ns = await alist(aconn.notifies(timeout=1.0, stop_after=2))
+ assert len(ns) == 2
+ assert ns[0].payload == "1"
+ assert ns[1].payload == "2"
+ finally:
+ await gather(worker)
+
+ ns = await alist(aconn.notifies(timeout=0.0))
+ assert len(ns) == 1
+ assert ns[0].payload == "3"
+
+
+async def test_stop_after_batch(aconn_cls, aconn, dsn):
+ await aconn.set_autocommit(True)
+ await aconn.execute("listen foo")
+
+ async def notifier():
+ async with await aconn_cls.connect(dsn, autocommit=True) as nconn:
+ async with nconn.transaction():
+ await nconn.execute("notify foo, '1'")
+ await nconn.execute("notify foo, '2'")
+
+ worker = spawn(notifier)
+ try:
+ ns = await alist(aconn.notifies(timeout=1.0, stop_after=1))
+ assert len(ns) == 2
+ assert ns[0].payload == "1"
+ assert ns[1].payload == "2"
+ finally:
+ await gather(worker)
class PsycopgTests(dbapi20.DatabaseAPI20Test):
driver = psycopg
# connect_args = () # set by the fixture
- connect_kw_args: Dict[str, Any] = {}
+ connect_kw_args: Dict[Any, Any] = {}
def test_nextset(self):
# tested elsewhere
assert p.age == 42
+def test_scalar_row(conn):
+ cur = conn.cursor(row_factory=rows.scalar_row)
+ cur.execute("select 1")
+ assert cur.fetchone() == 1
+ cur.execute("select 1, 2")
+ assert cur.fetchone() == 1
+ with pytest.raises(psycopg.ProgrammingError):
+ cur.execute("select")
+
+
@pytest.mark.parametrize(
"factory",
"tuple_row dict_row namedtuple_row class_row args_row kwargs_row".split(),
assert sql.Identifier("foo") != "foo"
assert sql.Identifier("foo") != sql.SQL("foo")
- @pytest.mark.parametrize(
- "args, want",
- [
- (("foo",), '"foo"'),
- (("foo", "bar"), '"foo"."bar"'),
- (("fo'o", 'ba"r'), '"fo\'o"."ba""r"'),
- ],
- )
+ _as_string_params = [
+ (("foo",), '"foo"'),
+ (("foo", "bar"), '"foo"."bar"'),
+ (("fo'o", 'ba"r'), '"fo\'o"."ba""r"'),
+ ]
+
+ @pytest.mark.parametrize("args, want", _as_string_params)
def test_as_string(self, conn, args, want):
assert sql.Identifier(*args).as_string(conn) == want
- @pytest.mark.parametrize(
- "args, want, enc",
- [
- crdb_encoding(("foo",), '"foo"', "ascii"),
- crdb_encoding(("foo", "bar"), '"foo"."bar"', "ascii"),
- crdb_encoding(("fo'o", 'ba"r'), '"fo\'o"."ba""r"', "ascii"),
- (("foo", eur), f'"foo"."{eur}"', "utf8"),
- crdb_encoding(("foo", eur), f'"foo"."{eur}"', "latin9"),
- ],
- )
+ @pytest.mark.parametrize("args, want", _as_string_params)
+ def test_as_string_no_conn(self, args, want):
+ assert sql.Identifier(*args).as_string(None) == want
+ assert sql.Identifier(*args).as_string() == want
+
+ _as_bytes_params = [
+ crdb_encoding(("foo",), '"foo"', "ascii"),
+ crdb_encoding(("foo", "bar"), '"foo"."bar"', "ascii"),
+ crdb_encoding(("fo'o", 'ba"r'), '"fo\'o"."ba""r"', "ascii"),
+ (("foo", eur), f'"foo"."{eur}"', "utf8"),
+ crdb_encoding(("foo", eur), f'"foo"."{eur}"', "latin9"),
+ ]
+
+ @pytest.mark.parametrize("args, want, enc", _as_bytes_params)
def test_as_bytes(self, conn, args, want, enc):
want = want.encode(enc)
conn.execute(f"set client_encoding to {py2pgenc(enc).decode()}")
assert sql.Identifier(*args).as_bytes(conn) == want
+ @pytest.mark.parametrize("args, want, enc", _as_bytes_params)
+ def test_as_bytes_no_conn(self, conn, args, want, enc):
+ want = want.encode()
+ assert sql.Identifier(*args).as_bytes(None) == want
+ assert sql.Identifier(*args).as_bytes() == want
+
def test_join(self):
assert not hasattr(sql.Identifier("foo"), "join")
+ def test_escape_no_conn(self, conn):
+ conn.execute("set client_encoding to 'utf8'")
+ for c in range(1, 128):
+ s = chr(c)
+ want = sql.Identifier(s).as_bytes(conn)
+ assert want == sql.Identifier(s).as_bytes(None)
+
class TestLiteral:
def test_class(self):
assert repr(sql.Literal("foo")) == "Literal('foo')"
assert str(sql.Literal("foo")) == "Literal('foo')"
- def test_as_string(self, conn):
- assert sql.Literal(None).as_string(conn) == "NULL"
- assert no_e(sql.Literal("foo").as_string(conn)) == "'foo'"
- assert sql.Literal(42).as_string(conn) == "42"
- assert sql.Literal(dt.date(2017, 1, 1)).as_string(conn) == "'2017-01-01'::date"
-
- def test_as_bytes(self, conn):
- assert sql.Literal(None).as_bytes(conn) == b"NULL"
- assert no_e(sql.Literal("foo").as_bytes(conn)) == b"'foo'"
- assert sql.Literal(42).as_bytes(conn) == b"42"
- assert sql.Literal(dt.date(2017, 1, 1)).as_bytes(conn) == b"'2017-01-01'::date"
+ _params = [
+ (None, "NULL"),
+ ("foo", "'foo'"),
+ (42, "42"),
+ (dt.date(2017, 1, 1), "'2017-01-01'::date"),
+ ]
+
+ @pytest.mark.parametrize("obj, want", _params)
+ def test_as_string(self, conn, obj, want):
+ got = sql.Literal(obj).as_string(conn)
+ if isinstance(obj, str):
+ got = no_e(got)
+ assert got == want
+
+ @pytest.mark.parametrize("obj, want", _params)
+ def test_as_bytes(self, conn, obj, want):
+ got = sql.Literal(obj).as_bytes(conn)
+ if isinstance(obj, str):
+ got = no_e(got)
+ assert got == want.encode()
+
+ @pytest.mark.parametrize("obj, want", _params)
+ def test_as_string_no_conn(self, obj, want):
+ got = sql.Literal(obj).as_string()
+ if isinstance(obj, str):
+ got = no_e(got)
+ assert got == want
+
+ @pytest.mark.parametrize("obj, want", _params)
+ def test_as_bytes_no_conn(self, obj, want):
+ got = sql.Literal(obj).as_bytes()
+ if isinstance(obj, str):
+ got = no_e(got)
+ assert got == want.encode()
@pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")])
def test_as_bytes_encoding(self, conn, encoding):
def test_as_string(self, conn):
assert sql.SQL("foo").as_string(conn) == "foo"
+ assert sql.SQL("foo").as_string() == "foo"
@pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")])
def test_as_bytes(self, conn, encoding):
assert sql.SQL(eur).as_bytes(conn) == eur.encode(encoding)
+ def test_no_conn(self):
+ assert sql.SQL(eur).as_string() == eur
+ assert sql.SQL(eur).as_bytes() == eur.encode()
+
class TestComposed:
def test_class(self):
def test_as_string(self, conn):
obj = sql.Composed([sql.SQL("foo"), sql.SQL("bar")])
assert obj.as_string(conn) == "foobar"
+ assert obj.as_string() == "foobar"
def test_as_bytes(self, conn):
obj = sql.Composed([sql.SQL("foo"), sql.SQL("bar")])
assert obj.as_bytes(conn) == b"foobar"
+ assert obj.as_bytes() == b"foobar"
@pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")])
def test_as_bytes_encoding(self, conn, encoding):
def test_as_string(self, conn, format):
ph = sql.Placeholder(format=format)
assert ph.as_string(conn) == f"%{format.value}"
+ assert ph.as_string() == f"%{format.value}"
ph = sql.Placeholder(name="foo", format=format)
assert ph.as_string(conn) == f"%(foo){format.value}"
+ assert ph.as_string() == f"%(foo){format.value}"
@pytest.mark.parametrize("format", PyFormat)
def test_as_bytes(self, conn, format):
ph = sql.Placeholder(format=format)
- assert ph.as_bytes(conn) == f"%{format.value}".encode("ascii")
+ assert ph.as_bytes(conn) == f"%{format.value}".encode()
+ assert ph.as_bytes() == f"%{format.value}".encode()
ph = sql.Placeholder(name="foo", format=format)
- assert ph.as_bytes(conn) == f"%(foo){format.value}".encode("ascii")
+ assert ph.as_bytes(conn) == f"%(foo){format.value}".encode()
+ assert ph.as_bytes() == f"%(foo){format.value}".encode()
class TestValues:
def no_e(s):
"""Drop an eventual E from E'' quotes"""
- if isinstance(s, memoryview):
+ if isinstance(s, (memoryview, bytearray)):
s = bytes(s)
if isinstance(s, str):
class TestTPC:
+
def test_tpc_commit(self, conn, tpc):
xid = conn.xid(1, "gtrid", "bqual")
assert conn.info.transaction_status == TransactionStatus.IDLE
+import sys
+import time
import select # noqa: used in pytest.mark.skipif
import socket
-import sys
import pytest
pytest.param("wait_c", marks=pytest.mark.skipif("not psycopg._cmodule._psycopg")),
]
+events = ["R", "W", "RW"]
timeouts = [pytest.param({}, id="blank")]
timeouts += [pytest.param({"timeout": x}, id=str(x)) for x in [None, 0, 0.2, 10]]
@pytest.mark.parametrize("waitfn", waitfns)
-@pytest.mark.parametrize("wait, ready", zip(waiting.Wait, waiting.Ready))
+@pytest.mark.parametrize("event", events)
@skip_if_not_linux
-def test_wait_ready(waitfn, wait, ready):
+def test_wait_ready(waitfn, event):
+ wait = getattr(waiting.Wait, event)
+ ready = getattr(waiting.Ready, event)
waitfn = getattr(waiting, waitfn)
def gen():
waitfn(gen, pgconn.socket)
+@pytest.mark.slow
+@pytest.mark.timing
+@pytest.mark.parametrize("waitfn", waitfns)
+def test_wait_timeout(pgconn, waitfn):
+ waitfn = getattr(waiting, waitfn)
+
+ pgconn.send_query(b"select pg_sleep(0.5)")
+ gen = generators.execute(pgconn)
+
+ ts = [time.time()]
+
+ def gen_wrapper():
+ try:
+ for x in gen:
+ res = yield x
+ ts.append(time.time())
+ gen.send(res)
+ except StopIteration as ex:
+ return ex.value
+
+ (res,) = waitfn(gen_wrapper(), pgconn.socket, timeout=0.1)
+ assert res.status == ExecStatus.TUPLES_OK
+ ds = [t1 - t0 for t0, t1 in zip(ts[:-1], ts[1:])]
+ assert len(ds) >= 5
+ for d in ds[:5]:
+ assert d == pytest.approx(0.1, 0.05)
+
+
@pytest.mark.slow
@pytest.mark.skipif(
"sys.platform == 'win32'", reason="win32 works ok, but FDs are mysterious"
@pytest.mark.anyio
-@pytest.mark.parametrize("wait, ready", zip(waiting.Wait, waiting.Ready))
+@pytest.mark.parametrize("event", events)
@skip_if_not_linux
-async def test_wait_ready_async(wait, ready):
+async def test_wait_ready_async(event):
+ wait = getattr(waiting.Wait, event)
+ ready = getattr(waiting.Ready, event)
+
def gen():
r = yield wait
return r
("-90d", "-3 month"),
("186d", "6 mons 6 days"),
("736d", "2 years 6 days"),
+ ("83063d,81640s,447000m", "1993534:40:40.447"),
],
)
@pytest.mark.parametrize("fmt_out", pq.Format)
@pytest.mark.parametrize(
"fmt_in",
[
- f
- if f != PyFormat.BINARY
- else pytest.param(f, marks=pytest.mark.crdb_skip("binary decimal"))
+ (
+ f
+ if f != PyFormat.BINARY
+ else pytest.param(f, marks=pytest.mark.crdb_skip("binary decimal"))
+ )
for f in PyFormat
],
)
tests/test_cursor_common_async.py
tests/test_cursor_raw_async.py
tests/test_cursor_server_async.py
+ tests/test_notify_async.py
tests/test_pipeline_async.py
tests/test_prepared_async.py
tests/test_tpc_async.py
import subprocess as sp
from typing import List
from pathlib import Path
-from typing_extensions import TypeAlias
import psycopg
from psycopg.rows import TupleRow
from psycopg.crdb import CrdbConnection
+from psycopg._compat import TypeAlias
Connection: TypeAlias = psycopg.Connection[TupleRow]