]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
chore: update code to master
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 4 Feb 2024 14:36:54 +0000 (14:36 +0000)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 4 Feb 2024 18:15:57 +0000 (18:15 +0000)
93 files changed:
.flake8
.github/workflows/3rd-party-tests.yml
.github/workflows/lint.yml
.github/workflows/packages-bin.yml
.github/workflows/packages-pool.yml
.github/workflows/packages-src.yml
.github/workflows/tests.yml
docs/advanced/async.rst
docs/api/connections.rst
docs/api/rows.rst
docs/api/sql.rst
docs/news.rst
psycopg/.flake8
psycopg/psycopg/_acompat.py
psycopg/psycopg/_compat.py
psycopg/psycopg/_connection_base.py
psycopg/psycopg/_conninfo_attempts.py
psycopg/psycopg/_conninfo_attempts_async.py
psycopg/psycopg/_conninfo_utils.py
psycopg/psycopg/_copy_base.py
psycopg/psycopg/_encodings.py
psycopg/psycopg/_enums.py
psycopg/psycopg/_pipeline.py
psycopg/psycopg/_preparing.py
psycopg/psycopg/_py_transformer.py
psycopg/psycopg/_queries.py
psycopg/psycopg/_struct.py
psycopg/psycopg/_typeinfo.py
psycopg/psycopg/abc.py
psycopg/psycopg/adapt.py
psycopg/psycopg/connection.py
psycopg/psycopg/connection_async.py
psycopg/psycopg/conninfo.py
psycopg/psycopg/cursor.py
psycopg/psycopg/cursor_async.py
psycopg/psycopg/errors.py
psycopg/psycopg/generators.py
psycopg/psycopg/pq/_pq_ctypes.py
psycopg/psycopg/pq/abc.py
psycopg/psycopg/rows.py
psycopg/psycopg/server_cursor.py
psycopg/psycopg/sql.py
psycopg/psycopg/types/enum.py
psycopg/psycopg/types/hstore.py
psycopg/psycopg/types/multirange.py
psycopg/psycopg/types/net.py
psycopg/psycopg/types/numeric.py
psycopg/psycopg/waiting.py
psycopg/setup.cfg
psycopg_c/.flake8
psycopg_c/psycopg_c/_psycopg/generators.pyx
psycopg_c/psycopg_c/_psycopg/waiting.pyx
psycopg_c/psycopg_c/types/datetime.pyx
psycopg_pool/.flake8
psycopg_pool/psycopg_pool/_acompat.py
psycopg_pool/psycopg_pool/_compat.py
psycopg_pool/psycopg_pool/abc.py
psycopg_pool/psycopg_pool/null_pool.py
psycopg_pool/psycopg_pool/pool.py
psycopg_pool/psycopg_pool/pool_async.py
psycopg_pool/psycopg_pool/sched.py
tests/conftest.py
tests/constraints.txt
tests/dbapi20.py
tests/fix_db.py
tests/fix_proxy.py
tests/fix_psycopg.py
tests/pool/test_pool.py
tests/pool/test_pool_async.py
tests/pool/test_pool_common.py
tests/pool/test_pool_null.py
tests/pool/test_pool_null_async.py
tests/scripts/dectest.py
tests/scripts/pipeline-demo.py
tests/test_concurrency_async.py
tests/test_connection.py
tests/test_connection_async.py
tests/test_conninfo_attempts.py
tests/test_conninfo_attempts_async.py
tests/test_copy.py
tests/test_cursor_common.py
tests/test_generators.py
tests/test_notify.py [new file with mode: 0644]
tests/test_notify_async.py [new file with mode: 0644]
tests/test_psycopg_dbapi20.py
tests/test_rows.py
tests/test_sql.py
tests/test_tpc.py
tests/test_waiting.py
tests/types/test_datetime.py
tests/types/test_numeric.py
tools/async_to_sync.py
tools/update_oids.py

diff --git a/.flake8 b/.flake8
index ec4053fb2be8d6fa8f90aeb773b09aab0c3e2579..d2473a1ae417a99c6d018cb6508ddabf37c63317 100644 (file)
--- a/.flake8
+++ b/.flake8
@@ -1,6 +1,6 @@
 [flake8]
 max-line-length = 88
-ignore = W503, E203
+ignore = W503, E203, E704
 extend-exclude = .venv build
 per-file-ignores =
     # Autogenerated section
index 89f948d0ea63d41c383cbbbe2e35cdc2121d2aa5..26c9f270c6f08fd8468f12c3bfad7c538577499c 100644 (file)
@@ -61,7 +61,7 @@ jobs:
           --health-retries 5
 
     steps:
-      - uses: actions/checkout@v3
+      - uses: actions/checkout@v4
 
       - uses: actions/setup-python@v5
         with:
@@ -158,7 +158,7 @@ jobs:
           --health-retries 5
 
     steps:
-      - uses: actions/checkout@v3
+      - uses: actions/checkout@v4
 
       - uses: actions/setup-python@v5
         with:
index b615fb28f2cc1e283d4568365592e42cd664a0f3..b86c53e75fcc7f47c3b13186c04fe376ec1a9288 100644 (file)
@@ -18,7 +18,7 @@ jobs:
     if: true
 
     steps:
-      - uses: actions/checkout@v3
+      - uses: actions/checkout@v4
 
       - uses: actions/setup-python@v5
         with:
index 975fe4c5a5f80f17dc1161d5129609082fcb3acb..ec5e1757a62160d733dae12364b948444b3fee92 100644 (file)
@@ -23,7 +23,7 @@ jobs:
         platform: [manylinux, musllinux]
 
     steps:
-      - uses: actions/checkout@v3
+      - uses: actions/checkout@v4
 
       - name: Set up QEMU for multi-arch build
         # Check https://github.com/docker/setup-qemu-action for newer versions.
@@ -43,7 +43,7 @@ jobs:
         run: python3 ./tools/build/copy_to_binary.py
 
       - name: Build wheels
-        uses: pypa/cibuildwheel@v2.16.0
+        uses: pypa/cibuildwheel@v2.16.5
         with:
           package-dir: psycopg_binary
         env:
@@ -104,13 +104,13 @@ jobs:
         pyver: [cp38, cp39, cp310, cp311, cp312]
 
     steps:
-      - uses: actions/checkout@v3
+      - uses: actions/checkout@v4
 
       - name: Create the binary package source tree
         run: python3 ./tools/build/copy_to_binary.py
 
       - name: Build wheels
-        uses: pypa/cibuildwheel@v2.16.0
+        uses: pypa/cibuildwheel@v2.16.5
         with:
           package-dir: psycopg_binary
         env:
@@ -147,7 +147,7 @@ jobs:
         pyver: [cp38, cp39, cp310, cp311, cp312]
 
     steps:
-      - uses: actions/checkout@v3
+      - uses: actions/checkout@v4
 
       - name: Start PostgreSQL service for test
         run: |
@@ -159,7 +159,7 @@ jobs:
         run: python3 ./tools/build/copy_to_binary.py
 
       - name: Build wheels
-        uses: pypa/cibuildwheel@v2.16.0
+        uses: pypa/cibuildwheel@v2.16.5
         with:
           package-dir: psycopg_binary
         env:
index 08c334825e16ca275beefdb2dd0aecc39eec91e9..db79ec13305b72832efae9d00d2c9af8719e6a32 100644 (file)
@@ -18,7 +18,7 @@ jobs:
           - {package: psycopg_pool, format: wheel}
 
     steps:
-      - uses: actions/checkout@v3
+      - uses: actions/checkout@v4
 
       - uses: actions/setup-python@v5
         with:
index 52db4858921a9e092971054c28e468b88bbba714..6b4f911dd234bbe3240f7daa931a833a8c5dbcbe 100644 (file)
@@ -20,7 +20,7 @@ jobs:
           - {package: psycopg_c, format: sdist, impl: c}
 
     steps:
-      - uses: actions/checkout@v3
+      - uses: actions/checkout@v4
 
       - uses: actions/setup-python@v5
         with:
index f97e64ba8844472855666eec96f01d0e4504a489..9fb9421a8fae0795757705357cfc5f248f322c77 100644 (file)
@@ -59,7 +59,7 @@ jobs:
       MARKERS: ""
 
     steps:
-      - uses: actions/checkout@v3
+      - uses: actions/checkout@v4
 
       - uses: actions/setup-python@v5
         with:
@@ -155,7 +155,7 @@ jobs:
       NOT_MARKERS: "timing proxy mypy"
 
     steps:
-      - uses: actions/checkout@v3
+      - uses: actions/checkout@v4
 
       - uses: actions/setup-python@v5
         with:
@@ -212,7 +212,7 @@ jobs:
         shell: bash
 
     steps:
-      - uses: actions/checkout@v3
+      - uses: actions/checkout@v4
 
       - uses: actions/setup-python@v5
         with:
@@ -267,7 +267,7 @@ jobs:
       PSYCOPG_TEST_DSN: "host=127.0.0.1 port=26257 user=root dbname=defaultdb"
 
     steps:
-      - uses: actions/checkout@v3
+      - uses: actions/checkout@v4
 
       - uses: actions/setup-python@v5
         with:
index bf75260711d02d2004e4975d4e77c15655b8f215..ef1f6c1513ba3fe6f1c0bbedcc617e19f519b343 100644 (file)
@@ -334,7 +334,8 @@ mode if you wish to receive or send notifications in a timely manner.
 Notifications are received as instances of `Notify`. If you are reserving a
 connection only to receive notifications, the simplest way is to consume the
 `Connection.notifies` generator. The generator can be stopped using
-`!close()`.
+`!close()`. Starting from Psycopg 3.2, the method supports options to receive
+notifications only for a certain time or up to a certain number.
 
 .. note::
 
index a607f07fa5191676d361f97f5fe14d5a34d6806e..898a8470d4da756a3567b753b85571437cad1bb7 100644 (file)
@@ -286,6 +286,10 @@ The `!Connection` class
         any sessions in the database generates a :sql:`NOTIFY` on one of the
         listened channels.
 
+        .. versionchanged:: 3.2
+
+            Added `!timeout` and `!stop_after` parameters.
+
     .. automethod:: add_notify_handler
 
         See :ref:`async-notify` for details.
@@ -494,6 +498,11 @@ The `!AsyncConnection` class
                     ...
 
     .. automethod:: notifies
+
+        .. versionchanged:: 3.2
+
+            Added `!timeout` and `!stop_after` parameters.
+
     .. automethod:: set_autocommit
     .. automethod:: set_isolation_level
     .. automethod:: set_read_only
index d4c438242541908d7ef608f083a29a1af3ba8478..15dfd3cadcea60845a1aea0c846c5fd3cfd27e67 100644 (file)
@@ -14,6 +14,10 @@ Check out :ref:`row-factory-create` for information about how to use these objec
 .. autofunction:: tuple_row
 .. autofunction:: dict_row
 .. autofunction:: namedtuple_row
+.. autofunction:: scalar_row
+
+        .. versionadded:: 3.2
+
 .. autofunction:: class_row
 
     This is not a row factory, but rather a factory of row factories.
index 6959fee4dba87df92ec7f2c74a0fe9f236f21a16..5e7000b269be7d3b954fbfdf6f228cc9425b0535 100644 (file)
@@ -108,9 +108,25 @@ The `!sql` objects are in the following inheritance hierarchy:
 
 .. autoclass:: Composable()
 
-    .. automethod:: as_bytes
     .. automethod:: as_string
 
+    .. versionchanged:: 3.2
+
+        The `!context` parameter is optional.
+
+        .. warning::
+
+            If a context is not specified, the results are "generic" and not
+            tailored for a specific target connection. Details such as the
+            connection encoding and escaping style will not be taken into
+            account.
+
+    .. automethod:: as_bytes
+
+    .. versionchanged:: 3.2
+
+        The `!context` parameter is optional. See `as_string` for details.
+
 
 .. autoclass:: SQL
 
index db3234f4ca94efc93c55299eb6655bef99d89877..aa87171a5198fe86525600a93c66be173809386e 100644 (file)
@@ -15,12 +15,17 @@ Psycopg 3.2 (unreleased)
 
 - Add support for integer, floating point, boolean `NumPy scalar types`__
   (:ticket:`#332`).
+- Add `!timeout` and `!stop_after` parameters to `Connection.notifies()`
+  (:ticket:`340`).
 - Add :ref:`raw-query-cursors` to execute queries using placeholders in
   PostgreSQL format (`$1`, `$2`...) (:ticket:`#560`).
+- Add `~rows.scalar_row` to return scalar values from a query (:ticket:`#723`).
 - Add `~Connection.set_autocommit()` on sync connections, and similar
   transaction control methods available on the async connections.
 - Add support for libpq functions to close prepared statements and portals
   introduced in libpq v17 (:ticket:`#603`).
+- The `!context` parameter of `sql` objects `~sql.Composable.as_string()` and
+  `~sql.Composable.as_bytes()` methods is not optional (:ticket:`#716`).
 - Disable receiving more than one result on the same cursor in pipeline mode,
   to iterate through `~Cursor.nextset()`. The behaviour was different than
   in non-pipeline mode and not totally reliable (:ticket:`#604`).
@@ -30,18 +35,29 @@ Psycopg 3.2 (unreleased)
 .. __: https://numpy.org/doc/stable/reference/arrays.scalars.html#built-in-scalar-types
 
 
-Psycopg 3.1.17 (unreleased)
+Psycopg 3.1.18 (unreleased)
 ^^^^^^^^^^^^^^^^^^^^^^^^^^^
 
-- Fix multiple connection attempts when a host name resolve to multiple
-  IP addresses (:ticket:`699`).
-- Use `typing.Self` as a more correct return value annotation of context
-  managers and other self-returning methods (see :ticket:`708`).
+- Fix possible deadlock on pipeline exit (:ticket:`#685`).
+- Fix overflow loading large intervals in C module (:ticket:`#719`).
+- Fix compatibility with musl libc distributions affected by `CPython issue
+  #65821`__ (:ticket:`#725`).
+
+.. __: https://github.com/python/cpython/issues/65821
 
 
 Current release
 ---------------
 
+Psycopg 3.1.17
+^^^^^^^^^^^^^^
+
+- Fix multiple connection attempts when a host name resolve to multiple
+  IP addresses (:ticket:`#699`).
+- Use `typing.Self` as a more correct return value annotation of context
+  managers and other self-returning methods (see :ticket:`#708`).
+
+
 Psycopg 3.1.16
 ^^^^^^^^^^^^^^
 
index 67fb0245c38715e659df23790d909f6f59dce3d1..33b08d768ceb0002ba839d2c779d8f3c2320287d 100644 (file)
@@ -1,6 +1,6 @@
 [flake8]
 max-line-length = 88
-ignore = W503, E203
+ignore = W503, E203, E704
 per-file-ignores =
     # Autogenerated section
     psycopg/errors.py: E125, E128, E302
index cf106c5ba9d869fdafa6f304e348a5eb30ac0248..d7290889d7119f9c539aeffaadb0e62f8bb9df75 100644 (file)
@@ -15,9 +15,7 @@ import asyncio
 import threading
 from typing import Any, Callable, Coroutine, TYPE_CHECKING
 
-from typing_extensions import TypeAlias
-
-from ._compat import TypeVar
+from ._compat import TypeAlias, TypeVar
 
 Worker: TypeAlias = threading.Thread
 AWorker: TypeAlias = "asyncio.Task[None]"
index 1e1130486599e6f60968462beec925aa998600c2..68d689a2d413ae2c6d32d3eeffb643f7b04632d3 100644 (file)
@@ -18,9 +18,9 @@ else:
     cache = lru_cache(maxsize=None)
 
 if sys.version_info >= (3, 10):
-    from typing import TypeGuard
+    from typing import TypeGuard, TypeAlias
 else:
-    from typing_extensions import TypeGuard
+    from typing_extensions import TypeGuard, TypeAlias
 
 if sys.version_info >= (3, 11):
     from typing import LiteralString, Self
@@ -37,6 +37,7 @@ __all__ = [
     "Deque",
     "LiteralString",
     "Self",
+    "TypeAlias",
     "TypeGuard",
     "TypeVar",
     "ZoneInfo",
index 4dc695ce6e0026cb6dfb50020cb110604e4f3bdc..39e00002b41eba5312a6826dbb516d878e32c48f 100644 (file)
@@ -11,7 +11,6 @@ from typing import TYPE_CHECKING
 from weakref import ref, ReferenceType
 from warnings import warn
 from functools import partial
-from typing_extensions import TypeAlias
 
 from . import pq
 from . import errors as e
@@ -23,7 +22,7 @@ from ._tpc import Xid
 from .rows import Row
 from .adapt import AdaptersMap
 from ._enums import IsolationLevel
-from ._compat import LiteralString, Self, TypeVar
+from ._compat import LiteralString, Self, TypeAlias, TypeVar
 from .pq.misc import connection_summary
 from ._pipeline import BasePipeline
 from ._encodings import pgconn_encoding
index 4fc0f792a33ced351583efa27218a53f7c83603c..6f64f4ba1cc19d2793bea75142ac3c52148f8c04 100644 (file)
@@ -14,14 +14,15 @@ import logging
 from random import shuffle
 
 from . import errors as e
-from ._conninfo_utils import ConnDict, get_param, is_ip_address, get_param_def
+from .abc import ConnDict, ConnMapping
+from ._conninfo_utils import get_param, is_ip_address, get_param_def
 from ._conninfo_utils import split_attempts
 
 
 logger = logging.getLogger("psycopg")
 
 
-def conninfo_attempts(params: ConnDict) -> list[ConnDict]:
+def conninfo_attempts(params: ConnMapping) -> list[ConnDict]:
     """Split a set of connection params on the single attempts to perform.
 
     A connection param can perform more than one attempt more than one ``host``
index 6aca4ee3adbf3a6a402eca6f8d414098db18f537..a549081e9578e1b2b6bc00cc26b0aaf65991de12 100644 (file)
@@ -11,7 +11,8 @@ import logging
 from random import shuffle
 
 from . import errors as e
-from ._conninfo_utils import ConnDict, get_param, is_ip_address, get_param_def
+from .abc import ConnDict, ConnMapping
+from ._conninfo_utils import get_param, is_ip_address, get_param_def
 from ._conninfo_utils import split_attempts
 
 if True:  # ASYNC:
@@ -20,7 +21,7 @@ if True:  # ASYNC:
 logger = logging.getLogger("psycopg")
 
 
-async def conninfo_attempts_async(params: ConnDict) -> list[ConnDict]:
+async def conninfo_attempts_async(params: ConnMapping) -> list[ConnDict]:
     """Split a set of connection params on the single attempts to perform.
 
     A connection param can perform more than one attempt more than one ``host``
index 8940c937b379d758be8972ed86eb7550efd6e2f6..a342987a09e279a41f424a20ec3f74c0c7fd901f 100644 (file)
@@ -7,19 +7,20 @@ Internal utilities to manipulate connection strings
 from __future__ import annotations
 
 import os
-from typing import Any
+from typing import TYPE_CHECKING
 from functools import lru_cache
 from ipaddress import ip_address
 from dataclasses import dataclass
-from typing_extensions import TypeAlias
 
 from . import pq
+from .abc import ConnDict, ConnMapping
 from . import errors as e
 
-ConnDict: TypeAlias = "dict[str, Any]"
+if TYPE_CHECKING:
+    from typing import Any  # noqa: F401
 
 
-def split_attempts(params: ConnDict) -> list[ConnDict]:
+def split_attempts(params: ConnMapping) -> list[ConnDict]:
     """
     Split connection parameters with a sequence of hosts into separate attempts.
     """
@@ -47,7 +48,7 @@ def split_attempts(params: ConnDict) -> list[ConnDict]:
 
     # A single attempt to make. Don't mangle the conninfo string.
     if nhosts <= 1:
-        return [params]
+        return [{**params}]
 
     if len(ports) == 1:
         ports *= nhosts
@@ -55,7 +56,7 @@ def split_attempts(params: ConnDict) -> list[ConnDict]:
     # Now all lists are either empty or have the same length
     rv = []
     for i in range(nhosts):
-        attempt = params.copy()
+        attempt = {**params}
         if hosts:
             attempt["host"] = hosts[i]
         if hostaddrs:
@@ -67,7 +68,7 @@ def split_attempts(params: ConnDict) -> list[ConnDict]:
     return rv
 
 
-def get_param(params: ConnDict, name: str) -> str | None:
+def get_param(params: ConnMapping, name: str) -> str | None:
     """
     Return a value from a connection string.
 
index 140744ff1c1e5e375d26c144dcec3f96576f914c..9194b266bca5c66434561a3027ca2b4da5507494 100644 (file)
@@ -210,20 +210,16 @@ class Formatter(ABC):
         self._row_mode = False  # true if the user is using write_row()
 
     @abstractmethod
-    def parse_row(self, data: Buffer) -> Optional[Tuple[Any, ...]]:
-        ...
+    def parse_row(self, data: Buffer) -> Optional[Tuple[Any, ...]]: ...
 
     @abstractmethod
-    def write(self, buffer: Union[Buffer, str]) -> Buffer:
-        ...
+    def write(self, buffer: Union[Buffer, str]) -> Buffer: ...
 
     @abstractmethod
-    def write_row(self, row: Sequence[Any]) -> Buffer:
-        ...
+    def write_row(self, row: Sequence[Any]) -> Buffer: ...
 
     @abstractmethod
-    def end(self) -> Buffer:
-        ...
+    def end(self) -> Buffer: ...
 
 
 class TextFormatter(Formatter):
index 2382710933d89d82ef9c786e99301968bbfc33b1..4949e26c683da38f294219130bbe4b0e8739a7b5 100644 (file)
@@ -116,7 +116,7 @@ def conninfo_encoding(conninfo: str) -> str:
     pgenc = params.get("client_encoding")
     if pgenc:
         try:
-            return pg2pyenc(pgenc.encode())
+            return pg2pyenc(str(pgenc).encode())
         except NotSupportedError:
             pass
 
index a7cb78df4c2123c008f81fc13b849dbf47a332ad..1975650c6654138862ec4f8542c7d8e88994befd 100644 (file)
@@ -20,6 +20,7 @@ class Wait(IntEnum):
 
 
 class Ready(IntEnum):
+    NONE = 0
     R = EVENT_READ
     W = EVENT_WRITE
     RW = EVENT_READ | EVENT_WRITE
index 72ac97ddd248d6ef0ee7f897223d6251628b630f..05d0beb64409f491ccc4a57efdcc5031bfa55b83 100644 (file)
@@ -7,12 +7,11 @@ commands pipeline management
 import logging
 from types import TracebackType
 from typing import Any, List, Optional, Union, Tuple, Type, TYPE_CHECKING
-from typing_extensions import TypeAlias
 
 from . import pq
 from . import errors as e
 from .abc import PipelineCommand, PQGen
-from ._compat import Deque, Self
+from ._compat import Deque, Self, TypeAlias
 from .pq.misc import connection_summary
 from ._encodings import pgconn_encoding
 from ._preparing import Key, Prepare
@@ -133,8 +132,7 @@ class BasePipeline:
             self._enqueue_sync()
             yield from self._communicate_gen()
         finally:
-            # No need to force flush since we emitted a sync just before.
-            yield from self._fetch_gen(flush=False)
+            yield from self._fetch_gen(flush=True)
 
     def _communicate_gen(self) -> PQGen[None]:
         """Communicate with pipeline to send commands and possibly fetch
index 158552ba55785709a802b53e449361317643359a..465de53a4899a2a4c7f13b814722e811d080272c 100644 (file)
@@ -7,10 +7,9 @@ Support for prepared statements
 from enum import IntEnum, auto
 from typing import Iterator, Optional, Sequence, Tuple, TYPE_CHECKING
 from collections import OrderedDict
-from typing_extensions import TypeAlias
 
 from . import pq
-from ._compat import Deque
+from ._compat import Deque, TypeAlias
 from ._queries import PostgresQuery
 
 if TYPE_CHECKING:
index 17f21c079e14be17484bc724a5b7563d9c868eb5..dd7f54759e29fabbeef4f7be25f9d3435bceba6c 100644 (file)
@@ -12,7 +12,6 @@ dependencies problems).
 from typing import Any, Dict, List, Optional, Sequence, Tuple
 from typing import DefaultDict, TYPE_CHECKING
 from collections import defaultdict
-from typing_extensions import TypeAlias
 
 from . import pq
 from . import abc
@@ -20,6 +19,7 @@ from . import errors as e
 from .abc import Buffer, LoadFunc, AdaptContext, PyFormat, DumperKey, NoneType
 from .rows import Row, RowMaker
 from ._oids import INVALID_OID, TEXT_OID
+from ._compat import TypeAlias
 from ._encodings import conn_encoding
 
 if TYPE_CHECKING:
index 376012aec7bd9a4c53e22c193e18aea71d8afedd..dc2e5a67e29d990aa621cbce4798594f38aaf14c 100644 (file)
@@ -8,14 +8,13 @@ import re
 from typing import Any, Callable, Dict, List, Mapping, Match, NamedTuple, Optional
 from typing import Sequence, Tuple, Union, TYPE_CHECKING
 from functools import lru_cache
-from typing_extensions import TypeAlias
 
 from . import pq
 from . import errors as e
 from .sql import Composable
 from .abc import Buffer, Query, Params
 from ._enums import PyFormat
-from ._compat import TypeGuard
+from ._compat import TypeAlias, TypeGuard
 from ._encodings import conn_encoding
 
 if TYPE_CHECKING:
index bce427c808d1737b5844dd8c3c7bede9f38c6cbd..7232a20bd90b309ccbd17de66344ecce0410e18d 100644 (file)
@@ -6,10 +6,10 @@ Utility functions to deal with binary structs.
 
 import struct
 from typing import Callable, cast, Optional, Protocol, Tuple
-from typing_extensions import TypeAlias
 
-from .abc import Buffer
 from . import errors as e
+from .abc import Buffer
+from ._compat import TypeAlias
 
 PackInt: TypeAlias = Callable[[int], bytes]
 UnpackInt: TypeAlias = Callable[[Buffer], Tuple[int]]
@@ -18,8 +18,7 @@ UnpackFloat: TypeAlias = Callable[[Buffer], Tuple[float]]
 
 
 class UnpackLen(Protocol):
-    def __call__(self, data: Buffer, start: Optional[int]) -> Tuple[int]:
-        ...
+    def __call__(self, data: Buffer, start: Optional[int]) -> Tuple[int]: ...
 
 
 pack_int2 = cast(PackInt, struct.Struct("!h").pack)
index bfa740ff9f82602897ebbda04d345413b63a4386..fc170492a5683340fe19fc8223f6040bdaa56fda 100644 (file)
@@ -9,13 +9,12 @@ information to the adapters if needed.
 
 from typing import Any, Dict, Iterator, Optional, overload
 from typing import Sequence, Tuple, Type, Union, TYPE_CHECKING
-from typing_extensions import TypeAlias
 
 from . import sql
 from . import errors as e
 from .abc import AdaptContext, Query
 from .rows import dict_row
-from ._compat import TypeVar
+from ._compat import TypeAlias, TypeVar
 from ._encodings import conn_encoding
 
 if TYPE_CHECKING:
@@ -59,15 +58,13 @@ class TypeInfo:
     @classmethod
     def fetch(
         cls: Type[T], conn: "Connection[Any]", name: Union[str, sql.Identifier]
-    ) -> Optional[T]:
-        ...
+    ) -> Optional[T]: ...
 
     @overload
     @classmethod
     async def fetch(
         cls: Type[T], conn: "AsyncConnection[Any]", name: Union[str, sql.Identifier]
-    ) -> Optional[T]:
-        ...
+    ) -> Optional[T]: ...
 
     @classmethod
     def fetch(
@@ -239,12 +236,10 @@ class TypesRegistry:
                 yield t
 
     @overload
-    def __getitem__(self, key: Union[str, int]) -> TypeInfo:
-        ...
+    def __getitem__(self, key: Union[str, int]) -> TypeInfo: ...
 
     @overload
-    def __getitem__(self, key: Tuple[Type[T], int]) -> T:
-        ...
+    def __getitem__(self, key: Tuple[Type[T], int]) -> T: ...
 
     def __getitem__(self, key: RegistryKey) -> TypeInfo:
         """
@@ -265,12 +260,10 @@ class TypesRegistry:
             raise KeyError(f"couldn't find the type {key!r} in the types registry")
 
     @overload
-    def get(self, key: Union[str, int]) -> Optional[TypeInfo]:
-        ...
+    def get(self, key: Union[str, int]) -> Optional[TypeInfo]: ...
 
     @overload
-    def get(self, key: Tuple[Type[T], int]) -> Optional[T]:
-        ...
+    def get(self, key: Tuple[Type[T], int]) -> Optional[T]: ...
 
     def get(self, key: RegistryKey) -> Optional[TypeInfo]:
         """
index 0952e8d0b5b4ba6e50e9d88f250c9dbc91ec138d..ad4a96646c13643139220adab2a561305312d78a 100644 (file)
@@ -4,14 +4,13 @@ Protocol objects representing different implementations of the same classes.
 
 # Copyright (C) 2020 The Psycopg Team
 
-from typing import Any, Callable, Generator, Mapping
+from typing import Any, Dict, Callable, Generator, Mapping
 from typing import List, Optional, Protocol, Sequence, Tuple, Union
 from typing import TYPE_CHECKING
-from typing_extensions import TypeAlias
 
 from . import pq
 from ._enums import PyFormat as PyFormat
-from ._compat import LiteralString, TypeVar
+from ._compat import LiteralString, TypeAlias, TypeVar
 
 if TYPE_CHECKING:
     from . import sql
@@ -31,18 +30,22 @@ Params: TypeAlias = Union[Sequence[Any], Mapping[str, Any]]
 ConnectionType = TypeVar("ConnectionType", bound="BaseConnection[Any]")
 PipelineCommand: TypeAlias = Callable[[], None]
 DumperKey: TypeAlias = Union[type, Tuple["DumperKey", ...]]
+ConnParam: TypeAlias = Union[str, int, None]
+ConnDict: TypeAlias = Dict[str, ConnParam]
+ConnMapping: TypeAlias = Mapping[str, ConnParam]
+
 
 # Waiting protocol types
 
 RV = TypeVar("RV")
 
-PQGenConn: TypeAlias = Generator[Tuple[int, "Wait"], "Ready", RV]
+PQGenConn: TypeAlias = Generator[Tuple[int, "Wait"], Union["Ready", int], RV]
 """Generator for processes where the connection file number can change.
 
 This can happen in connection and reset, but not in normal querying.
 """
 
-PQGen: TypeAlias = Generator["Wait", "Ready", RV]
+PQGen: TypeAlias = Generator["Wait", Union["Ready", int], RV]
 """Generator for processes where the connection file number won't change.
 """
 
@@ -54,8 +57,7 @@ class WaitFunc(Protocol):
 
     def __call__(
         self, gen: PQGen[RV], fileno: int, timeout: Optional[float] = None
-    ) -> RV:
-        ...
+    ) -> RV: ...
 
 
 # Adaptation types
@@ -106,8 +108,7 @@ class Dumper(Protocol):
     oid: int
     """The oid to pass to the server, if known; 0 otherwise (class attribute)."""
 
-    def __init__(self, cls: type, context: Optional[AdaptContext] = None):
-        ...
+    def __init__(self, cls: type, context: Optional[AdaptContext] = None): ...
 
     def dump(self, obj: Any) -> Buffer:
         """Convert the object `!obj` to PostgreSQL representation.
@@ -187,8 +188,7 @@ class Loader(Protocol):
     This is a class attribute.
     """
 
-    def __init__(self, oid: int, context: Optional[AdaptContext] = None):
-        ...
+    def __init__(self, oid: int, context: Optional[AdaptContext] = None): ...
 
     def load(self, data: Buffer) -> Any:
         """
@@ -203,28 +203,22 @@ class Transformer(Protocol):
     types: Optional[Tuple[int, ...]]
     formats: Optional[List[pq.Format]]
 
-    def __init__(self, context: Optional[AdaptContext] = None):
-        ...
+    def __init__(self, context: Optional[AdaptContext] = None): ...
 
     @classmethod
-    def from_context(cls, context: Optional[AdaptContext]) -> "Transformer":
-        ...
+    def from_context(cls, context: Optional[AdaptContext]) -> "Transformer": ...
 
     @property
-    def connection(self) -> Optional["BaseConnection[Any]"]:
-        ...
+    def connection(self) -> Optional["BaseConnection[Any]"]: ...
 
     @property
-    def encoding(self) -> str:
-        ...
+    def encoding(self) -> str: ...
 
     @property
-    def adapters(self) -> "AdaptersMap":
-        ...
+    def adapters(self) -> "AdaptersMap": ...
 
     @property
-    def pgresult(self) -> Optional["PGresult"]:
-        ...
+    def pgresult(self) -> Optional["PGresult"]: ...
 
     def set_pgresult(
         self,
@@ -232,34 +226,26 @@ class Transformer(Protocol):
         *,
         set_loaders: bool = True,
         format: Optional[pq.Format] = None
-    ) -> None:
-        ...
+    ) -> None: ...
 
-    def set_dumper_types(self, types: Sequence[int], format: pq.Format) -> None:
-        ...
+    def set_dumper_types(self, types: Sequence[int], format: pq.Format) -> None: ...
 
-    def set_loader_types(self, types: Sequence[int], format: pq.Format) -> None:
-        ...
+    def set_loader_types(self, types: Sequence[int], format: pq.Format) -> None: ...
 
     def dump_sequence(
         self, params: Sequence[Any], formats: Sequence[PyFormat]
-    ) -> Sequence[Optional[Buffer]]:
-        ...
+    ) -> Sequence[Optional[Buffer]]: ...
 
-    def as_literal(self, obj: Any) -> bytes:
-        ...
+    def as_literal(self, obj: Any) -> bytes: ...
 
-    def get_dumper(self, obj: Any, format: PyFormat) -> Dumper:
-        ...
+    def get_dumper(self, obj: Any, format: PyFormat) -> Dumper: ...
 
-    def load_rows(self, row0: int, row1: int, make_row: "RowMaker[Row]") -> List["Row"]:
-        ...
+    def load_rows(
+        self, row0: int, row1: int, make_row: "RowMaker[Row]"
+    ) -> List["Row"]: ...
 
-    def load_row(self, row: int, make_row: "RowMaker[Row]") -> Optional["Row"]:
-        ...
+    def load_row(self, row: int, make_row: "RowMaker[Row]") -> Optional["Row"]: ...
 
-    def load_sequence(self, record: Sequence[Optional[Buffer]]) -> Tuple[Any, ...]:
-        ...
+    def load_sequence(self, record: Sequence[Optional[Buffer]]) -> Tuple[Any, ...]: ...
 
-    def get_loader(self, oid: int, format: pq.Format) -> Loader:
-        ...
+    def get_loader(self, oid: int, format: pq.Format) -> Loader: ...
index 31a7104296d660bd8f3b834b19afa70de78056f0..7d6a191d8d68df009a9573a393bca0ad733259e1 100644 (file)
@@ -46,8 +46,7 @@ class Dumper(abc.Dumper, ABC):
         )
 
     @abstractmethod
-    def dump(self, obj: Any) -> Buffer:
-        ...
+    def dump(self, obj: Any) -> Buffer: ...
 
     def quote(self, obj: Any) -> Buffer:
         """
index dc02ce3815c359a6759710a174e7a145db596a20..cb0244aa504d5c7dc617dc16a7bcfb057a6218ba 100644 (file)
@@ -10,6 +10,7 @@ Psycopg connection object (sync version)
 from __future__ import annotations
 
 import logging
+from time import monotonic
 from types import TracebackType
 from typing import Any, Generator, Iterator, List, Optional
 from typing import Type, Union, cast, overload, TYPE_CHECKING
@@ -18,13 +19,13 @@ from contextlib import contextmanager
 from . import pq
 from . import errors as e
 from . import waiting
-from .abc import AdaptContext, Params, PQGen, PQGenConn, Query, RV
+from .abc import AdaptContext, ConnDict, ConnParam, Params, PQGen, PQGenConn, Query, RV
 from ._tpc import Xid
 from .rows import Row, RowFactory, tuple_row, args_row
 from .adapt import AdaptersMap
 from ._enums import IsolationLevel
 from ._compat import Self
-from .conninfo import ConnDict, make_conninfo, conninfo_to_dict
+from .conninfo import make_conninfo, conninfo_to_dict
 from .conninfo import conninfo_attempts, timeout_from_conninfo
 from ._pipeline import Pipeline
 from ._encodings import pgconn_encoding
@@ -39,10 +40,13 @@ from threading import Lock
 if TYPE_CHECKING:
     from .pq.abc import PGconn
 
+_WAIT_INTERVAL = 0.1
+
 TEXT = pq.Format.TEXT
 BINARY = pq.Format.BINARY
 
 IDLE = pq.TransactionStatus.IDLE
+ACTIVE = pq.TransactionStatus.ACTIVE
 INTRANS = pq.TransactionStatus.INTRANS
 
 _INTERRUPTED = KeyboardInterrupt
@@ -83,7 +87,7 @@ class Connection(BaseConnection[Row]):
         context: Optional[AdaptContext] = None,
         row_factory: Optional[RowFactory[Row]] = None,
         cursor_factory: Optional[Type[Cursor[Row]]] = None,
-        **kwargs: Any,
+        **kwargs: ConnParam,
     ) -> Self:
         """
         Connect to a database server and return a new `Connection` instance.
@@ -95,7 +99,7 @@ class Connection(BaseConnection[Row]):
         attempts = conninfo_attempts(params)
         for attempt in attempts:
             try:
-                conninfo = make_conninfo(**attempt)
+                conninfo = make_conninfo("", **attempt)
                 rv = cls._wait_conn(cls._connect_gen(conninfo), timeout=timeout)
                 break
             except e._NO_TRACEBACK as ex:
@@ -165,14 +169,12 @@ class Connection(BaseConnection[Row]):
         self.pgconn.finish()
 
     @overload
-    def cursor(self, *, binary: bool = False) -> Cursor[Row]:
-        ...
+    def cursor(self, *, binary: bool = False) -> Cursor[Row]: ...
 
     @overload
     def cursor(
         self, *, binary: bool = False, row_factory: RowFactory[CursorRow]
-    ) -> Cursor[CursorRow]:
-        ...
+    ) -> Cursor[CursorRow]: ...
 
     @overload
     def cursor(
@@ -182,8 +184,7 @@ class Connection(BaseConnection[Row]):
         binary: bool = False,
         scrollable: Optional[bool] = None,
         withhold: bool = False,
-    ) -> ServerCursor[Row]:
-        ...
+    ) -> ServerCursor[Row]: ...
 
     @overload
     def cursor(
@@ -194,8 +195,7 @@ class Connection(BaseConnection[Row]):
         row_factory: RowFactory[CursorRow],
         scrollable: Optional[bool] = None,
         withhold: bool = False,
-    ) -> ServerCursor[CursorRow]:
-        ...
+    ) -> ServerCursor[CursorRow]: ...
 
     def cursor(
         self,
@@ -280,20 +280,56 @@ class Connection(BaseConnection[Row]):
             with tx:
                 yield tx
 
-    def notifies(self) -> Generator[Notify, None, None]:
+    def notifies(
+        self, *, timeout: Optional[float] = None, stop_after: Optional[int] = None
+    ) -> Generator[Notify, None, None]:
         """
         Yield `Notify` objects as soon as they are received from the database.
+
+        :param timeout: maximum amount of time to wait for notifications.
+            `!None` means no timeout.
+        :param stop_after: stop after receiving this number of notifications.
+            You might actually receive more than this number if more than one
+            notifications arrives in the same packet.
         """
+        # Allow interrupting the wait with a signal by reducing a long timeout
+        # into shorter interval.
+        if timeout is not None:
+            deadline = monotonic() + timeout
+            timeout = min(timeout, _WAIT_INTERVAL)
+        else:
+            deadline = None
+            timeout = _WAIT_INTERVAL
+
+        nreceived = 0
+
         while True:
-            with self.lock:
-                try:
-                    ns = self.wait(notifies(self.pgconn))
-                except e._NO_TRACEBACK as ex:
-                    raise ex.with_traceback(None)
-            enc = pgconn_encoding(self.pgconn)
+            # Collect notifications. Also get the connection encoding if any
+            # notification is received to makes sure that they are consistent.
+            try:
+                with self.lock:
+                    ns = self.wait(notifies(self.pgconn), timeout=timeout)
+                    if ns:
+                        enc = pgconn_encoding(self.pgconn)
+            except e._NO_TRACEBACK as ex:
+                raise ex.with_traceback(None)
+
+            # Emit the notifications received.
             for pgn in ns:
                 n = Notify(pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid)
                 yield n
+                nreceived += 1
+
+            # Stop if we have received enough notifications.
+            if stop_after is not None and nreceived >= stop_after:
+                break
+
+            # Check the deadline after the loop to ensure that timeout=0
+            # polls at least once.
+            if deadline:
+                timeout = min(_WAIT_INTERVAL, deadline - monotonic())
+                if timeout < 0.0:
+                    break
 
     @contextmanager
     def pipeline(self) -> Iterator[Pipeline]:
@@ -315,7 +351,7 @@ class Connection(BaseConnection[Row]):
                     assert pipeline is self._pipeline
                     self._pipeline = None
 
-    def wait(self, gen: PQGen[RV], timeout: Optional[float] = 0.1) -> RV:
+    def wait(self, gen: PQGen[RV], timeout: Optional[float] = _WAIT_INTERVAL) -> RV:
         """
         Consume a generator operating on the connection.
 
@@ -325,13 +361,14 @@ class Connection(BaseConnection[Row]):
         try:
             return waiting.wait(gen, self.pgconn.socket, timeout=timeout)
         except _INTERRUPTED:
-            # On Ctrl-C, try to cancel the query in the server, otherwise
-            # the connection will remain stuck in ACTIVE state.
-            self._try_cancel(self.pgconn)
-            try:
-                waiting.wait(gen, self.pgconn.socket, timeout=timeout)
-            except e.QueryCanceled:
-                pass  # as expected
+            if self.pgconn.transaction_status == ACTIVE:
+                # On Ctrl-C, try to cancel the query in the server, otherwise
+                # the connection will remain stuck in ACTIVE state.
+                self._try_cancel(self.pgconn)
+                try:
+                    waiting.wait(gen, self.pgconn.socket, timeout=timeout)
+                except e.QueryCanceled:
+                    pass  # as expected
             raise
 
     @classmethod
index 46269fda4f6fc43bc2b0684a3bbf6937c01fd87b..d810d45b29850dcd8cead444a863f25e85a51adc 100644 (file)
@@ -7,6 +7,7 @@ Psycopg connection object (async version)
 from __future__ import annotations
 
 import logging
+from time import monotonic
 from types import TracebackType
 from typing import Any, AsyncGenerator, AsyncIterator, List, Optional
 from typing import Type, Union, cast, overload, TYPE_CHECKING
@@ -15,13 +16,13 @@ from contextlib import asynccontextmanager
 from . import pq
 from . import errors as e
 from . import waiting
-from .abc import AdaptContext, Params, PQGen, PQGenConn, Query, RV
+from .abc import AdaptContext, ConnDict, ConnParam, Params, PQGen, PQGenConn, Query, RV
 from ._tpc import Xid
 from .rows import Row, AsyncRowFactory, tuple_row, args_row
 from .adapt import AdaptersMap
 from ._enums import IsolationLevel
 from ._compat import Self
-from .conninfo import ConnDict, make_conninfo, conninfo_to_dict
+from .conninfo import make_conninfo, conninfo_to_dict
 from .conninfo import conninfo_attempts_async, timeout_from_conninfo
 from ._pipeline import AsyncPipeline
 from ._encodings import pgconn_encoding
@@ -41,10 +42,13 @@ else:
 if TYPE_CHECKING:
     from .pq.abc import PGconn
 
+_WAIT_INTERVAL = 0.1
+
 TEXT = pq.Format.TEXT
 BINARY = pq.Format.BINARY
 
 IDLE = pq.TransactionStatus.IDLE
+ACTIVE = pq.TransactionStatus.ACTIVE
 INTRANS = pq.TransactionStatus.INTRANS
 
 if True:  # ASYNC
@@ -88,7 +92,7 @@ class AsyncConnection(BaseConnection[Row]):
         context: Optional[AdaptContext] = None,
         row_factory: Optional[AsyncRowFactory[Row]] = None,
         cursor_factory: Optional[Type[AsyncCursor[Row]]] = None,
-        **kwargs: Any,
+        **kwargs: ConnParam,
     ) -> Self:
         """
         Connect to a database server and return a new `AsyncConnection` instance.
@@ -110,7 +114,7 @@ class AsyncConnection(BaseConnection[Row]):
         attempts = await conninfo_attempts_async(params)
         for attempt in attempts:
             try:
-                conninfo = make_conninfo(**attempt)
+                conninfo = make_conninfo("", **attempt)
                 rv = await cls._wait_conn(cls._connect_gen(conninfo), timeout=timeout)
                 break
             except e._NO_TRACEBACK as ex:
@@ -180,14 +184,12 @@ class AsyncConnection(BaseConnection[Row]):
         self.pgconn.finish()
 
     @overload
-    def cursor(self, *, binary: bool = False) -> AsyncCursor[Row]:
-        ...
+    def cursor(self, *, binary: bool = False) -> AsyncCursor[Row]: ...
 
     @overload
     def cursor(
         self, *, binary: bool = False, row_factory: AsyncRowFactory[CursorRow]
-    ) -> AsyncCursor[CursorRow]:
-        ...
+    ) -> AsyncCursor[CursorRow]: ...
 
     @overload
     def cursor(
@@ -197,8 +199,7 @@ class AsyncConnection(BaseConnection[Row]):
         binary: bool = False,
         scrollable: Optional[bool] = None,
         withhold: bool = False,
-    ) -> AsyncServerCursor[Row]:
-        ...
+    ) -> AsyncServerCursor[Row]: ...
 
     @overload
     def cursor(
@@ -209,8 +210,7 @@ class AsyncConnection(BaseConnection[Row]):
         row_factory: AsyncRowFactory[CursorRow],
         scrollable: Optional[bool] = None,
         withhold: bool = False,
-    ) -> AsyncServerCursor[CursorRow]:
-        ...
+    ) -> AsyncServerCursor[CursorRow]: ...
 
     def cursor(
         self,
@@ -296,20 +296,56 @@ class AsyncConnection(BaseConnection[Row]):
             async with tx:
                 yield tx
 
-    async def notifies(self) -> AsyncGenerator[Notify, None]:
+    async def notifies(
+        self, *, timeout: Optional[float] = None, stop_after: Optional[int] = None
+    ) -> AsyncGenerator[Notify, None]:
         """
         Yield `Notify` objects as soon as they are received from the database.
+
+        :param timeout: maximum amount of time to wait for notifications.
+            `!None` means no timeout.
+        :param stop_after: stop after receiving this number of notifications.
+            You might actually receive more than this number if more than one
+            notifications arrives in the same packet.
         """
+        # Allow interrupting the wait with a signal by reducing a long timeout
+        # into shorter interval.
+        if timeout is not None:
+            deadline = monotonic() + timeout
+            timeout = min(timeout, _WAIT_INTERVAL)
+        else:
+            deadline = None
+            timeout = _WAIT_INTERVAL
+
+        nreceived = 0
+
         while True:
-            async with self.lock:
-                try:
-                    ns = await self.wait(notifies(self.pgconn))
-                except e._NO_TRACEBACK as ex:
-                    raise ex.with_traceback(None)
-            enc = pgconn_encoding(self.pgconn)
+            # Collect notifications. Also get the connection encoding if any
+            # notification is received to makes sure that they are consistent.
+            try:
+                async with self.lock:
+                    ns = await self.wait(notifies(self.pgconn), timeout=timeout)
+                    if ns:
+                        enc = pgconn_encoding(self.pgconn)
+            except e._NO_TRACEBACK as ex:
+                raise ex.with_traceback(None)
+
+            # Emit the notifications received.
             for pgn in ns:
                 n = Notify(pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid)
                 yield n
+                nreceived += 1
+
+            # Stop if we have received enough notifications.
+            if stop_after is not None and nreceived >= stop_after:
+                break
+
+            # Check the deadline after the loop to ensure that timeout=0
+            # polls at least once.
+            if deadline:
+                timeout = min(_WAIT_INTERVAL, deadline - monotonic())
+                if timeout < 0.0:
+                    break
 
     @asynccontextmanager
     async def pipeline(self) -> AsyncIterator[AsyncPipeline]:
@@ -331,7 +367,9 @@ class AsyncConnection(BaseConnection[Row]):
                     assert pipeline is self._pipeline
                     self._pipeline = None
 
-    async def wait(self, gen: PQGen[RV], timeout: Optional[float] = 0.1) -> RV:
+    async def wait(
+        self, gen: PQGen[RV], timeout: Optional[float] = _WAIT_INTERVAL
+    ) -> RV:
         """
         Consume a generator operating on the connection.
 
@@ -341,13 +379,14 @@ class AsyncConnection(BaseConnection[Row]):
         try:
             return await waiting.wait_async(gen, self.pgconn.socket, timeout=timeout)
         except _INTERRUPTED:
-            # On Ctrl-C, try to cancel the query in the server, otherwise
-            # the connection will remain stuck in ACTIVE state.
-            self._try_cancel(self.pgconn)
-            try:
-                await waiting.wait_async(gen, self.pgconn.socket, timeout=timeout)
-            except e.QueryCanceled:
-                pass  # as expected
+            if self.pgconn.transaction_status == ACTIVE:
+                # On Ctrl-C, try to cancel the query in the server, otherwise
+                # the connection will remain stuck in ACTIVE state.
+                self._try_cancel(self.pgconn)
+                try:
+                    await waiting.wait_async(gen, self.pgconn.socket, timeout=timeout)
+                except e.QueryCanceled:
+                    pass  # as expected
             raise
 
     @classmethod
index 82da5882259057c9b668ff198d6f42a6aa4114b0..1401426b2ebeab32711aceba4db3643754aaa3d6 100644 (file)
@@ -7,17 +7,15 @@ Functions to manipulate conninfo strings
 from __future__ import annotations
 
 import re
-from typing import Any
 
 from . import pq
 from . import errors as e
-
 from . import _conninfo_utils
 from . import _conninfo_attempts
 from . import _conninfo_attempts_async
+from .abc import ConnParam, ConnDict
 
 # re-exoprts
-ConnDict = _conninfo_utils.ConnDict
 conninfo_attempts = _conninfo_attempts.conninfo_attempts
 conninfo_attempts_async = _conninfo_attempts_async.conninfo_attempts_async
 
@@ -27,7 +25,7 @@ conninfo_attempts_async = _conninfo_attempts_async.conninfo_attempts_async
 _DEFAULT_CONNECT_TIMEOUT = 130
 
 
-def make_conninfo(conninfo: str = "", **kwargs: Any) -> str:
+def make_conninfo(conninfo: str = "", **kwargs: ConnParam) -> str:
     """
     Merge a string and keyword params into a single conninfo string.
 
@@ -68,7 +66,7 @@ def make_conninfo(conninfo: str = "", **kwargs: Any) -> str:
     return conninfo
 
 
-def conninfo_to_dict(conninfo: str = "", **kwargs: Any) -> ConnDict:
+def conninfo_to_dict(conninfo: str = "", **kwargs: ConnParam) -> ConnDict:
     """
     Convert the `!conninfo` string into a dictionary of parameters.
 
@@ -84,7 +82,9 @@ def conninfo_to_dict(conninfo: str = "", **kwargs: Any) -> ConnDict:
            #LIBPQ-CONNSTRING
     """
     opts = _parse_conninfo(conninfo)
-    rv = {opt.keyword.decode(): opt.val.decode() for opt in opts if opt.val is not None}
+    rv: ConnDict = {
+        opt.keyword.decode(): opt.val.decode() for opt in opts if opt.val is not None
+    }
     for k, v in kwargs.items():
         if v is not None:
             rv[k] = v
index 10741c95fd2eec7c4dd8f408c8a7c236986b8d98..6b48929bce9348de6c5f557e3179f04eb8eacd44 100644 (file)
@@ -34,12 +34,12 @@ class Cursor(BaseCursor["Connection[Any]", Row]):
     __slots__ = ()
 
     @overload
-    def __init__(self, connection: Connection[Row]):
-        ...
+    def __init__(self, connection: Connection[Row]): ...
 
     @overload
-    def __init__(self, connection: Connection[Any], *, row_factory: RowFactory[Row]):
-        ...
+    def __init__(
+        self, connection: Connection[Any], *, row_factory: RowFactory[Row]
+    ): ...
 
     def __init__(
         self,
index 603560155c2800b6cda9206a38da7df9ab33d044..55dc9a5c2616f29e38aa7ea68e9b5bc09e61cad8 100644 (file)
@@ -31,14 +31,12 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
     __slots__ = ()
 
     @overload
-    def __init__(self, connection: AsyncConnection[Row]):
-        ...
+    def __init__(self, connection: AsyncConnection[Row]): ...
 
     @overload
     def __init__(
         self, connection: AsyncConnection[Any], *, row_factory: AsyncRowFactory[Row]
-    ):
-        ...
+    ): ...
 
     def __init__(
         self,
index d2cd8120724e7c496021ccbcc8ed969b0fdc41f1..d2e2a955e6d500cb30ea79c56ebf843b369d37eb 100644 (file)
@@ -21,12 +21,11 @@ DBAPI-defined Exceptions are defined in the following hierarchy::
 from dataclasses import dataclass, field, fields
 from typing import Any, Callable, Dict, List, NoReturn, Optional, Sequence, Tuple, Type
 from typing import Union, TYPE_CHECKING
-from typing_extensions import TypeAlias
 from asyncio import CancelledError
 
 from .pq.abc import PGconn, PGresult
 from .pq._enums import ConnStatus, DiagnosticField, PipelineStatus, TransactionStatus
-from ._compat import TypeGuard
+from ._compat import TypeAlias, TypeGuard
 
 if TYPE_CHECKING:
     from .pq.misc import PGnotify, ConninfoOption
index 4f2ec878bb9cc70d06c392f884cba07962b9669e..2e463196e6e5eee462f3bdbde8839f6a3ef712a9 100644 (file)
@@ -7,10 +7,15 @@ the operations, yielding a polling state whenever there is to wait. The
 functions in the `waiting` module are the ones who wait more or less
 cooperatively for the socket to be ready and make these generators continue.
 
-All these generators yield pairs (fileno, `Wait`) whenever an operation would
-block. The generator can be restarted sending the appropriate `Ready` state
-when the file descriptor is ready.
-
+These generators yield `Wait` objects whenever an operation would block. These
+generators assume the connection fileno will not change. In case of the
+connection function, where the fileno may change, the generators yield pairs
+(fileno, `Wait`).
+
+The generator can be restarted sending the appropriate `Ready` state when the
+file descriptor is ready. If a None value is sent, it means that the wait
+function timed out without any file descriptor becoming ready; in this case the
+generator should probably yield the same value again in order to wait more.
 """
 
 # Copyright (C) 2020 The Psycopg Team
@@ -119,7 +124,11 @@ def _send(pgconn: PGconn) -> PQGen[None]:
         if f == 0:
             break
 
-        ready = yield WAIT_RW
+        while True:
+            ready = yield WAIT_RW
+            if ready:
+                break
+
         if ready & READY_R:
             # This call may read notifies: they will be saved in the
             # PGconn buffer and passed to Python later, in `fetch()`.
@@ -168,12 +177,19 @@ def _fetch(pgconn: PGconn) -> PQGen[Optional[PGresult]]:
     Return a result from the database (whether success or error).
     """
     if pgconn.is_busy():
-        yield WAIT_R
+        while True:
+            ready = yield WAIT_R
+            if ready:
+                break
+
         while True:
             pgconn.consume_input()
             if not pgconn.is_busy():
                 break
-            yield WAIT_R
+            while True:
+                ready = yield WAIT_R
+                if ready:
+                    break
 
     _consume_notifies(pgconn)
 
@@ -191,7 +207,10 @@ def _pipeline_communicate(
     results = []
 
     while True:
-        ready = yield WAIT_RW
+        while True:
+            ready = yield WAIT_RW
+            if ready:
+                break
 
         if ready & READY_R:
             pgconn.consume_input()
@@ -263,7 +282,10 @@ def copy_from(pgconn: PGconn) -> PQGen[Union[memoryview, PGresult]]:
             break
 
         # would block
-        yield WAIT_R
+        while True:
+            ready = yield WAIT_R
+            if ready:
+                break
         pgconn.consume_input()
 
     if nbytes > 0:
@@ -291,17 +313,26 @@ def copy_to(pgconn: PGconn, buffer: Buffer) -> PQGen[None]:
     # into smaller ones. We prefer to do it there instead of here in order to
     # do it upstream the queue decoupling the writer task from the producer one.
     while pgconn.put_copy_data(buffer) == 0:
-        yield WAIT_W
+        while True:
+            ready = yield WAIT_W
+            if ready:
+                break
 
 
 def copy_end(pgconn: PGconn, error: Optional[bytes]) -> PQGen[PGresult]:
     # Retry enqueuing end copy message until successful
     while pgconn.put_copy_end(error) == 0:
-        yield WAIT_W
+        while True:
+            ready = yield WAIT_W
+            if ready:
+                break
 
     # Repeat until it the message is flushed to the server
     while True:
-        yield WAIT_W
+        while True:
+            ready = yield WAIT_W
+            if ready:
+                break
         f = pgconn.flush()
         if f == 0:
             break
index 9d4dd18141886d052228664f62ef38271b8b8499..1b0f391f2df81eea2b506634cca8d1f2a90cf3d3 100644 (file)
@@ -29,7 +29,10 @@ FILE_ptr = POINTER(FILE)
 
 if sys.platform == "linux":
     libcname = ctypes.util.find_library("c")
-    assert libcname
+    if not libcname:
+        # Likely this is a system using musl libc, see the following bug:
+        # https://github.com/python/cpython/issues/65821
+        libcname = "libc.so"
     libc = ctypes.cdll.LoadLibrary(libcname)
 
     fdopen = libc.fdopen
index 3a76d56c086bb72ce1b5d0902a2d1f23a6217968..13a077211204a396eb144eb11ad3a00bd05b5f86 100644 (file)
@@ -6,9 +6,9 @@ Protocol objects to represent objects exposed by different pq implementations.
 
 from typing import Any, Callable, List, Optional, Protocol, Sequence, Tuple
 from typing import Union, TYPE_CHECKING
-from typing_extensions import TypeAlias
 
 from ._enums import Format, Trace
+from .._compat import TypeAlias
 
 if TYPE_CHECKING:
     from .misc import PGnotify, ConninfoOption, PGresAttDesc
@@ -22,112 +22,83 @@ class PGconn(Protocol):
     notify_handler: Optional[Callable[["PGnotify"], None]]
 
     @classmethod
-    def connect(cls, conninfo: bytes) -> "PGconn":
-        ...
+    def connect(cls, conninfo: bytes) -> "PGconn": ...
 
     @classmethod
-    def connect_start(cls, conninfo: bytes) -> "PGconn":
-        ...
+    def connect_start(cls, conninfo: bytes) -> "PGconn": ...
 
-    def connect_poll(self) -> int:
-        ...
+    def connect_poll(self) -> int: ...
 
-    def finish(self) -> None:
-        ...
+    def finish(self) -> None: ...
 
     @property
-    def info(self) -> List["ConninfoOption"]:
-        ...
+    def info(self) -> List["ConninfoOption"]: ...
 
-    def reset(self) -> None:
-        ...
+    def reset(self) -> None: ...
 
-    def reset_start(self) -> None:
-        ...
+    def reset_start(self) -> None: ...
 
-    def reset_poll(self) -> int:
-        ...
+    def reset_poll(self) -> int: ...
 
     @classmethod
-    def ping(self, conninfo: bytes) -> int:
-        ...
+    def ping(self, conninfo: bytes) -> int: ...
 
     @property
-    def db(self) -> bytes:
-        ...
+    def db(self) -> bytes: ...
 
     @property
-    def user(self) -> bytes:
-        ...
+    def user(self) -> bytes: ...
 
     @property
-    def password(self) -> bytes:
-        ...
+    def password(self) -> bytes: ...
 
     @property
-    def host(self) -> bytes:
-        ...
+    def host(self) -> bytes: ...
 
     @property
-    def hostaddr(self) -> bytes:
-        ...
+    def hostaddr(self) -> bytes: ...
 
     @property
-    def port(self) -> bytes:
-        ...
+    def port(self) -> bytes: ...
 
     @property
-    def tty(self) -> bytes:
-        ...
+    def tty(self) -> bytes: ...
 
     @property
-    def options(self) -> bytes:
-        ...
+    def options(self) -> bytes: ...
 
     @property
-    def status(self) -> int:
-        ...
+    def status(self) -> int: ...
 
     @property
-    def transaction_status(self) -> int:
-        ...
+    def transaction_status(self) -> int: ...
 
-    def parameter_status(self, name: bytes) -> Optional[bytes]:
-        ...
+    def parameter_status(self, name: bytes) -> Optional[bytes]: ...
 
     @property
-    def error_message(self) -> bytes:
-        ...
+    def error_message(self) -> bytes: ...
 
     @property
-    def server_version(self) -> int:
-        ...
+    def server_version(self) -> int: ...
 
     @property
-    def socket(self) -> int:
-        ...
+    def socket(self) -> int: ...
 
     @property
-    def backend_pid(self) -> int:
-        ...
+    def backend_pid(self) -> int: ...
 
     @property
-    def needs_password(self) -> bool:
-        ...
+    def needs_password(self) -> bool: ...
 
     @property
-    def used_password(self) -> bool:
-        ...
+    def used_password(self) -> bool: ...
 
     @property
-    def ssl_in_use(self) -> bool:
-        ...
+    def ssl_in_use(self) -> bool: ...
 
-    def exec_(self, command: bytes) -> "PGresult":
-        ...
+    def exec_(self, command: bytes) -> "PGresult": ...
 
-    def send_query(self, command: bytes) -> None:
-        ...
+    def send_query(self, command: bytes) -> None: ...
 
     def exec_params(
         self,
@@ -136,8 +107,7 @@ class PGconn(Protocol):
         param_types: Optional[Sequence[int]] = None,
         param_formats: Optional[Sequence[int]] = None,
         result_format: int = Format.TEXT,
-    ) -> "PGresult":
-        ...
+    ) -> "PGresult": ...
 
     def send_query_params(
         self,
@@ -146,16 +116,14 @@ class PGconn(Protocol):
         param_types: Optional[Sequence[int]] = None,
         param_formats: Optional[Sequence[int]] = None,
         result_format: int = Format.TEXT,
-    ) -> None:
-        ...
+    ) -> None: ...
 
     def send_prepare(
         self,
         name: bytes,
         command: bytes,
         param_types: Optional[Sequence[int]] = None,
-    ) -> None:
-        ...
+    ) -> None: ...
 
     def send_query_prepared(
         self,
@@ -163,16 +131,14 @@ class PGconn(Protocol):
         param_values: Optional[Sequence[Optional[Buffer]]],
         param_formats: Optional[Sequence[int]] = None,
         result_format: int = Format.TEXT,
-    ) -> None:
-        ...
+    ) -> None: ...
 
     def prepare(
         self,
         name: bytes,
         command: bytes,
         param_types: Optional[Sequence[int]] = None,
-    ) -> "PGresult":
-        ...
+    ) -> "PGresult": ...
 
     def exec_prepared(
         self,
@@ -180,216 +146,153 @@ class PGconn(Protocol):
         param_values: Optional[Sequence[Buffer]],
         param_formats: Optional[Sequence[int]] = None,
         result_format: int = 0,
-    ) -> "PGresult":
-        ...
+    ) -> "PGresult": ...
 
-    def describe_prepared(self, name: bytes) -> "PGresult":
-        ...
+    def describe_prepared(self, name: bytes) -> "PGresult": ...
 
-    def send_describe_prepared(self, name: bytes) -> None:
-        ...
+    def send_describe_prepared(self, name: bytes) -> None: ...
 
-    def describe_portal(self, name: bytes) -> "PGresult":
-        ...
+    def describe_portal(self, name: bytes) -> "PGresult": ...
 
-    def send_describe_portal(self, name: bytes) -> None:
-        ...
+    def send_describe_portal(self, name: bytes) -> None: ...
 
-    def close_prepared(self, name: bytes) -> "PGresult":
-        ...
+    def close_prepared(self, name: bytes) -> "PGresult": ...
 
-    def send_close_prepared(self, name: bytes) -> None:
-        ...
+    def send_close_prepared(self, name: bytes) -> None: ...
 
-    def close_portal(self, name: bytes) -> "PGresult":
-        ...
+    def close_portal(self, name: bytes) -> "PGresult": ...
 
-    def send_close_portal(self, name: bytes) -> None:
-        ...
+    def send_close_portal(self, name: bytes) -> None: ...
 
-    def get_result(self) -> Optional["PGresult"]:
-        ...
+    def get_result(self) -> Optional["PGresult"]: ...
 
-    def consume_input(self) -> None:
-        ...
+    def consume_input(self) -> None: ...
 
-    def is_busy(self) -> int:
-        ...
+    def is_busy(self) -> int: ...
 
     @property
-    def nonblocking(self) -> int:
-        ...
+    def nonblocking(self) -> int: ...
 
     @nonblocking.setter
-    def nonblocking(self, arg: int) -> None:
-        ...
+    def nonblocking(self, arg: int) -> None: ...
 
-    def flush(self) -> int:
-        ...
+    def flush(self) -> int: ...
 
-    def set_single_row_mode(self) -> None:
-        ...
+    def set_single_row_mode(self) -> None: ...
 
-    def get_cancel(self) -> "PGcancel":
-        ...
+    def get_cancel(self) -> "PGcancel": ...
 
-    def notifies(self) -> Optional["PGnotify"]:
-        ...
+    def notifies(self) -> Optional["PGnotify"]: ...
 
-    def put_copy_data(self, buffer: Buffer) -> int:
-        ...
+    def put_copy_data(self, buffer: Buffer) -> int: ...
 
-    def put_copy_end(self, error: Optional[bytes] = None) -> int:
-        ...
+    def put_copy_end(self, error: Optional[bytes] = None) -> int: ...
 
-    def get_copy_data(self, async_: int) -> Tuple[int, memoryview]:
-        ...
+    def get_copy_data(self, async_: int) -> Tuple[int, memoryview]: ...
 
-    def trace(self, fileno: int) -> None:
-        ...
+    def trace(self, fileno: int) -> None: ...
 
-    def set_trace_flags(self, flags: Trace) -> None:
-        ...
+    def set_trace_flags(self, flags: Trace) -> None: ...
 
-    def untrace(self) -> None:
-        ...
+    def untrace(self) -> None: ...
 
     def encrypt_password(
         self, passwd: bytes, user: bytes, algorithm: Optional[bytes] = None
-    ) -> bytes:
-        ...
+    ) -> bytes: ...
 
-    def make_empty_result(self, exec_status: int) -> "PGresult":
-        ...
+    def make_empty_result(self, exec_status: int) -> "PGresult": ...
 
     @property
-    def pipeline_status(self) -> int:
-        ...
+    def pipeline_status(self) -> int: ...
 
-    def enter_pipeline_mode(self) -> None:
-        ...
+    def enter_pipeline_mode(self) -> None: ...
 
-    def exit_pipeline_mode(self) -> None:
-        ...
+    def exit_pipeline_mode(self) -> None: ...
 
-    def pipeline_sync(self) -> None:
-        ...
+    def pipeline_sync(self) -> None: ...
 
-    def send_flush_request(self) -> None:
-        ...
+    def send_flush_request(self) -> None: ...
 
 
 class PGresult(Protocol):
-    def clear(self) -> None:
-        ...
+    def clear(self) -> None: ...
 
     @property
-    def status(self) -> int:
-        ...
+    def status(self) -> int: ...
 
     @property
-    def error_message(self) -> bytes:
-        ...
+    def error_message(self) -> bytes: ...
 
-    def error_field(self, fieldcode: int) -> Optional[bytes]:
-        ...
+    def error_field(self, fieldcode: int) -> Optional[bytes]: ...
 
     @property
-    def ntuples(self) -> int:
-        ...
+    def ntuples(self) -> int: ...
 
     @property
-    def nfields(self) -> int:
-        ...
+    def nfields(self) -> int: ...
 
-    def fname(self, column_number: int) -> Optional[bytes]:
-        ...
+    def fname(self, column_number: int) -> Optional[bytes]: ...
 
-    def ftable(self, column_number: int) -> int:
-        ...
+    def ftable(self, column_number: int) -> int: ...
 
-    def ftablecol(self, column_number: int) -> int:
-        ...
+    def ftablecol(self, column_number: int) -> int: ...
 
-    def fformat(self, column_number: int) -> int:
-        ...
+    def fformat(self, column_number: int) -> int: ...
 
-    def ftype(self, column_number: int) -> int:
-        ...
+    def ftype(self, column_number: int) -> int: ...
 
-    def fmod(self, column_number: int) -> int:
-        ...
+    def fmod(self, column_number: int) -> int: ...
 
-    def fsize(self, column_number: int) -> int:
-        ...
+    def fsize(self, column_number: int) -> int: ...
 
     @property
-    def binary_tuples(self) -> int:
-        ...
+    def binary_tuples(self) -> int: ...
 
-    def get_value(self, row_number: int, column_number: int) -> Optional[bytes]:
-        ...
+    def get_value(self, row_number: int, column_number: int) -> Optional[bytes]: ...
 
     @property
-    def nparams(self) -> int:
-        ...
+    def nparams(self) -> int: ...
 
-    def param_type(self, param_number: int) -> int:
-        ...
+    def param_type(self, param_number: int) -> int: ...
 
     @property
-    def command_status(self) -> Optional[bytes]:
-        ...
+    def command_status(self) -> Optional[bytes]: ...
 
     @property
-    def command_tuples(self) -> Optional[int]:
-        ...
+    def command_tuples(self) -> Optional[int]: ...
 
     @property
-    def oid_value(self) -> int:
-        ...
+    def oid_value(self) -> int: ...
 
-    def set_attributes(self, descriptions: List["PGresAttDesc"]) -> None:
-        ...
+    def set_attributes(self, descriptions: List["PGresAttDesc"]) -> None: ...
 
 
 class PGcancel(Protocol):
-    def free(self) -> None:
-        ...
+    def free(self) -> None: ...
 
-    def cancel(self) -> None:
-        ...
+    def cancel(self) -> None: ...
 
 
 class Conninfo(Protocol):
     @classmethod
-    def get_defaults(cls) -> List["ConninfoOption"]:
-        ...
+    def get_defaults(cls) -> List["ConninfoOption"]: ...
 
     @classmethod
-    def parse(cls, conninfo: bytes) -> List["ConninfoOption"]:
-        ...
+    def parse(cls, conninfo: bytes) -> List["ConninfoOption"]: ...
 
     @classmethod
-    def _options_from_array(cls, opts: Sequence[Any]) -> List["ConninfoOption"]:
-        ...
+    def _options_from_array(cls, opts: Sequence[Any]) -> List["ConninfoOption"]: ...
 
 
 class Escaping(Protocol):
-    def __init__(self, conn: Optional[PGconn] = None):
-        ...
+    def __init__(self, conn: Optional[PGconn] = None): ...
 
-    def escape_literal(self, data: Buffer) -> bytes:
-        ...
+    def escape_literal(self, data: Buffer) -> bytes: ...
 
-    def escape_identifier(self, data: Buffer) -> bytes:
-        ...
+    def escape_identifier(self, data: Buffer) -> bytes: ...
 
-    def escape_string(self, data: Buffer) -> bytes:
-        ...
+    def escape_string(self, data: Buffer) -> bytes: ...
 
-    def escape_bytea(self, data: Buffer) -> bytes:
-        ...
+    def escape_bytea(self, data: Buffer) -> bytes: ...
 
-    def unescape_bytea(self, data: Buffer) -> bytes:
-        ...
+    def unescape_bytea(self, data: Buffer) -> bytes: ...
index 4c2f7781b5eba777b4df4c852c1c196a948a3e07..07e0dbcaf294c9361a3c4e5c245c7e26691ceee5 100644 (file)
@@ -8,11 +8,10 @@ import functools
 from typing import Any, Callable, Dict, List, Optional, NamedTuple, NoReturn
 from typing import TYPE_CHECKING, Protocol, Sequence, Tuple, Type
 from collections import namedtuple
-from typing_extensions import TypeAlias
 
 from . import pq
 from . import errors as e
-from ._compat import TypeVar
+from ._compat import TypeAlias, TypeVar
 from ._encodings import _as_python_identifier
 
 if TYPE_CHECKING:
@@ -44,8 +43,7 @@ class RowMaker(Protocol[Row]):
     Typically, `!RowMaker` functions are returned by `RowFactory`.
     """
 
-    def __call__(self, __values: Sequence[Any]) -> Row:
-        ...
+    def __call__(self, __values: Sequence[Any]) -> Row: ...
 
 
 class RowFactory(Protocol[Row]):
@@ -62,8 +60,7 @@ class RowFactory(Protocol[Row]):
     use the values to create a dictionary for each record.
     """
 
-    def __call__(self, __cursor: "Cursor[Any]") -> RowMaker[Row]:
-        ...
+    def __call__(self, __cursor: "Cursor[Any]") -> RowMaker[Row]: ...
 
 
 class AsyncRowFactory(Protocol[Row]):
@@ -71,8 +68,7 @@ class AsyncRowFactory(Protocol[Row]):
     Like `RowFactory`, taking an async cursor as argument.
     """
 
-    def __call__(self, __cursor: "AsyncCursor[Any]") -> RowMaker[Row]:
-        ...
+    def __call__(self, __cursor: "AsyncCursor[Any]") -> RowMaker[Row]: ...
 
 
 class BaseRowFactory(Protocol[Row]):
@@ -80,8 +76,7 @@ class BaseRowFactory(Protocol[Row]):
     Like `RowFactory`, taking either type of cursor as argument.
     """
 
-    def __call__(self, __cursor: "BaseCursor[Any, Any]") -> RowMaker[Row]:
-        ...
+    def __call__(self, __cursor: "BaseCursor[Any, Any]") -> RowMaker[Row]: ...
 
 
 TupleRow: TypeAlias = Tuple[Any, ...]
@@ -212,6 +207,28 @@ def kwargs_row(func: Callable[..., T]) -> BaseRowFactory[T]:
     return kwargs_row_
 
 
+def scalar_row(cursor: "BaseCursor[Any, Any]") -> "RowMaker[Any]":
+    """
+    Generate a row factory returning the first column
+    as a scalar value.
+    """
+    res = cursor.pgresult
+    if not res:
+        return no_result
+
+    nfields = _get_nfields(res)
+    if nfields is None:
+        return no_result
+
+    if nfields < 1:
+        raise e.ProgrammingError("at least one column expected")
+
+    def scalar_row_(values: Sequence[Any]) -> Any:
+        return values[0]
+
+    return scalar_row_
+
+
 def no_result(values: Sequence[Any]) -> NoReturn:
     """A `RowMaker` that always fail.
 
index 1c6e77aa10f96b33647029e9be2083edbf29c7a4..2f5f44739d96ac1be8d3720a239351885a4076f8 100644 (file)
@@ -222,8 +222,7 @@ class ServerCursor(ServerCursorMixin["Connection[Any]", Row], Cursor[Row]):
         *,
         scrollable: Optional[bool] = None,
         withhold: bool = False,
-    ):
-        ...
+    ): ...
 
     @overload
     def __init__(
@@ -234,8 +233,7 @@ class ServerCursor(ServerCursorMixin["Connection[Any]", Row], Cursor[Row]):
         row_factory: RowFactory[Row],
         scrollable: Optional[bool] = None,
         withhold: bool = False,
-    ):
-        ...
+    ): ...
 
     def __init__(
         self,
@@ -363,8 +361,7 @@ class AsyncServerCursor(
         *,
         scrollable: Optional[bool] = None,
         withhold: bool = False,
-    ):
-        ...
+    ): ...
 
     @overload
     def __init__(
@@ -375,8 +372,7 @@ class AsyncServerCursor(
         row_factory: AsyncRowFactory[Row],
         scrollable: Optional[bool] = None,
         withhold: bool = False,
-    ):
-        ...
+    ): ...
 
     def __init__(
         self,
index d793bf389245a6140664318ec1c056d817678e1f..a94f77f6efcf517703c56a25b6db96b26396c5bd 100644 (file)
@@ -55,7 +55,7 @@ class Composable(ABC):
         return f"{self.__class__.__name__}({self._obj!r})"
 
     @abstractmethod
-    def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
+    def as_bytes(self, context: Optional[AdaptContext] = None) -> bytes:
         """
         Return the value of the object as bytes.
 
@@ -69,7 +69,7 @@ class Composable(ABC):
         """
         raise NotImplementedError
 
-    def as_string(self, context: Optional[AdaptContext]) -> str:
+    def as_string(self, context: Optional[AdaptContext] = None) -> str:
         """
         Return the value of the object as string.
 
@@ -130,7 +130,7 @@ class Composed(Composable):
         seq = [obj if isinstance(obj, Composable) else Literal(obj) for obj in seq]
         super().__init__(seq)
 
-    def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
+    def as_bytes(self, context: Optional[AdaptContext] = None) -> bytes:
         return b"".join(obj.as_bytes(context) for obj in self._obj)
 
     def __iter__(self) -> Iterator[Composable]:
@@ -200,10 +200,10 @@ class SQL(Composable):
         if not isinstance(obj, str):
             raise TypeError(f"SQL values must be strings, got {obj!r} instead")
 
-    def as_string(self, context: Optional[AdaptContext]) -> str:
+    def as_string(self, context: Optional[AdaptContext] = None) -> str:
         return self._obj
 
-    def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
+    def as_bytes(self, context: Optional[AdaptContext] = None) -> bytes:
         conn = context.connection if context else None
         enc = conn_encoding(conn)
         return self._obj.encode(enc)
@@ -362,15 +362,22 @@ class Identifier(Composable):
     def __repr__(self) -> str:
         return f"{self.__class__.__name__}({', '.join(map(repr, self._obj))})"
 
-    def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
+    def as_bytes(self, context: Optional[AdaptContext] = None) -> bytes:
         conn = context.connection if context else None
-        if not conn:
-            raise ValueError("a connection is necessary for Identifier")
-        esc = Escaping(conn.pgconn)
-        enc = conn_encoding(conn)
-        escs = [esc.escape_identifier(s.encode(enc)) for s in self._obj]
+        if conn:
+            esc = Escaping(conn.pgconn)
+            enc = conn_encoding(conn)
+            escs = [esc.escape_identifier(s.encode(enc)) for s in self._obj]
+        else:
+            escs = [self._escape_identifier(s.encode()) for s in self._obj]
         return b".".join(escs)
 
+    def _escape_identifier(self, s: bytes) -> bytes:
+        """
+        Approximation of PQescapeIdentifier taking no connection.
+        """
+        return b'"' + s.replace(b'"', b'""') + b'"'
+
 
 class Literal(Composable):
     """
@@ -393,7 +400,7 @@ class Literal(Composable):
 
     """
 
-    def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
+    def as_bytes(self, context: Optional[AdaptContext] = None) -> bytes:
         tx = Transformer.from_context(context)
         return tx.as_literal(self._obj)
 
@@ -452,11 +459,11 @@ class Placeholder(Composable):
 
         return f"{self.__class__.__name__}({', '.join(parts)})"
 
-    def as_string(self, context: Optional[AdaptContext]) -> str:
+    def as_string(self, context: Optional[AdaptContext] = None) -> str:
         code = self._format.value
         return f"%({self._obj}){code}" if self._obj else f"%{code}"
 
-    def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
+    def as_bytes(self, context: Optional[AdaptContext] = None) -> bytes:
         conn = context.connection if context else None
         enc = conn_encoding(conn)
         return self.as_string(context).encode(enc)
index e15c1129951cb4dbef3380ac15ddd9681015d7c3..6e20dd3cef764aab3fecc893c517ac0fabc7e230 100644 (file)
@@ -1,10 +1,10 @@
 """
 Adapters for the enum type.
 """
+
 from enum import Enum
 from typing import Any, Dict, Generic, Optional, Mapping, Sequence
 from typing import Tuple, Type, Union, cast, TYPE_CHECKING
-from typing_extensions import TypeAlias
 
 from .. import sql
 from .. import postgres
@@ -12,7 +12,7 @@ from .. import errors as e
 from ..pq import Format
 from ..abc import AdaptContext, Query
 from ..adapt import Buffer, Dumper, Loader
-from .._compat import cache, TypeVar
+from .._compat import cache, TypeAlias, TypeVar
 from .._encodings import conn_encoding
 from .._typeinfo import TypeInfo
 
index 851a0556fdcd7f39fff42589394a715178b99bb7..5bc261f5563c0df89a1c9d283cd273bcfc7d6726 100644 (file)
@@ -6,14 +6,13 @@ Dict to hstore adaptation
 
 import re
 from typing import Dict, List, Optional, Type
-from typing_extensions import TypeAlias
 
 from .. import errors as e
 from .. import postgres
 from ..abc import Buffer, AdaptContext
 from .._oids import TEXT_OID
 from ..adapt import PyFormat, RecursiveDumper, RecursiveLoader
-from .._compat import cache
+from .._compat import cache, TypeAlias
 from .._typeinfo import TypeInfo
 
 _re_escape = re.compile(r'(["\\])')
index d672f6be8d4fb7d1697ddc1402c0419b9d951475..51f61d1a79445a9a7f0d33a036238afe63c15ec5 100644 (file)
@@ -91,12 +91,10 @@ class Multirange(MutableSequence[Range[T]]):
         return f"{{{', '.join(map(str, self._ranges))}}}"
 
     @overload
-    def __getitem__(self, index: int) -> Range[T]:
-        ...
+    def __getitem__(self, index: int) -> Range[T]: ...
 
     @overload
-    def __getitem__(self, index: slice) -> "Multirange[T]":
-        ...
+    def __getitem__(self, index: slice) -> "Multirange[T]": ...
 
     def __getitem__(self, index: Union[int, slice]) -> "Union[Range[T],Multirange[T]]":
         if isinstance(index, int):
@@ -108,12 +106,10 @@ class Multirange(MutableSequence[Range[T]]):
         return len(self._ranges)
 
     @overload
-    def __setitem__(self, index: int, value: Range[T]) -> None:
-        ...
+    def __setitem__(self, index: int, value: Range[T]) -> None: ...
 
     @overload
-    def __setitem__(self, index: slice, value: Iterable[Range[T]]) -> None:
-        ...
+    def __setitem__(self, index: slice, value: Iterable[Range[T]]) -> None: ...
 
     def __setitem__(
         self,
index 983de9a039ad1bd41665bb6dff10025eac9d6f20..76522dcbbd3917a6d09d79e05d88554567e4d136 100644 (file)
@@ -5,12 +5,12 @@ Adapters for network types.
 # Copyright (C) 2020 The Psycopg Team
 
 from typing import Callable, Optional, Type, Union, TYPE_CHECKING
-from typing_extensions import TypeAlias
 
 from .. import _oids
 from ..pq import Format
 from ..abc import AdaptContext
 from ..adapt import Buffer, Dumper, Loader
+from .._compat import TypeAlias
 
 if TYPE_CHECKING:
     import ipaddress
index f394bdac7dc13a19b5ccdee8787b508fde74c6cb..1817740fd6bd504af9b70661c5f7430d899652a5 100644 (file)
@@ -379,8 +379,7 @@ class _MixedNumericDumper(Dumper, ABC):
                 _MixedNumericDumper.int_classes = int
 
     @abstractmethod
-    def dump(self, obj: Union[Decimal, int, "numpy.integer[Any]"]) -> Buffer:
-        ...
+    def dump(self, obj: Union[Decimal, int, "numpy.integer[Any]"]) -> Buffer: ...
 
 
 class NumericDumper(_MixedNumericDumper):
index d6db0d922e899f53c130d538069fd26e7697243e..4f307b6ef38693235b542cd9d8c84fb636de732a 100644 (file)
@@ -26,6 +26,7 @@ from ._cmodule import _psycopg
 WAIT_R = Wait.R
 WAIT_W = Wait.W
 WAIT_RW = Wait.RW
+READY_NONE = Ready.NONE
 READY_R = Ready.R
 READY_W = Ready.W
 READY_RW = Ready.RW
@@ -51,16 +52,17 @@ def wait_selector(gen: PQGen[RV], fileno: int, timeout: Optional[float] = None)
     try:
         s = next(gen)
         with DefaultSelector() as sel:
+            sel.register(fileno, s)
             while True:
-                sel.register(fileno, s)
-                rlist = None
-                while not rlist:
-                    rlist = sel.select(timeout=timeout)
+                rlist = sel.select(timeout=timeout)
+                if not rlist:
+                    gen.send(READY_NONE)
+                    continue
+
                 sel.unregister(fileno)
-                # note: this line should require a cast, but mypy doesn't complain
-                ready: Ready = rlist[0][1]
-                assert s & ready
+                ready = rlist[0][1]
                 s = gen.send(ready)
+                sel.register(fileno, s)
 
     except StopIteration as ex:
         rv: RV = ex.args[0] if ex.args else None
@@ -92,7 +94,7 @@ def wait_conn(gen: PQGenConn[RV], timeout: Optional[float] = None) -> RV:
                 sel.unregister(fileno)
                 if not rlist:
                     raise e.ConnectionTimeout("connection timeout expired")
-                ready: Ready = rlist[0][1]  # type: ignore[assignment]
+                ready = rlist[0][1]
                 fileno, s = gen.send(ready)
 
     except StopIteration as ex:
@@ -110,7 +112,7 @@ async def wait_async(
         `Ready` values when it would block.
     :param fileno: the file descriptor to wait on.
     :param timeout: timeout (in seconds) to check for other interrupt, e.g.
-        to allow Ctrl-C. If zero or None, wait indefinitely.
+        to allow Ctrl-C. If zero, wait indefinitely.
     :return: whatever `!gen` returns on completion.
 
     Behave like in `wait()`, but exposing an `asyncio` interface.
@@ -119,12 +121,12 @@ async def wait_async(
     # Not sure this is the best implementation but it's a start.
     ev = Event()
     loop = get_event_loop()
-    ready: Ready
+    ready: int
     s: Wait
 
     def wakeup(state: Ready) -> None:
         nonlocal ready
-        ready |= state  # type: ignore[assignment]
+        ready |= state
         ev.set()
 
     try:
@@ -135,19 +137,19 @@ async def wait_async(
             if not reader and not writer:
                 raise e.InternalError(f"bad poll status: {s}")
             ev.clear()
-            ready = 0  # type: ignore[assignment]
+            ready = 0
             if reader:
                 loop.add_reader(fileno, wakeup, READY_R)
             if writer:
                 loop.add_writer(fileno, wakeup, READY_W)
             try:
-                if timeout is None:
-                    await ev.wait()
-                else:
+                if timeout is not None:
                     try:
                         await wait_for(ev.wait(), timeout)
                     except TimeoutError:
                         pass
+                else:
+                    await ev.wait()
             finally:
                 if reader:
                     loop.remove_reader(fileno)
@@ -155,6 +157,9 @@ async def wait_async(
                     loop.remove_writer(fileno)
             s = gen.send(ready)
 
+    except OSError as ex:
+        # Assume the connection was closed
+        raise e.OperationalError(str(ex))
     except StopIteration as ex:
         rv: RV = ex.args[0] if ex.args else None
         return rv
@@ -245,9 +250,10 @@ def wait_select(gen: PQGen[RV], fileno: int, timeout: Optional[float] = None) ->
             if wl:
                 ready |= READY_W
             if not ready:
+                gen.send(READY_NONE)
                 continue
-            # assert s & ready
-            s = gen.send(ready)  # type: ignore
+
+            s = gen.send(ready)
 
     except StopIteration as ex:
         rv: RV = ex.args[0] if ex.args else None
@@ -285,24 +291,22 @@ def wait_epoll(gen: PQGen[RV], fileno: int, timeout: Optional[float] = None) ->
         s = next(gen)
 
         if timeout is None or timeout < 0:
-            timeout = 0
-        else:
-            timeout = int(timeout * 1000.0)
+            timeout = 0.0
 
         with select.epoll() as epoll:
             evmask = _epoll_evmasks[s]
             epoll.register(fileno, evmask)
             while True:
-                fileevs = None
-                while not fileevs:
-                    fileevs = epoll.poll(timeout)
+                fileevs = epoll.poll(timeout)
+                if not fileevs:
+                    gen.send(READY_NONE)
+                    continue
                 ev = fileevs[0][1]
                 ready = 0
                 if ev & ~select.EPOLLOUT:
                     ready = READY_R
                 if ev & ~select.EPOLLIN:
                     ready |= READY_W
-                # assert s & ready
                 s = gen.send(ready)
                 evmask = _epoll_evmasks[s]
                 epoll.modify(fileno, evmask)
@@ -340,16 +344,17 @@ def wait_poll(gen: PQGen[RV], fileno: int, timeout: Optional[float] = None) -> R
         evmask = _poll_evmasks[s]
         poll.register(fileno, evmask)
         while True:
-            fileevs = None
-            while not fileevs:
-                fileevs = poll.poll(timeout)
+            fileevs = poll.poll(timeout)
+            if not fileevs:
+                gen.send(READY_NONE)
+                continue
+
             ev = fileevs[0][1]
             ready = 0
             if ev & ~select.POLLOUT:
                 ready = READY_R
             if ev & ~select.POLLIN:
                 ready |= READY_W
-            # assert s & ready
             s = gen.send(ready)
             evmask = _poll_evmasks[s]
             poll.modify(fileno, evmask)
index f734c40ec99a0f178ac44795df1bf0e318b36882..fbb544677844d2d6b7634ebd9fd033af2116596d 100644 (file)
@@ -74,7 +74,7 @@ test =
     pytest-randomly >= 3.5
 dev =
     ast-comments >= 1.1.2
-    black >= 23.1.0
+    black >= 24.1.0
     codespell >= 2.2
     dnspython >= 2.1
     flake8 >= 4.0
index 2ae629c2d4d3d195def647f96d7bce3cdbab83b6..40a061b1e954b4de28869853a004d04fcd5459b7 100644 (file)
@@ -1,3 +1,3 @@
 [flake8]
 max-line-length = 88
-ignore = W503, E203
+ignore = W503, E203, E704
index a51fce5e28d5f2117e1a161fa6c3c2a8663ede9c..70335cf8995731ea0d156baa8f9026a17ce117dd 100644 (file)
@@ -18,9 +18,11 @@ from psycopg._encodings import conninfo_encoding
 cdef object WAIT_W = Wait.W
 cdef object WAIT_R = Wait.R
 cdef object WAIT_RW = Wait.RW
+cdef object PY_READY_NONE = Ready.NONE
 cdef object PY_READY_R = Ready.R
 cdef object PY_READY_W = Ready.W
 cdef object PY_READY_RW = Ready.RW
+cdef int READY_NONE = Ready.NONE
 cdef int READY_R = Ready.R
 cdef int READY_W = Ready.W
 cdef int READY_RW = Ready.RW
@@ -96,15 +98,19 @@ def send(pq.PGconn pgconn) -> PQGen[None]:
     to retrieve the results available.
     """
     cdef libpq.PGconn *pgconn_ptr = pgconn._pgconn_ptr
-    cdef int status
+    cdef int ready
     cdef int cires
 
     while True:
         if pgconn.flush() == 0:
             break
 
-        status = yield WAIT_RW
-        if status & READY_R:
+        while True:
+            ready = yield WAIT_RW
+            if ready:
+                break
+
+        if ready & READY_R:
             with nogil:
                 # This call may read notifies which will be saved in the
                 # PGconn buffer and passed to Python later.
@@ -166,11 +172,16 @@ def fetch(pq.PGconn pgconn) -> PQGen[Optional[PGresult]]:
     cdef libpq.PGconn *pgconn_ptr = pgconn._pgconn_ptr
     cdef int cires, ibres
     cdef libpq.PGresult *pgres
+    cdef object ready
 
     with nogil:
         ibres = libpq.PQisBusy(pgconn_ptr)
     if ibres:
-        yield WAIT_R
+        while True:
+            ready = yield WAIT_R
+            if ready:
+                break
+
         while True:
             with nogil:
                 cires = libpq.PQconsumeInput(pgconn_ptr)
@@ -182,7 +193,10 @@ def fetch(pq.PGconn pgconn) -> PQGen[Optional[PGresult]]:
                     f"consuming input failed: {error_message(pgconn)}")
             if not ibres:
                 break
-            yield WAIT_R
+            while True:
+                ready = yield WAIT_R
+                if ready:
+                    break
 
     _consume_notifies(pgconn)
 
@@ -211,7 +225,10 @@ def pipeline_communicate(
     cdef pq.PGresult r
 
     while True:
-        ready = yield WAIT_RW
+        while True:
+            ready = yield WAIT_RW
+            if ready:
+                break
 
         if ready & READY_R:
             with nogil:
index 33c54c513b8d0e4ed5afbd4db7cb9b5e8c792dd8..3a6cc6e255eb0c4a41cfe38d5030f684ea8ef898 100644 (file)
@@ -51,7 +51,7 @@ static int
 wait_c_impl(int fileno, int wait, float timeout)
 {
     int select_rv;
-    int rv = 0;
+    int rv = -1;
 
 #if defined(HAVE_POLL) && !defined(HAVE_BROKEN_POLL)
 
@@ -83,11 +83,14 @@ retry_eintr:
         goto retry_eintr;
     }
 
-    if (select_rv < 0) { goto error; }
     if (PyErr_CheckSignals()) { goto finally; }
+    if (select_rv < 0) { goto finally; }  /* poll error */
 
-    if (input_fd.events & POLLIN) { rv |= SELECT_EV_READ; }
-    if (input_fd.events & POLLOUT) { rv |= SELECT_EV_WRITE; }
+    rv = 0;  /* success, maybe with timeout */
+    if (select_rv >= 0) {
+        if (input_fd.events & POLLIN) { rv |= SELECT_EV_READ; }
+        if (input_fd.events & POLLOUT) { rv |= SELECT_EV_WRITE; }
+    }
 
 #else
 
@@ -135,11 +138,14 @@ retry_eintr:
         goto retry_eintr;
     }
 
-    if (select_rv < 0) { goto error; }
     if (PyErr_CheckSignals()) { goto finally; }
+    if (select_rv < 0) { goto error; }  /* select error */
 
-    if (FD_ISSET(fileno, &ifds)) { rv |= SELECT_EV_READ; }
-    if (FD_ISSET(fileno, &ofds)) { rv |= SELECT_EV_WRITE; }
+    rv = 0;
+    if (select_rv > 0) {
+        if (FD_ISSET(fileno, &ifds)) { rv |= SELECT_EV_READ; }
+        if (FD_ISSET(fileno, &ofds)) { rv |= SELECT_EV_WRITE; }
+    }
 
 #endif  /* HAVE_POLL */
 
@@ -147,6 +153,8 @@ retry_eintr:
 
 error:
 
+    rv = -1;
+
 #ifdef MS_WINDOWS
     if (select_rv == SOCKET_ERROR) {
         PyErr_SetExcFromWindowsErr(PyExc_OSError, WSAGetLastError());
@@ -162,7 +170,7 @@ error:
 
 finally:
 
-    return -1;
+    return rv;
 
 }
     """
@@ -191,8 +199,8 @@ def wait_c(gen: PQGen[RV], int fileno, timeout = None) -> RV:
 
         while True:
             ready = wait_c_impl(fileno, wait, ctimeout)
-            if ready == 0:
-                continue
+            if ready == READY_NONE:
+                pyready = <PyObject *>PY_READY_NONE
             elif ready == READY_R:
                 pyready = <PyObject *>PY_READY_R
             elif ready == READY_RW:
index 0ec4179a20c9b7de9f68af2e65f3e2e91315f0df..4b0784bded5cce4661d1a40038d9a614ecde7b21 100644 (file)
@@ -4,6 +4,7 @@ Cython adapters for date/time types.
 
 # Copyright (C) 2021 The Psycopg Team
 
+from libc.stdint cimport int64_t
 from libc.string cimport memset, strchr
 from cpython cimport datetime as cdt
 from cpython.dict cimport PyDict_GetItem
@@ -391,7 +392,7 @@ cdef class DateLoader(CLoader):
         if length != 10:
             self._error_date(data, "unexpected length")
 
-        cdef int vals[3]
+        cdef int64_t vals[3]
         memset(vals, 0, sizeof(vals))
 
         cdef const char *ptr
@@ -437,7 +438,7 @@ cdef class TimeLoader(CLoader):
 
     cdef object cload(self, const char *data, size_t length):
 
-        cdef int vals[3]
+        cdef int64_t vals[3]
         memset(vals, 0, sizeof(vals))
         cdef const char *ptr
         cdef const char *end = data + length
@@ -494,7 +495,7 @@ cdef class TimetzLoader(CLoader):
 
     cdef object cload(self, const char *data, size_t length):
 
-        cdef int vals[3]
+        cdef int64_t vals[3]
         memset(vals, 0, sizeof(vals))
         cdef const char *ptr
         cdef const char *end = data + length
@@ -581,7 +582,7 @@ cdef class TimestampLoader(CLoader):
         if self._order == ORDER_PGDM or self._order == ORDER_PGMD:
             return self._cload_pg(data, end)
 
-        cdef int vals[6]
+        cdef int64_t vals[6]
         memset(vals, 0, sizeof(vals))
         cdef const char *ptr
 
@@ -611,7 +612,7 @@ cdef class TimestampLoader(CLoader):
             raise _get_timestamp_load_error(self._pgconn, data, ex) from None
 
     cdef object _cload_pg(self, const char *data, const char *end):
-        cdef int vals[4]
+        cdef int64_t vals[4]
         memset(vals, 0, sizeof(vals))
         cdef const char *ptr
 
@@ -721,7 +722,7 @@ cdef class TimestamptzLoader(_BaseTimestamptzLoader):
         if end[-1] == b'C':  # ends with BC
             raise _get_timestamp_load_error(self._pgconn, data) from None
 
-        cdef int vals[6]
+        cdef int64_t vals[6]
         memset(vals, 0, sizeof(vals))
 
         # Parse the first 6 groups of digits (date and time)
@@ -862,9 +863,10 @@ cdef class IntervalLoader(CLoader):
         if self._style == INTERVALSTYLE_OTHERS:
             return self._cload_notimpl(data, length)
 
-        cdef int days = 0, secs = 0, us = 0
+        cdef int days = 0, us = 0
+        cdef int64_t secs = 0
         cdef char sign
-        cdef int val
+        cdef int64_t val
         cdef const char *ptr = data
         cdef const char *sep
         cdef const char *end = ptr + length
@@ -908,7 +910,7 @@ cdef class IntervalLoader(CLoader):
                 break
 
         # Parse the time part. An eventual sign was already consumed in the loop
-        cdef int vals[3]
+        cdef int64_t vals[3]
         memset(vals, 0, sizeof(vals))
         if ptr != NULL:
             ptr = _parse_date_values(ptr, end, vals, ARRAYSIZE(vals))
@@ -918,6 +920,10 @@ cdef class IntervalLoader(CLoader):
 
             secs = vals[2] + 60 * (vals[1] + 60 * vals[0])
 
+            if secs > 86_400:
+                days += secs // 86_400
+                secs %= 86_400
+
             if ptr[0] == b'.':
                 ptr = _parse_micros(ptr + 1, &us)
 
@@ -966,11 +972,11 @@ cdef class IntervalBinaryLoader(CLoader):
         # Work only with positive values as the cdivision behaves differently
         # with negative values, and cdivision=False adds overhead.
         cdef int64_t aval = val if val >= 0 else -val
-        cdef int us, ussecs, usdays
+        cdef int64_t us, ussecs, usdays
 
-        # Group the micros in biggers stuff or timedelta_new might overflow
+        # Group the micros in bigger stuff or timedelta_new might overflow
         with cython.cdivision(True):
-            ussecs = <int>(aval // 1_000_000)
+            ussecs = <int64_t>(aval // 1_000_000)
             us = aval % 1_000_000
 
             usdays = ussecs // 86_400
@@ -988,7 +994,7 @@ cdef class IntervalBinaryLoader(CLoader):
 
 
 cdef const char *_parse_date_values(
-    const char *ptr, const char *end, int *vals, int nvals
+    const char *ptr, const char *end, int64_t *vals, int nvals
 ):
     """
     Parse *nvals* numeric values separated by non-numeric chars.
@@ -1046,7 +1052,7 @@ cdef int _parse_timezone_to_seconds(const char **bufptr, const char *end):
     cdef char sgn = ptr[0]
 
     # Parse at most three groups of digits
-    cdef int vals[3]
+    cdef int64_t vals[3]
     memset(vals, 0, sizeof(vals))
 
     ptr = _parse_date_values(ptr + 1, end, vals, ARRAYSIZE(vals))
index 2ae629c2d4d3d195def647f96d7bce3cdbab83b6..40a061b1e954b4de28869853a004d04fcd5459b7 100644 (file)
@@ -1,3 +1,3 @@
 [flake8]
 max-line-length = 88
-ignore = W503, E203
+ignore = W503, E203, E704
index 4e4fa20b0a42320e7a18e3441a1741f54d7d7f7f..d58548515bc2a649eb148b4f04fa5528800ce9ef 100644 (file)
@@ -17,9 +17,7 @@ import logging
 import threading
 from typing import Any, Callable, Coroutine, TYPE_CHECKING
 
-from typing_extensions import TypeAlias
-
-from ._compat import TypeVar
+from ._compat import TypeAlias, TypeVar
 
 logger = logging.getLogger("psycopg.pool")
 T = TypeVar("T")
index 5917ff31b71e236e45153cf818f7650414e96af2..3fc645cbeb5c0175008145f43bb899a5dd378d69 100644 (file)
@@ -14,6 +14,11 @@ if sys.version_info >= (3, 9):
 else:
     from typing import Counter, Deque
 
+if sys.version_info >= (3, 10):
+    from typing import TypeAlias
+else:
+    from typing_extensions import TypeAlias
+
 if sys.version_info >= (3, 11):
     from typing import Self
 else:
@@ -28,6 +33,7 @@ __all__ = [
     "Counter",
     "Deque",
     "Self",
+    "TypeAlias",
     "TypeVar",
 ]
 
index 6cc85a2c569e6ffd7c9abc637daf647da344cd15..07209a64c2854b7e968a75cef78752dfecfee58b 100644 (file)
@@ -8,9 +8,7 @@ from __future__ import annotations
 
 from typing import Any, Awaitable, Callable, Union, TYPE_CHECKING
 
-from typing_extensions import TypeAlias
-
-from ._compat import TypeVar
+from ._compat import TypeAlias, TypeVar
 
 if TYPE_CHECKING:
     from .pool import ConnectionPool
index a5408d08e5cf2a663f493bf1c32b353e786d0ce3..0be5d614820e927a59de563db0d86716fc2cc279 100644 (file)
@@ -26,6 +26,7 @@ logger = logging.getLogger("psycopg.pool")
 
 
 class NullConnectionPool(_BaseNullConnectionPool, ConnectionPool[CT]):
+
     def __init__(
         self,
         conninfo: str = "",
@@ -47,6 +48,7 @@ class NullConnectionPool(_BaseNullConnectionPool, ConnectionPool[CT]):
         reconnect_failed: Optional[ConnectFailedCB] = None,
         num_workers: int = 3,
     ):  # Note: min_size default value changed to 0.
+
         super().__init__(
             conninfo,
             open=open,
index 7ab999b1eaa6c1b764c1ee7fa0f0a9179c3fde5e..a021e14a0ba94d089877a3417fc6e34b932461ed 100644 (file)
@@ -913,8 +913,7 @@ class MaintenanceTask(ABC):
         pool.run_task(self)
 
     @abstractmethod
-    def _run(self, pool: ConnectionPool[Any]) -> None:
-        ...
+    def _run(self, pool: ConnectionPool[Any]) -> None: ...
 
 
 class StopWorker(MaintenanceTask):
@@ -925,6 +924,7 @@ class StopWorker(MaintenanceTask):
 
 
 class AddConnection(MaintenanceTask):
+
     def __init__(
         self,
         pool: ConnectionPool[Any],
index fd8f63782bb6206a1c2171943a8703623bfa9cc8..d0770dd778681675e697e0ebff41a1b350764453 100644 (file)
@@ -962,8 +962,7 @@ class MaintenanceTask(ABC):
         pool.run_task(self)
 
     @abstractmethod
-    async def _run(self, pool: AsyncConnectionPool[Any]) -> None:
-        ...
+    async def _run(self, pool: AsyncConnectionPool[Any]) -> None: ...
 
 
 class StopWorker(MaintenanceTask):
index 58cadc36df1f499fcfadef0ff36dcf38f5750071..40954b99605dc42388525e83f329cd50fa3ca59f 100644 (file)
@@ -27,6 +27,7 @@ logger = logging.getLogger(__name__)
 
 
 class Scheduler:
+
     def __init__(self) -> None:
         self._queue: List[Task] = []
         self._lock = Lock()
index 98b03426e0c30c36e5c08c2b855d4e33e002f5e6..cc1273f78eb90df9b1df3f6ba2829153e4c359fa 100644 (file)
@@ -73,9 +73,9 @@ def pytest_sessionstart(session):
 
 asyncio_options: Dict[str, Any] = {}
 if sys.platform == "win32":
-    asyncio_options[
-        "loop_factory"
-    ] = asyncio.WindowsSelectorEventLoopPolicy().new_event_loop
+    asyncio_options["loop_factory"] = (
+        asyncio.WindowsSelectorEventLoopPolicy().new_event_loop
+    )
 
 
 @pytest.fixture(
index a676a993f593fb40fe94032858a1799659834caa..5c106971cdb2212607d92b87c162ac1ec43d7ad4 100644 (file)
@@ -17,7 +17,7 @@ pytest-cov == 3.0.0
 pytest-randomly == 3.5.0
 
 # From the 'dev' extra
-black == 23.1.0
+black == 24.1.0
 dnspython == 2.1.0
 flake8 == 4.0.0
 types-setuptools == 57.4.0
index c873a4e66b63b385e138000a59609207e07a852b..ea98800f525f5de4821b83be53b1fab5e9c43c2b 100644 (file)
@@ -13,6 +13,8 @@
     -- Ian Bicking
 '''
 
+from __future__ import annotations
+
 __rcs_id__  = '$Id: dbapi20.py,v 1.11 2005/01/02 02:41:01 zenzen Exp $'
 __version__ = '$Revision: 1.12 $'[11:-2]
 __author__ = 'Stuart Bishop <stuart@stuartbishop.net>'
@@ -101,7 +103,7 @@ class DatabaseAPI20Test(unittest.TestCase):
     # method is to be found
     driver: Any = None
     connect_args = () # List of arguments to pass to connect
-    connect_kw_args: Dict[str, Any] = {} # Keyword arguments for connect
+    connect_kw_args: Dict[Any, Any] = {} # Keyword arguments for connect
     table_prefix = 'dbapi20test_' # If you need to specify a prefix for tables
 
     ddl1 = 'create table %sbooze (name varchar(20))' % table_prefix
index 9abeda79e9b5c13d46a4a4dfe2e3c03c4cd2e98a..37ee7ac3252d117aac3844c4318483207b40cd88 100644 (file)
@@ -119,7 +119,7 @@ def dsn_env(dsn):
             continue
         args[opt.keyword.decode()] = os.environ[opt.envvar.decode()]
 
-    return make_conninfo(**args)
+    return make_conninfo("", **args)
 
 
 @pytest.fixture(scope="session")
index 1d566b5e5dc178aefc104415e0e8932b90f27779..6a74877866887c3f5cbb167172a57072ddedc22c 100644 (file)
@@ -58,7 +58,8 @@ class Proxy:
         cdict = conninfo.conninfo_to_dict(server_dsn)
 
         # Get server params
-        host = cdict.get("host") or os.environ.get("PGHOST")
+        host = cdict.get("host") or os.environ.get("PGHOST", "")
+        assert isinstance(host, str)
         self.server_host = host if host and not host.startswith("/") else "localhost"
         self.server_port = cdict.get("port") or os.environ.get("PGPORT", "5432")
 
@@ -70,7 +71,7 @@ class Proxy:
         cdict["host"] = self.client_host
         cdict["port"] = self.client_port
         cdict["sslmode"] = "disable"  # not supported by the proxy
-        self.client_dsn = conninfo.make_conninfo(**cdict)
+        self.client_dsn = conninfo.make_conninfo("", **cdict)
 
         # The running proxy process
         self.proc = None
index 80e0c626d178d2904ce89ffcad4bb9b042dc2ca1..bc054fdb7230ce02dc7a291ede40e983211a801a 100644 (file)
@@ -24,7 +24,6 @@ def global_adapters():
 
 
 @pytest.fixture
-@pytest.mark.crdb_skip("2-phase commit")
 def tpc(svcconn):
     tpc = Tpc(svcconn)
     tpc.check_tpc()
index e2aeb7cc39164dd5f239adca2766d902816d07fb..ff8fafd885d79f5350a9eca136c58f354b0b6c62 100644 (file)
@@ -42,10 +42,11 @@ def test_bad_size(dsn, min_size, max_size):
 
 
 class MyRow(Dict[str, Any]):
-    ...
+    pass
 
 
 def test_generic_connection_type(dsn):
+
     def configure(conn: psycopg.Connection[Any]) -> None:
         set_autocommit(conn, True)
 
@@ -78,10 +79,12 @@ def test_generic_connection_type(dsn):
 
 
 def test_non_generic_connection_type(dsn):
+
     def configure(conn: psycopg.Connection[Any]) -> None:
         set_autocommit(conn, True)
 
     class MyConnection(psycopg.Connection[MyRow]):
+
         def __init__(self, *args: Any, **kwargs: Any):
             kwargs["row_factory"] = class_row(MyRow)
             super().__init__(*args, **kwargs)
@@ -638,6 +641,7 @@ def test_uniform_use(dsn):
 @pytest.mark.slow
 @pytest.mark.timing
 def test_resize(dsn):
+
     def sampler():
         sleep(0.05)  # ensure sampling happens after shrink check
         while True:
index 0bc84f8bc6428cc41f13590b121ef8c44ae7e7e2..6a699e40b489f815c31464f12f13921cbb438dc5 100644 (file)
@@ -42,7 +42,7 @@ async def test_bad_size(dsn, min_size, max_size):
 
 
 class MyRow(Dict[str, Any]):
-    ...
+    pass
 
 
 async def test_generic_connection_type(dsn):
index ddf78a693638723d0a229380e92de76af7fd60d7..a8815da200508d40d23697b46912d750f8df27eb 100644 (file)
@@ -35,6 +35,7 @@ def test_defaults(pool_cls, dsn):
 
 
 def test_connection_class(pool_cls, dsn):
+
     class MyConn(psycopg.Connection[Any]):
         pass
 
@@ -158,6 +159,7 @@ def test_configure_broken(pool_cls, dsn, caplog):
 @pytest.mark.timing
 @pytest.mark.crdb_skip("backend pid")
 def test_queue(pool_cls, dsn):
+
     def worker(n):
         t0 = time()
         with p.connection() as conn:
@@ -182,6 +184,7 @@ def test_queue(pool_cls, dsn):
 
 @pytest.mark.slow
 def test_queue_size(pool_cls, dsn):
+
     def worker(t, ev=None):
         try:
             with p.connection():
@@ -217,6 +220,7 @@ def test_queue_size(pool_cls, dsn):
 @pytest.mark.timing
 @pytest.mark.crdb_skip("backend pid")
 def test_queue_timeout(pool_cls, dsn):
+
     def worker(n):
         t0 = time()
         try:
@@ -246,6 +250,7 @@ def test_queue_timeout(pool_cls, dsn):
 @pytest.mark.slow
 @pytest.mark.timing
 def test_dead_client(pool_cls, dsn):
+
     def worker(i, timeout):
         try:
             with p.connection(timeout=timeout) as conn:
@@ -273,6 +278,7 @@ def test_dead_client(pool_cls, dsn):
 @pytest.mark.timing
 @pytest.mark.crdb_skip("backend pid")
 def test_queue_timeout_override(pool_cls, dsn):
+
     def worker(n):
         t0 = time()
         timeout = 0.25 if n == 3 else None
@@ -382,6 +388,7 @@ def test_close_connection_on_pool_close(pool_cls, dsn):
 
 
 def test_closed_queue(pool_cls, dsn):
+
     def w1():
         with p.connection() as conn:
             e1.set()  # Tell w0 that w1 got a connection
@@ -493,6 +500,7 @@ def test_jitter(pool_cls):
 @pytest.mark.slow
 @pytest.mark.timing
 def test_stats_measures(pool_cls, dsn):
+
     def worker(n):
         with p.connection() as conn:
             conn.execute("select pg_sleep(0.2)")
@@ -532,6 +540,7 @@ def test_stats_measures(pool_cls, dsn):
 @pytest.mark.slow
 @pytest.mark.timing
 def test_stats_usage(pool_cls, dsn):
+
     def worker(n):
         try:
             with p.connection(timeout=0.3) as conn:
@@ -613,6 +622,7 @@ def test_check_init(pool_cls, dsn):
 
 @pytest.mark.slow
 def test_check_timeout(pool_cls, dsn):
+
     def check(conn):
         raise Exception()
 
index fbe698df68658a6cf77e65e2fe0cce8dae3f12ff..c54014572cda00ba6f33c4ef88fc05e9120855e4 100644 (file)
@@ -40,10 +40,11 @@ def test_bad_size(dsn, min_size, max_size):
 
 
 class MyRow(Dict[str, Any]):
-    ...
+    pass
 
 
 def test_generic_connection_type(dsn):
+
     def configure(conn: psycopg.Connection[Any]) -> None:
         set_autocommit(conn, True)
 
@@ -76,10 +77,12 @@ def test_generic_connection_type(dsn):
 
 
 def test_non_generic_connection_type(dsn):
+
     def configure(conn: psycopg.Connection[Any]) -> None:
         set_autocommit(conn, True)
 
     class MyConnection(psycopg.Connection[MyRow]):
+
         def __init__(self, *args: Any, **kwargs: Any):
             kwargs["row_factory"] = class_row(MyRow)
             super().__init__(*args, **kwargs)
index 09c0e21509050d9bf71c3a6b7022d1bc75e8960f..b610045ccb8bfd1357c35b725936a56562f14ceb 100644 (file)
@@ -40,7 +40,7 @@ async def test_bad_size(dsn, min_size, max_size):
 
 
 class MyRow(Dict[str, Any]):
-    ...
+    pass
 
 
 async def test_generic_connection_type(dsn):
index a49f11685b069c33d9f7958e3513308c312d75e0..0fb65ad90ea633c4b969d2a16478916e1dee29ac 100644 (file)
@@ -1,6 +1,7 @@
 """
 A quick and rough performance comparison of text vs. binary Decimal adaptation
 """
+
 from random import randrange
 from decimal import Decimal
 import psycopg
index ec952293a18d79209ab7cce649935f9d596201fc..74cc04d1f8635f6b3a1ab6b9eccafe00757edb7f 100644 (file)
@@ -7,6 +7,7 @@ We do not fetch results explicitly (using cursor.fetch*()), this is
 handled by execute() calls when pgconn socket is read-ready, which
 happens when the output buffer is full.
 """
+
 import argparse
 import asyncio
 import logging
index 150d7747793f2b82d4355ae9f28c450de92c38a6..a14fb93e6569bbc5a53dc819e897be2a539d287d 100644 (file)
@@ -6,7 +6,7 @@ import threading
 import subprocess as sp
 from asyncio import create_task
 from asyncio.queues import Queue
-from typing import List, Tuple
+from typing import List
 
 import pytest
 
@@ -58,50 +58,6 @@ async def test_concurrent_execution(aconn_cls, dsn):
     assert time.time() - t0 < 0.8, "something broken in concurrency"
 
 
-@pytest.mark.slow
-@pytest.mark.timing
-@pytest.mark.crdb_skip("notify")
-async def test_notifies(aconn_cls, aconn, dsn):
-    nconn = await aconn_cls.connect(dsn, autocommit=True)
-    npid = nconn.pgconn.backend_pid
-
-    async def notifier():
-        cur = nconn.cursor()
-        await asyncio.sleep(0.25)
-        await cur.execute("notify foo, '1'")
-        await asyncio.sleep(0.25)
-        await cur.execute("notify foo, '2'")
-        await nconn.close()
-
-    async def receiver():
-        await aconn.set_autocommit(True)
-        cur = aconn.cursor()
-        await cur.execute("listen foo")
-        gen = aconn.notifies()
-        async for n in gen:
-            ns.append((n, time.time()))
-            if len(ns) >= 2:
-                await gen.aclose()
-
-    ns: List[Tuple[psycopg.Notify, float]] = []
-    t0 = time.time()
-    workers = [notifier(), receiver()]
-    await asyncio.gather(*workers)
-    assert len(ns) == 2
-
-    n, t1 = ns[0]
-    assert n.pid == npid
-    assert n.channel == "foo"
-    assert n.payload == "1"
-    assert t1 - t0 == pytest.approx(0.25, abs=0.05)
-
-    n, t1 = ns[1]
-    assert n.pid == npid
-    assert n.channel == "foo"
-    assert n.payload == "2"
-    assert t1 - t0 == pytest.approx(0.5, abs=0.05)
-
-
 async def canceller(aconn, errors):
     try:
         await asyncio.sleep(0.5)
index 8456dba45aa8237af43990f564b212b0734f5b6e..f17f8d2a06f85f3097dad4e685e740b52b9ba732 100644 (file)
@@ -9,7 +9,7 @@ import weakref
 from typing import Any, List
 
 import psycopg
-from psycopg import Notify, pq, errors as e
+from psycopg import pq, errors as e
 from psycopg.rows import tuple_row
 from psycopg.conninfo import conninfo_to_dict, timeout_from_conninfo
 
@@ -34,6 +34,7 @@ def test_connect_bad(conn_cls):
 
 
 def test_connect_str_subclass(conn_cls, dsn):
+
     class MyString(str):
         pass
 
@@ -467,6 +468,7 @@ def test_connect_args(
     ],
 )
 def test_connect_badargs(conn_cls, monkeypatch, pgconn, args, kwargs, exctype):
+
     def fake_connect(conninfo):
         return pgconn
         yield
@@ -526,47 +528,6 @@ def test_notice_handlers(conn, caplog):
         conn.remove_notice_handler(cb1)
 
 
-@pytest.mark.crdb_skip("notify")
-def test_notify_handlers(conn):
-    nots1 = []
-    nots2 = []
-
-    def cb1(n):
-        nots1.append(n)
-
-    conn.add_notify_handler(cb1)
-    conn.add_notify_handler(lambda n: nots2.append(n))
-
-    conn.set_autocommit(True)
-    cur = conn.cursor()
-    cur.execute("listen foo")
-    cur.execute("notify foo, 'n1'")
-
-    assert len(nots1) == 1
-    n = nots1[0]
-    assert n.channel == "foo"
-    assert n.payload == "n1"
-    assert n.pid == conn.pgconn.backend_pid
-
-    assert len(nots2) == 1
-    assert nots2[0] == nots1[0]
-
-    conn.remove_notify_handler(cb1)
-    cur.execute("notify foo, 'n2'")
-
-    assert len(nots1) == 1
-    assert len(nots2) == 2
-    n = nots2[1]
-    assert isinstance(n, Notify)
-    assert n.channel == "foo"
-    assert n.payload == "n2"
-    assert n.pid == conn.pgconn.backend_pid
-    assert hash(n)
-
-    with pytest.raises(ValueError):
-        conn.remove_notify_handler(cb1)
-
-
 def test_execute(conn):
     cur = conn.execute("select %s, %s", [10, 20])
     assert cur.fetchone() == (10, 20)
@@ -642,6 +603,7 @@ def test_cursor_factory(conn):
 
 
 def test_cursor_factory_connect(conn_cls, dsn):
+
     class MyCursor(psycopg.Cursor[psycopg.rows.Row]):
         pass
 
index 2e950aff4f2016fd32e138e4aaa1595bc8576346..d7aa7ca8b41898e927a5e1c9dbe15fc30420d75c 100644 (file)
@@ -6,7 +6,7 @@ import weakref
 from typing import Any, List
 
 import psycopg
-from psycopg import Notify, pq, errors as e
+from psycopg import pq, errors as e
 from psycopg.rows import tuple_row
 from psycopg.conninfo import conninfo_to_dict, timeout_from_conninfo
 
@@ -526,47 +526,6 @@ async def test_notice_handlers(aconn, caplog):
         aconn.remove_notice_handler(cb1)
 
 
-@pytest.mark.crdb_skip("notify")
-async def test_notify_handlers(aconn):
-    nots1 = []
-    nots2 = []
-
-    def cb1(n):
-        nots1.append(n)
-
-    aconn.add_notify_handler(cb1)
-    aconn.add_notify_handler(lambda n: nots2.append(n))
-
-    await aconn.set_autocommit(True)
-    cur = aconn.cursor()
-    await cur.execute("listen foo")
-    await cur.execute("notify foo, 'n1'")
-
-    assert len(nots1) == 1
-    n = nots1[0]
-    assert n.channel == "foo"
-    assert n.payload == "n1"
-    assert n.pid == aconn.pgconn.backend_pid
-
-    assert len(nots2) == 1
-    assert nots2[0] == nots1[0]
-
-    aconn.remove_notify_handler(cb1)
-    await cur.execute("notify foo, 'n2'")
-
-    assert len(nots1) == 1
-    assert len(nots2) == 2
-    n = nots2[1]
-    assert isinstance(n, Notify)
-    assert n.channel == "foo"
-    assert n.payload == "n2"
-    assert n.pid == aconn.pgconn.backend_pid
-    assert hash(n)
-
-    with pytest.raises(ValueError):
-        aconn.remove_notify_handler(cb1)
-
-
 async def test_execute(aconn):
     cur = await aconn.execute("select %s, %s", [10, 20])
     assert await cur.fetchone() == (10, 20)
@@ -637,7 +596,7 @@ async def test_cursor_factory(aconn):
     async with aconn.cursor() as cur:
         assert isinstance(cur, MyCursor)
 
-    async with (await aconn.execute("select 1")) as cur:
+    async with await aconn.execute("select 1") as cur:
         assert isinstance(cur, MyCursor)
 
 
index c2855760ac88ec7f1603aa9b91ba6255909a2fa0..0f4ba1b118baafc3c125641f4d5410e19939db60 100644 (file)
@@ -165,14 +165,14 @@ def test_conninfo_random_multi_host():
 
 def test_conninfo_random_multi_ips(fake_resolve):
     args = {"host": "alot.com"}
-    hostaddrs = [att["hostaddr"] for att in conninfo_attempts(args)]
+    hostaddrs = [str(att["hostaddr"]) for att in conninfo_attempts(args)]
     assert len(hostaddrs) == 20
     assert hostaddrs == sorted(hostaddrs)
 
     args["load_balance_hosts"] = "disable"
-    hostaddrs = [att["hostaddr"] for att in conninfo_attempts(args)]
+    hostaddrs = [str(att["hostaddr"]) for att in conninfo_attempts(args)]
     assert hostaddrs == sorted(hostaddrs)
 
     args["load_balance_hosts"] = "random"
-    hostaddrs = [att["hostaddr"] for att in conninfo_attempts(args)]
+    hostaddrs = [str(att["hostaddr"]) for att in conninfo_attempts(args)]
     assert hostaddrs != sorted(hostaddrs)
index bf6da880f4d7a2ad8b0642dd182d9f38caffa0b8..aada9f1e00f77b9c96ad19c685f17d6b41c8d544 100644 (file)
@@ -172,14 +172,14 @@ async def test_conninfo_random_multi_host():
 
 async def test_conninfo_random_multi_ips(fake_resolve):
     args = {"host": "alot.com"}
-    hostaddrs = [att["hostaddr"] for att in await conninfo_attempts_async(args)]
+    hostaddrs = [str(att["hostaddr"]) for att in await conninfo_attempts_async(args)]
     assert len(hostaddrs) == 20
     assert hostaddrs == sorted(hostaddrs)
 
     args["load_balance_hosts"] = "disable"
-    hostaddrs = [att["hostaddr"] for att in await conninfo_attempts_async(args)]
+    hostaddrs = [str(att["hostaddr"]) for att in await conninfo_attempts_async(args)]
     assert hostaddrs == sorted(hostaddrs)
 
     args["load_balance_hosts"] = "random"
-    hostaddrs = [att["hostaddr"] for att in await conninfo_attempts_async(args)]
+    hostaddrs = [str(att["hostaddr"]) for att in await conninfo_attempts_async(args)]
     assert hostaddrs != sorted(hostaddrs)
index fda854e60bea1adeee1eea27bb3e7b6f0888f5ef..55f9e3b77499bf44555ce30caae87ae41ec1c3d7 100644 (file)
@@ -306,6 +306,7 @@ def test_subclass_adapter(conn, format):
         BaseDumper = StrBinaryDumper  # type: ignore
 
     class MyStrDumper(BaseDumper):
+
         def dump(self, obj):
             return super().dump(obj) * 2
 
@@ -641,6 +642,7 @@ def test_worker_life(conn, format, buffer):
 
 
 def test_worker_error_propagated(conn, monkeypatch):
+
     def copy_to_broken(pgconn, buffer):
         raise ZeroDivisionError
         yield
@@ -803,6 +805,7 @@ def test_copy_table_across(conn_cls, dsn, faker, mode):
 
 
 class DataGenerator:
+
     def __init__(self, conn, nrecs, srec, offset=0, block_size=8192):
         self.conn = conn
         self.nrecs = nrecs
index 159e67cb1e92de995418ffe24bb6622c4af0c42f..d5b0d15136e78229a0e3beca406ceb8a1d15cb23 100644 (file)
@@ -576,6 +576,7 @@ def test_row_factory_none(conn):
 
 
 def test_bad_row_factory(conn):
+
     def broken_factory(cur):
         1 / 0
 
@@ -584,6 +585,7 @@ def test_bad_row_factory(conn):
         cur.execute("select 1")
 
     def broken_maker(cur):
+
         def make_row(seq):
             1 / 0
 
index ecb8da987182ca7d626f3cb8c0a9c4608f3fc3e6..2df55e3e08eede59d72df371e28a4e3cd741ffa2 100644 (file)
@@ -25,7 +25,7 @@ def test_connect_operationalerror_pgconn(generators, dsn, monkeypatch):
         except KeyError:
             info = conninfo_to_dict(dsn)
             del info["password"]  # should not raise per check above.
-            dsn = make_conninfo(**info)
+            dsn = make_conninfo("", **info)
 
         gen = generators.connect(dsn)
         with pytest.raises(
diff --git a/tests/test_notify.py b/tests/test_notify.py
new file mode 100644 (file)
index 0000000..c67b331
--- /dev/null
@@ -0,0 +1,195 @@
+# WARNING: this file is auto-generated by 'async_to_sync.py'
+# from the original file 'test_notify_async.py'
+# DO NOT CHANGE! Change the original file instead.
+from __future__ import annotations
+
+from time import time
+
+import pytest
+from psycopg import Notify
+
+from .acompat import sleep, gather, spawn
+
+pytestmark = pytest.mark.crdb_skip("notify")
+
+
+def test_notify_handlers(conn):
+    nots1 = []
+    nots2 = []
+
+    def cb1(n):
+        nots1.append(n)
+
+    conn.add_notify_handler(cb1)
+    conn.add_notify_handler(lambda n: nots2.append(n))
+
+    conn.set_autocommit(True)
+    conn.execute("listen foo")
+    conn.execute("notify foo, 'n1'")
+
+    assert len(nots1) == 1
+    n = nots1[0]
+    assert n.channel == "foo"
+    assert n.payload == "n1"
+    assert n.pid == conn.pgconn.backend_pid
+
+    assert len(nots2) == 1
+    assert nots2[0] == nots1[0]
+
+    conn.remove_notify_handler(cb1)
+    conn.execute("notify foo, 'n2'")
+
+    assert len(nots1) == 1
+    assert len(nots2) == 2
+    n = nots2[1]
+    assert isinstance(n, Notify)
+    assert n.channel == "foo"
+    assert n.payload == "n2"
+    assert n.pid == conn.pgconn.backend_pid
+    assert hash(n)
+
+    with pytest.raises(ValueError):
+        conn.remove_notify_handler(cb1)
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+def test_notify(conn_cls, conn, dsn):
+    npid = None
+
+    def notifier():
+        with conn_cls.connect(dsn, autocommit=True) as nconn:
+            nonlocal npid
+            npid = nconn.pgconn.backend_pid
+
+            sleep(0.25)
+            nconn.execute("notify foo, '1'")
+            sleep(0.25)
+            nconn.execute("notify foo, '2'")
+
+    def receiver():
+        conn.set_autocommit(True)
+        cur = conn.cursor()
+        cur.execute("listen foo")
+        gen = conn.notifies()
+        for n in gen:
+            ns.append((n, time()))
+            if len(ns) >= 2:
+                gen.close()
+
+    ns: list[tuple[Notify, float]] = []
+    t0 = time()
+    workers = [spawn(notifier), spawn(receiver)]
+    gather(*workers)
+    assert len(ns) == 2
+
+    n, t1 = ns[0]
+    assert n.pid == npid
+    assert n.channel == "foo"
+    assert n.payload == "1"
+    assert t1 - t0 == pytest.approx(0.25, abs=0.05)
+
+    n, t1 = ns[1]
+    assert n.pid == npid
+    assert n.channel == "foo"
+    assert n.payload == "2"
+    assert t1 - t0 == pytest.approx(0.5, abs=0.05)
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+def test_no_notify_timeout(conn):
+    conn.set_autocommit(True)
+    t0 = time()
+    for n in conn.notifies(timeout=0.5):
+        assert False
+    dt = time() - t0
+    assert 0.5 <= dt < 0.75
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+def test_notify_timeout(conn_cls, conn, dsn):
+    conn.set_autocommit(True)
+    conn.execute("listen foo")
+
+    def notifier():
+        with conn_cls.connect(dsn, autocommit=True) as nconn:
+            sleep(0.25)
+            nconn.execute("notify foo, '1'")
+
+    worker = spawn(notifier)
+    try:
+        times = [time()]
+        for n in conn.notifies(timeout=0.5):
+            times.append(time())
+        times.append(time())
+    finally:
+        gather(worker)
+
+    assert len(times) == 3
+    assert times[1] - times[0] == pytest.approx(0.25, 0.1)
+    assert times[2] - times[1] == pytest.approx(0.25, 0.1)
+
+
+@pytest.mark.slow
+def test_notify_timeout_0(conn_cls, conn, dsn):
+    conn.set_autocommit(True)
+    conn.execute("listen foo")
+
+    ns = list(conn.notifies(timeout=0))
+    assert not ns
+
+    with conn_cls.connect(dsn, autocommit=True) as nconn:
+        nconn.execute("notify foo, '1'")
+        sleep(0.1)
+
+    ns = list(conn.notifies(timeout=0))
+    assert len(ns) == 1
+
+
+@pytest.mark.slow
+def test_stop_after(conn_cls, conn, dsn):
+    conn.set_autocommit(True)
+    conn.execute("listen foo")
+
+    def notifier():
+        with conn_cls.connect(dsn, autocommit=True) as nconn:
+            nconn.execute("notify foo, '1'")
+            sleep(0.1)
+            nconn.execute("notify foo, '2'")
+            sleep(0.1)
+            nconn.execute("notify foo, '3'")
+
+    worker = spawn(notifier)
+    try:
+        ns = list(conn.notifies(timeout=1.0, stop_after=2))
+        assert len(ns) == 2
+        assert ns[0].payload == "1"
+        assert ns[1].payload == "2"
+    finally:
+        gather(worker)
+
+    ns = list(conn.notifies(timeout=0.0))
+    assert len(ns) == 1
+    assert ns[0].payload == "3"
+
+
+def test_stop_after_batch(conn_cls, conn, dsn):
+    conn.set_autocommit(True)
+    conn.execute("listen foo")
+
+    def notifier():
+        with conn_cls.connect(dsn, autocommit=True) as nconn:
+            with nconn.transaction():
+                nconn.execute("notify foo, '1'")
+                nconn.execute("notify foo, '2'")
+
+    worker = spawn(notifier)
+    try:
+        ns = list(conn.notifies(timeout=1.0, stop_after=1))
+        assert len(ns) == 2
+        assert ns[0].payload == "1"
+        assert ns[1].payload == "2"
+    finally:
+        gather(worker)
diff --git a/tests/test_notify_async.py b/tests/test_notify_async.py
new file mode 100644 (file)
index 0000000..f4f0901
--- /dev/null
@@ -0,0 +1,192 @@
+from __future__ import annotations
+
+from time import time
+
+import pytest
+from psycopg import Notify
+
+from .acompat import alist, asleep, gather, spawn
+
+pytestmark = pytest.mark.crdb_skip("notify")
+
+
+async def test_notify_handlers(aconn):
+    nots1 = []
+    nots2 = []
+
+    def cb1(n):
+        nots1.append(n)
+
+    aconn.add_notify_handler(cb1)
+    aconn.add_notify_handler(lambda n: nots2.append(n))
+
+    await aconn.set_autocommit(True)
+    await aconn.execute("listen foo")
+    await aconn.execute("notify foo, 'n1'")
+
+    assert len(nots1) == 1
+    n = nots1[0]
+    assert n.channel == "foo"
+    assert n.payload == "n1"
+    assert n.pid == aconn.pgconn.backend_pid
+
+    assert len(nots2) == 1
+    assert nots2[0] == nots1[0]
+
+    aconn.remove_notify_handler(cb1)
+    await aconn.execute("notify foo, 'n2'")
+
+    assert len(nots1) == 1
+    assert len(nots2) == 2
+    n = nots2[1]
+    assert isinstance(n, Notify)
+    assert n.channel == "foo"
+    assert n.payload == "n2"
+    assert n.pid == aconn.pgconn.backend_pid
+    assert hash(n)
+
+    with pytest.raises(ValueError):
+        aconn.remove_notify_handler(cb1)
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+async def test_notify(aconn_cls, aconn, dsn):
+    npid = None
+
+    async def notifier():
+        async with await aconn_cls.connect(dsn, autocommit=True) as nconn:
+            nonlocal npid
+            npid = nconn.pgconn.backend_pid
+
+            await asleep(0.25)
+            await nconn.execute("notify foo, '1'")
+            await asleep(0.25)
+            await nconn.execute("notify foo, '2'")
+
+    async def receiver():
+        await aconn.set_autocommit(True)
+        cur = aconn.cursor()
+        await cur.execute("listen foo")
+        gen = aconn.notifies()
+        async for n in gen:
+            ns.append((n, time()))
+            if len(ns) >= 2:
+                await gen.aclose()
+
+    ns: list[tuple[Notify, float]] = []
+    t0 = time()
+    workers = [spawn(notifier), spawn(receiver)]
+    await gather(*workers)
+    assert len(ns) == 2
+
+    n, t1 = ns[0]
+    assert n.pid == npid
+    assert n.channel == "foo"
+    assert n.payload == "1"
+    assert t1 - t0 == pytest.approx(0.25, abs=0.05)
+
+    n, t1 = ns[1]
+    assert n.pid == npid
+    assert n.channel == "foo"
+    assert n.payload == "2"
+    assert t1 - t0 == pytest.approx(0.5, abs=0.05)
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+async def test_no_notify_timeout(aconn):
+    await aconn.set_autocommit(True)
+    t0 = time()
+    async for n in aconn.notifies(timeout=0.5):
+        assert False
+    dt = time() - t0
+    assert 0.5 <= dt < 0.75
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+async def test_notify_timeout(aconn_cls, aconn, dsn):
+    await aconn.set_autocommit(True)
+    await aconn.execute("listen foo")
+
+    async def notifier():
+        async with await aconn_cls.connect(dsn, autocommit=True) as nconn:
+            await asleep(0.25)
+            await nconn.execute("notify foo, '1'")
+
+    worker = spawn(notifier)
+    try:
+        times = [time()]
+        async for n in aconn.notifies(timeout=0.5):
+            times.append(time())
+        times.append(time())
+    finally:
+        await gather(worker)
+
+    assert len(times) == 3
+    assert times[1] - times[0] == pytest.approx(0.25, 0.1)
+    assert times[2] - times[1] == pytest.approx(0.25, 0.1)
+
+
+@pytest.mark.slow
+async def test_notify_timeout_0(aconn_cls, aconn, dsn):
+    await aconn.set_autocommit(True)
+    await aconn.execute("listen foo")
+
+    ns = await alist(aconn.notifies(timeout=0))
+    assert not ns
+
+    async with await aconn_cls.connect(dsn, autocommit=True) as nconn:
+        await nconn.execute("notify foo, '1'")
+        await asleep(0.1)
+
+    ns = await alist(aconn.notifies(timeout=0))
+    assert len(ns) == 1
+
+
+@pytest.mark.slow
+async def test_stop_after(aconn_cls, aconn, dsn):
+    await aconn.set_autocommit(True)
+    await aconn.execute("listen foo")
+
+    async def notifier():
+        async with await aconn_cls.connect(dsn, autocommit=True) as nconn:
+            await nconn.execute("notify foo, '1'")
+            await asleep(0.1)
+            await nconn.execute("notify foo, '2'")
+            await asleep(0.1)
+            await nconn.execute("notify foo, '3'")
+
+    worker = spawn(notifier)
+    try:
+        ns = await alist(aconn.notifies(timeout=1.0, stop_after=2))
+        assert len(ns) == 2
+        assert ns[0].payload == "1"
+        assert ns[1].payload == "2"
+    finally:
+        await gather(worker)
+
+    ns = await alist(aconn.notifies(timeout=0.0))
+    assert len(ns) == 1
+    assert ns[0].payload == "3"
+
+
+async def test_stop_after_batch(aconn_cls, aconn, dsn):
+    await aconn.set_autocommit(True)
+    await aconn.execute("listen foo")
+
+    async def notifier():
+        async with await aconn_cls.connect(dsn, autocommit=True) as nconn:
+            async with nconn.transaction():
+                await nconn.execute("notify foo, '1'")
+                await nconn.execute("notify foo, '2'")
+
+    worker = spawn(notifier)
+    try:
+        ns = await alist(aconn.notifies(timeout=1.0, stop_after=1))
+        assert len(ns) == 2
+        assert ns[0].payload == "1"
+        assert ns[1].payload == "2"
+    finally:
+        await gather(worker)
index a89344974207a7177d4bfd13262f9d1bb7c8b806..00e37017e63d7457030bc5c022bddf848a6eccb8 100644 (file)
@@ -18,7 +18,7 @@ def with_dsn(request, session_dsn):
 class PsycopgTests(dbapi20.DatabaseAPI20Test):
     driver = psycopg
     # connect_args = () # set by the fixture
-    connect_kw_args: Dict[str, Any] = {}
+    connect_kw_args: Dict[Any, Any] = {}
 
     def test_nextset(self):
         # tested elsewhere
index 5165b80074b0f8f1b658e840a2ebeba5e4a48e60..93240b5eba92c6c8ab830befbf226a9cf3dd6661 100644 (file)
@@ -102,6 +102,16 @@ def test_kwargs_row(conn):
     assert p.age == 42
 
 
+def test_scalar_row(conn):
+    cur = conn.cursor(row_factory=rows.scalar_row)
+    cur.execute("select 1")
+    assert cur.fetchone() == 1
+    cur.execute("select 1, 2")
+    assert cur.fetchone() == 1
+    with pytest.raises(psycopg.ProgrammingError):
+        cur.execute("select")
+
+
 @pytest.mark.parametrize(
     "factory",
     "tuple_row dict_row namedtuple_row class_row args_row kwargs_row".split(),
index b1ec8d85df3d0841966c0b17875184e3c8f1404d..b5f1b37cad08a3008521caf7d4614519cf40682e 100644 (file)
@@ -267,35 +267,51 @@ class TestIdentifier:
         assert sql.Identifier("foo") != "foo"
         assert sql.Identifier("foo") != sql.SQL("foo")
 
-    @pytest.mark.parametrize(
-        "args, want",
-        [
-            (("foo",), '"foo"'),
-            (("foo", "bar"), '"foo"."bar"'),
-            (("fo'o", 'ba"r'), '"fo\'o"."ba""r"'),
-        ],
-    )
+    _as_string_params = [
+        (("foo",), '"foo"'),
+        (("foo", "bar"), '"foo"."bar"'),
+        (("fo'o", 'ba"r'), '"fo\'o"."ba""r"'),
+    ]
+
+    @pytest.mark.parametrize("args, want", _as_string_params)
     def test_as_string(self, conn, args, want):
         assert sql.Identifier(*args).as_string(conn) == want
 
-    @pytest.mark.parametrize(
-        "args, want, enc",
-        [
-            crdb_encoding(("foo",), '"foo"', "ascii"),
-            crdb_encoding(("foo", "bar"), '"foo"."bar"', "ascii"),
-            crdb_encoding(("fo'o", 'ba"r'), '"fo\'o"."ba""r"', "ascii"),
-            (("foo", eur), f'"foo"."{eur}"', "utf8"),
-            crdb_encoding(("foo", eur), f'"foo"."{eur}"', "latin9"),
-        ],
-    )
+    @pytest.mark.parametrize("args, want", _as_string_params)
+    def test_as_string_no_conn(self, args, want):
+        assert sql.Identifier(*args).as_string(None) == want
+        assert sql.Identifier(*args).as_string() == want
+
+    _as_bytes_params = [
+        crdb_encoding(("foo",), '"foo"', "ascii"),
+        crdb_encoding(("foo", "bar"), '"foo"."bar"', "ascii"),
+        crdb_encoding(("fo'o", 'ba"r'), '"fo\'o"."ba""r"', "ascii"),
+        (("foo", eur), f'"foo"."{eur}"', "utf8"),
+        crdb_encoding(("foo", eur), f'"foo"."{eur}"', "latin9"),
+    ]
+
+    @pytest.mark.parametrize("args, want, enc", _as_bytes_params)
     def test_as_bytes(self, conn, args, want, enc):
         want = want.encode(enc)
         conn.execute(f"set client_encoding to {py2pgenc(enc).decode()}")
         assert sql.Identifier(*args).as_bytes(conn) == want
 
+    @pytest.mark.parametrize("args, want, enc", _as_bytes_params)
+    def test_as_bytes_no_conn(self, conn, args, want, enc):
+        want = want.encode()
+        assert sql.Identifier(*args).as_bytes(None) == want
+        assert sql.Identifier(*args).as_bytes() == want
+
     def test_join(self):
         assert not hasattr(sql.Identifier("foo"), "join")
 
+    def test_escape_no_conn(self, conn):
+        conn.execute("set client_encoding to 'utf8'")
+        for c in range(1, 128):
+            s = chr(c)
+            want = sql.Identifier(s).as_bytes(conn)
+            assert want == sql.Identifier(s).as_bytes(None)
+
 
 class TestLiteral:
     def test_class(self):
@@ -312,17 +328,40 @@ class TestLiteral:
         assert repr(sql.Literal("foo")) == "Literal('foo')"
         assert str(sql.Literal("foo")) == "Literal('foo')"
 
-    def test_as_string(self, conn):
-        assert sql.Literal(None).as_string(conn) == "NULL"
-        assert no_e(sql.Literal("foo").as_string(conn)) == "'foo'"
-        assert sql.Literal(42).as_string(conn) == "42"
-        assert sql.Literal(dt.date(2017, 1, 1)).as_string(conn) == "'2017-01-01'::date"
-
-    def test_as_bytes(self, conn):
-        assert sql.Literal(None).as_bytes(conn) == b"NULL"
-        assert no_e(sql.Literal("foo").as_bytes(conn)) == b"'foo'"
-        assert sql.Literal(42).as_bytes(conn) == b"42"
-        assert sql.Literal(dt.date(2017, 1, 1)).as_bytes(conn) == b"'2017-01-01'::date"
+    _params = [
+        (None, "NULL"),
+        ("foo", "'foo'"),
+        (42, "42"),
+        (dt.date(2017, 1, 1), "'2017-01-01'::date"),
+    ]
+
+    @pytest.mark.parametrize("obj, want", _params)
+    def test_as_string(self, conn, obj, want):
+        got = sql.Literal(obj).as_string(conn)
+        if isinstance(obj, str):
+            got = no_e(got)
+        assert got == want
+
+    @pytest.mark.parametrize("obj, want", _params)
+    def test_as_bytes(self, conn, obj, want):
+        got = sql.Literal(obj).as_bytes(conn)
+        if isinstance(obj, str):
+            got = no_e(got)
+        assert got == want.encode()
+
+    @pytest.mark.parametrize("obj, want", _params)
+    def test_as_string_no_conn(self, obj, want):
+        got = sql.Literal(obj).as_string()
+        if isinstance(obj, str):
+            got = no_e(got)
+        assert got == want
+
+    @pytest.mark.parametrize("obj, want", _params)
+    def test_as_bytes_no_conn(self, obj, want):
+        got = sql.Literal(obj).as_bytes()
+        if isinstance(obj, str):
+            got = no_e(got)
+        assert got == want.encode()
 
     @pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")])
     def test_as_bytes_encoding(self, conn, encoding):
@@ -459,6 +498,7 @@ class TestSQL:
 
     def test_as_string(self, conn):
         assert sql.SQL("foo").as_string(conn) == "foo"
+        assert sql.SQL("foo").as_string() == "foo"
 
     @pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")])
     def test_as_bytes(self, conn, encoding):
@@ -467,6 +507,10 @@ class TestSQL:
 
         assert sql.SQL(eur).as_bytes(conn) == eur.encode(encoding)
 
+    def test_no_conn(self):
+        assert sql.SQL(eur).as_string() == eur
+        assert sql.SQL(eur).as_bytes() == eur.encode()
+
 
 class TestComposed:
     def test_class(self):
@@ -526,10 +570,12 @@ class TestComposed:
     def test_as_string(self, conn):
         obj = sql.Composed([sql.SQL("foo"), sql.SQL("bar")])
         assert obj.as_string(conn) == "foobar"
+        assert obj.as_string() == "foobar"
 
     def test_as_bytes(self, conn):
         obj = sql.Composed([sql.SQL("foo"), sql.SQL("bar")])
         assert obj.as_bytes(conn) == b"foobar"
+        assert obj.as_bytes() == b"foobar"
 
     @pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")])
     def test_as_bytes_encoding(self, conn, encoding):
@@ -570,17 +616,21 @@ class TestPlaceholder:
     def test_as_string(self, conn, format):
         ph = sql.Placeholder(format=format)
         assert ph.as_string(conn) == f"%{format.value}"
+        assert ph.as_string() == f"%{format.value}"
 
         ph = sql.Placeholder(name="foo", format=format)
         assert ph.as_string(conn) == f"%(foo){format.value}"
+        assert ph.as_string() == f"%(foo){format.value}"
 
     @pytest.mark.parametrize("format", PyFormat)
     def test_as_bytes(self, conn, format):
         ph = sql.Placeholder(format=format)
-        assert ph.as_bytes(conn) == f"%{format.value}".encode("ascii")
+        assert ph.as_bytes(conn) == f"%{format.value}".encode()
+        assert ph.as_bytes() == f"%{format.value}".encode()
 
         ph = sql.Placeholder(name="foo", format=format)
-        assert ph.as_bytes(conn) == f"%(foo){format.value}".encode("ascii")
+        assert ph.as_bytes(conn) == f"%(foo){format.value}".encode()
+        assert ph.as_bytes() == f"%(foo){format.value}".encode()
 
 
 class TestValues:
@@ -595,7 +645,7 @@ class TestValues:
 
 def no_e(s):
     """Drop an eventual E from E'' quotes"""
-    if isinstance(s, memoryview):
+    if isinstance(s, (memoryview, bytearray)):
         s = bytes(s)
 
     if isinstance(s, str):
index 41023ccc87f92af303bba9048172827d917bc5ef..864f8a7868e77b3616cd13460a19138b87d1ec69 100644 (file)
@@ -22,6 +22,7 @@ def test_tpc_disabled(conn, pipeline):
 
 
 class TestTPC:
+
     def test_tpc_commit(self, conn, tpc):
         xid = conn.xid(1, "gtrid", "bqual")
         assert conn.info.transaction_status == TransactionStatus.IDLE
index 6a9ad88f376155388db44433fc66736302e6d96b..c4d8915e8a801209fcc416203b77fa20e7df4ac1 100644 (file)
@@ -1,6 +1,7 @@
+import sys
+import time
 import select  # noqa: used in pytest.mark.skipif
 import socket
-import sys
 
 import pytest
 
@@ -26,6 +27,7 @@ waitfns = [
     pytest.param("wait_c", marks=pytest.mark.skipif("not psycopg._cmodule._psycopg")),
 ]
 
+events = ["R", "W", "RW"]
 timeouts = [pytest.param({}, id="blank")]
 timeouts += [pytest.param({"timeout": x}, id=str(x)) for x in [None, 0, 0.2, 10]]
 
@@ -44,9 +46,11 @@ def test_wait_conn_bad(dsn):
 
 
 @pytest.mark.parametrize("waitfn", waitfns)
-@pytest.mark.parametrize("wait, ready", zip(waiting.Wait, waiting.Ready))
+@pytest.mark.parametrize("event", events)
 @skip_if_not_linux
-def test_wait_ready(waitfn, wait, ready):
+def test_wait_ready(waitfn, event):
+    wait = getattr(waiting.Wait, event)
+    ready = getattr(waiting.Ready, event)
     waitfn = getattr(waiting, waitfn)
 
     def gen():
@@ -80,6 +84,34 @@ def test_wait_bad(pgconn, waitfn):
         waitfn(gen, pgconn.socket)
 
 
+@pytest.mark.slow
+@pytest.mark.timing
+@pytest.mark.parametrize("waitfn", waitfns)
+def test_wait_timeout(pgconn, waitfn):
+    waitfn = getattr(waiting, waitfn)
+
+    pgconn.send_query(b"select pg_sleep(0.5)")
+    gen = generators.execute(pgconn)
+
+    ts = [time.time()]
+
+    def gen_wrapper():
+        try:
+            for x in gen:
+                res = yield x
+                ts.append(time.time())
+                gen.send(res)
+        except StopIteration as ex:
+            return ex.value
+
+    (res,) = waitfn(gen_wrapper(), pgconn.socket, timeout=0.1)
+    assert res.status == ExecStatus.TUPLES_OK
+    ds = [t1 - t0 for t0, t1 in zip(ts[:-1], ts[1:])]
+    assert len(ds) >= 5
+    for d in ds[:5]:
+        assert d == pytest.approx(0.1, 0.05)
+
+
 @pytest.mark.slow
 @pytest.mark.skipif(
     "sys.platform == 'win32'", reason="win32 works ok, but FDs are mysterious"
@@ -130,9 +162,12 @@ async def test_wait_conn_async_bad(dsn):
 
 
 @pytest.mark.anyio
-@pytest.mark.parametrize("wait, ready", zip(waiting.Wait, waiting.Ready))
+@pytest.mark.parametrize("event", events)
 @skip_if_not_linux
-async def test_wait_ready_async(wait, ready):
+async def test_wait_ready_async(event):
+    wait = getattr(waiting.Wait, event)
+    ready = getattr(waiting.Ready, event)
+
     def gen():
         r = yield wait
         return r
index b582a96d90fd0a0804063d765c3202eab5b5fdb6..2f283885fb20893536662e73624e42e710d3c8cf 100644 (file)
@@ -713,6 +713,7 @@ class TestInterval:
             ("-90d", "-3 month"),
             ("186d", "6 mons 6 days"),
             ("736d", "2 years 6 days"),
+            ("83063d,81640s,447000m", "1993534:40:40.447"),
         ],
     )
     @pytest.mark.parametrize("fmt_out", pq.Format)
index fffc32f31ba0ac8ac25d36527f3502294156d486..8e5db7668c78efc28e5389ada3e8fe05201237f2 100644 (file)
@@ -402,9 +402,11 @@ def test_dump_numeric_binary(conn, expr):
 @pytest.mark.parametrize(
     "fmt_in",
     [
-        f
-        if f != PyFormat.BINARY
-        else pytest.param(f, marks=pytest.mark.crdb_skip("binary decimal"))
+        (
+            f
+            if f != PyFormat.BINARY
+            else pytest.param(f, marks=pytest.mark.crdb_skip("binary decimal"))
+        )
         for f in PyFormat
     ],
 )
index d264c78bd20e5339eaa83516ff9f205f5ff1cb97..c4c053a1ab78b91de4f0f3ac8b6d50439a55692d 100755 (executable)
@@ -48,6 +48,7 @@ ALL_INPUTS = """
     tests/test_cursor_common_async.py
     tests/test_cursor_raw_async.py
     tests/test_cursor_server_async.py
+    tests/test_notify_async.py
     tests/test_pipeline_async.py
     tests/test_prepared_async.py
     tests/test_tpc_async.py
index 22f04ec5307de7cb3bd5cc13b1785ac377e10244..7e303bb5120d2eeee388eb73722f0b929ec63ce3 100755 (executable)
@@ -19,11 +19,11 @@ import argparse
 import subprocess as sp
 from typing import List
 from pathlib import Path
-from typing_extensions import TypeAlias
 
 import psycopg
 from psycopg.rows import TupleRow
 from psycopg.crdb import CrdbConnection
+from psycopg._compat import TypeAlias
 
 Connection: TypeAlias = psycopg.Connection[TupleRow]