]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
chore: align to master
authorNick Pope <nick.pope@infogrid.io>
Tue, 28 Nov 2023 12:46:59 +0000 (12:46 +0000)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 6 Jan 2024 12:59:59 +0000 (13:59 +0100)
The pool branch is diverging unnecessarily from master because so far I
have only cherry picked the changes useful for the pool and compatible
with 3.2.x, but this is creating unnecessary differences and making
cherry-picking harder and harder.

From now on we should use the policy of cherry-picking everything unless
non compatible (not a new feature or a breaking change), similarly to
what we do on the maint-3.1 branch.

61 files changed:
.github/workflows/3rd-party-tests.yml
.github/workflows/lint.yml
.github/workflows/tests.yml
docs/advanced/async.rst
docs/advanced/cursors.rst
docs/api/cursors.rst
docs/api/dns.rst
docs/basic/install.rst
docs/news.rst
psycopg/psycopg/_connection_base.py
psycopg/psycopg/_dns.py
psycopg/psycopg/connection.py
psycopg/psycopg/connection_async.py
psycopg/psycopg/conninfo.py
psycopg/psycopg/cursor.py
psycopg/psycopg/cursor_async.py
psycopg/psycopg/pq/pq_ctypes.py
psycopg/psycopg/server_cursor.py
psycopg/psycopg/waiting.py
psycopg/setup.cfg
psycopg_c/setup.cfg
psycopg_pool/setup.cfg
pyproject.toml
tests/_test_connection.py
tests/conftest.py
tests/crdb/test_copy.py
tests/crdb/test_copy_async.py
tests/fix_db.py
tests/fix_gc.py [new file with mode: 0644]
tests/fix_pq.py
tests/fix_proxy.py
tests/pool/test_pool.py
tests/pool/test_pool_async.py
tests/pool/test_pool_async_noasyncio.py
tests/pool/test_pool_common.py
tests/pool/test_pool_common_async.py
tests/pq/test_pgconn.py
tests/test_connection.py
tests/test_connection_async.py
tests/test_conninfo.py
tests/test_copy.py
tests/test_copy_async.py
tests/test_cursor.py
tests/test_cursor_async.py
tests/test_cursor_client.py
tests/test_cursor_client_async.py
tests/test_cursor_common.py
tests/test_cursor_common_async.py
tests/test_cursor_raw.py
tests/test_cursor_raw_async.py
tests/test_cursor_server.py
tests/test_cursor_server_async.py
tests/test_dns.py
tests/test_errors.py
tests/test_gevent.py [new file with mode: 0644]
tests/test_module.py
tests/test_psycopg_dbapi20.py
tests/types/test_array.py
tests/utils.py
tools/async_to_sync.py
tools/build/build_macos_arm64.sh

index f085250a573f816747e9fd2af73802fdd518ebb3..89f948d0ea63d41c383cbbbe2e35cdc2121d2aa5 100644 (file)
@@ -25,7 +25,7 @@ jobs:
       fail-fast: false
       matrix:
         python-version:
-          - "3.11"
+          - "3.12"
           - "3.8"
         sqlalchemy_label:
           # what version of sqlalchemy to download is defined in the "include" section below,
@@ -39,12 +39,11 @@ jobs:
           - sqlalchemy_label: git_main
             pip_sqlalchemy: git+https://github.com/sqlalchemy/sqlalchemy.git#egg=SQLAlchemy
           - sqlalchemy_label: release
-            # TODO: remove pre once v2 is stable
-            pip_sqlalchemy: --pre sqlalchemy>=2a
+            pip_sqlalchemy: sqlalchemy>=2
 
     env:
       PSYCOPG_IMPL: ${{ matrix.impl }}
-      DEPS: ./psycopg pytest pytest-xdist
+      DEPS: ./psycopg pytest pytest-xdist greenlet
 
     services:
       postgresql:
@@ -111,7 +110,7 @@ jobs:
         env:
           URL: postgresql+psycopg://postgres:password@127.0.0.1/test
         working-directory: sa_home/sa
-        run: pytest -n 2 -q --dburi $URL --backend-only --dropfirst --color=yes
+        run: pytest -n 2 -q --dburi $URL --backend-only --dropfirst --color=yes --dbdriver psycopg_async
 
   django:
     # linux should be enough to test if everything works.
@@ -130,6 +129,8 @@ jobs:
         include:
           - django_label: git_main
             pip_django: git+https://github.com/django/django.git#egg=Django
+            # Need pylibmc wheel package to test with Python 3.12.
+            # https://github.com/lericson/pylibmc/issues/288
             python-version: "3.11"
           # TODO: Needs updating with new LTS releases, is this something we want?
           #       Also needs consideration against which python we wanna test.
index 9e294b8690e3eca56339bdf364fe794cf50ad034..b615fb28f2cc1e283d4568365592e42cd664a0f3 100644 (file)
@@ -25,7 +25,7 @@ jobs:
           python-version: "3.11"
 
       - name: install packages to tests
-        run: pip install ./psycopg[dev,test] codespell
+        run: pip install ./psycopg[dev,test]
 
       - name: Run black
         run: black --check --diff .
index ebd36b6cf9733066c057df06400f62b3ff4998ee..f97e64ba8844472855666eec96f01d0e4504a489 100644 (file)
@@ -38,6 +38,8 @@ jobs:
           - {impl: c, python: "3.11", postgres: "postgres:12", libpq: oldest}
           - {impl: c, python: "3.12", postgres: "postgres:11", libpq: newest}
 
+          - {impl: python, python: "3.8", ext: gevent, postgres: "postgres:16"}
+          - {impl: c, python: "3.12", ext: gevent, postgres: "postgres:14"}
           - {impl: python, python: "3.9", ext: dns, postgres: "postgres:14"}
           - {impl: python, python: "3.9", ext: postgis, postgres: "postgis/postgis"}
           - {impl: python, python: "3.10", ext: numpy, postgres: "postgres:14"}
@@ -46,6 +48,10 @@ jobs:
           # Test with minimum dependencies versions
           - {impl: c, python: "3.8", ext: min, postgres: "postgres:15"}
 
+          # Test with PyPy.
+          - {impl: python, python: "pypy3.9", postgres: "postgres:13"}
+          - {impl: python, python: "pypy3.10", postgres: "postgres:14"}
+
     env:
       PSYCOPG_IMPL: ${{ matrix.impl }}
       DEPS: ./psycopg[test] ./psycopg_pool
@@ -76,6 +82,12 @@ jobs:
         run: |
           echo "DEPS=$DEPS ./psycopg_c" >> $GITHUB_ENV
 
+      - name: Include gevent to the packages to install
+        if: ${{ matrix.ext == 'gevent' }}
+        run: |
+          echo "DEPS=$DEPS gevent" >> $GITHUB_ENV
+          echo "MARKERS=$MARKERS gevent" >> $GITHUB_ENV
+
       - name: Include dnspython to the packages to install
         if: ${{ matrix.ext == 'dns' }}
         run: |
@@ -93,10 +105,15 @@ jobs:
           echo "DEPS=$DEPS numpy" >> $GITHUB_ENV
           echo "MARKERS=$MARKERS numpy" >> $GITHUB_ENV
 
+      - name: Exclude certain tests from pypy
+        if: ${{ startsWith(matrix.python, 'pypy') }}
+        run: |
+          echo "NOT_MARKERS=$NOT_MARKERS timing" >> $GITHUB_ENV
+
       - name: Configure to use the oldest dependencies
         if: ${{ matrix.ext == 'min' }}
         run: |
-          echo "DEPS=$DEPS dnspython shapely numpy" >> $GITHUB_ENV
+          echo "DEPS=$DEPS dnspython shapely numpy gevent" >> $GITHUB_ENV
           echo "PIP_CONSTRAINT=${{ github.workspace }}/tests/constraints.txt" \
             >> $GITHUB_ENV
 
index dfb9fabc59750c92510c406a200c67b697ddc98e..bf75260711d02d2004e4975d4e77c15655b8f215 100644 (file)
@@ -1,14 +1,83 @@
 .. currentmodule:: psycopg
 
+
+.. index:: threads
+
+.. _concurrency:
+
+Concurrent operations
+=====================
+
+Psycopg allows to write *concurrent* code, executing more than one operation
+at time.
+
+- `Connection` objects *are thread-safe*: more than one thread at time can use
+  the same connection. Different thread can use the same connection by
+  creating different cursors.
+
+- `Cursor` objects *are not thread-safe*, and are not designed to be used by
+  several threads at the same time. However, cursors are lightweight objects:
+  different threads can create each one its own cursor to use independently
+  from other threads.
+
+.. note::
+
+    All the cursors that share the same connection *will also share the same
+    transaction*. This means that, if a thread starts a transaction, every
+    cursor on the same connection will execute their queries in the same
+    transaction and, if one thread causes a database server error, all the
+    other cursors will be in error state until transaction rollback.
+
+    It also means that every cursor will see changes made in the same session
+    by other cursors, even if the transaction is still uncommitted. This
+    effect might be desirable or not, and is something to consider when
+    deciding whether to share a connection or not.
+
+.. hint::
+
+    Should you use many cursors or many connections?
+
+    Query execution and results retrieval on a connection is serialized: only
+    one cursor at time will be able to run a query on the same connection (the
+    `!Connection` object will coordinate different cursors' access). If your
+    program runs a mix of database and non-database operations in several
+    threads, then these threads might be able to share the same connection.
+    However, if you expect to execute massively parallel operations on the
+    database, it might be useful to use more than one connection at time,
+    rather than many cursors on the same connection (or a mix of both).
+
+    Using several connections, however, has an impact on the server's
+    performance and usually the number of connections that a server can handle
+    is limited by grumpy sysadmins with long beards and a strict control on
+    the `max_connections`__ server setting.
+
+    If you want to use more than one connection at time, but still avoid to
+    create too many connections and starve the server, you might want to use a
+    :ref:`connection pool <connection-pools>`.
+
+    .. __: https://www.postgresql.org/docs/current/runtime-config-connection.html#GUC-MAX-CONNECTIONS
+
+.. warning::
+
+    *Connections are not process-safe* and cannot be shared across processes,
+    for instance using the facilities of the `multiprocessing` module.
+
+    If you are using Psycopg in a forking framework (for instance in a web
+    server that implements concurrency using multiprocessing), you should make
+    sure that the database connections are created after the worker process is
+    forked. Failing to do so you will probably find the connection in broken
+    state.
+
+
 .. index:: asyncio
 
 .. _async:
 
 Asynchronous operations
-=======================
+-----------------------
 
-Psycopg `~Connection` and `~Cursor` have counterparts `~AsyncConnection` and
-`~AsyncCursor` supporting an `asyncio` interface.
+Psycopg `Connection` and `Cursor` have counterparts `AsyncConnection` and
+`AsyncCursor` supporting an `asyncio` interface.
 
 The design of the asynchronous objects is pretty much the same of the sync
 ones: in order to use them you will only have to scatter the `!await` keyword
@@ -28,6 +97,12 @@ here and there.
             async for record in acur:
                 print(record)
 
+An `!AsyncConnection` can be used by several `asyncio.Task` at the same time.
+However, as with threads, all the `AsyncCursor` on the same connection will
+share the same session and will have their access to the connection
+serialized.
+
+
 .. versionchanged:: 3.1
 
     `AsyncConnection.connect()` performs DNS name resolution in a non-blocking
@@ -141,6 +216,31 @@ manually cancel connections.  This should no longer be necessary.
 .. __: https://docs.python.org/3/library/asyncio-task.html#task-cancellation
 
 
+.. index:: gevent
+
+.. _gevent:
+
+Gevent support
+--------------
+
+Psycopg 3 supports `gevent <https://www.gevent.org/>`__ out of the box. If the
+`select` module is found patched by functions such as
+`gevent.monkey.patch_select()`__ or `patch_all()`__, psycopg will behave in a
+collaborative way.
+
+Unlike with `!psycopg2`, using the `!psycogreen` module is not required.
+
+.. __: http://www.gevent.org/api/gevent.monkey.html#gevent.monkey.patch_select
+.. __: http://www.gevent.org/api/gevent.monkey.html#gevent.monkey.patch_all
+
+.. warning::
+
+    gevent support was initially accidental, and was accidentally broken in
+    psycopg 3.1.4.
+
+    gevent is officially supported only starting from psycopg 3.1.14.
+
+
 .. index::
     pair: Asynchronous; Notifications
     pair: LISTEN; SQL command
index 7ec2e60b134794ac3f418b6b323eff6c38ba4adc..8ece1e7e8fc571e9e7aa1aad34982ba182f432d8 100644 (file)
@@ -8,9 +8,45 @@
 Cursor types
 ============
 
-Psycopg can manage kinds of "cursors" which differ in where the state of a
-query being processed is stored: :ref:`client-side-cursors` and
-:ref:`server-side-cursors`.
+Cursors are objects used to send commands to a PostgreSQL connection and to
+manage the results returned by it. They are normally created by the
+connection's `~Connection.cursor()` method.
+
+Psycopg can manage different kinds of "cursors", the objects used to send
+queries and retrieve results from the server. They differ from each other in
+aspects such as:
+
+- Are the parameters bound on the client or on the server?
+  :ref:`server-side-binding` can offer better performance (for instance
+  allowing to use prepared statements) and reduced memory footprint, but may
+  require stricter query definition and certain queries that work in
+  `!psycopg2` might need to be adapted.
+
+- Is the query result stored on the client or on the server? Server-side
+  cursors allow partial retrieval of large datasets, but they might offer less
+  performance in everyday usage.
+
+- Are queries manipulated by Python (to handle placeholders in ``%s`` and
+  ``%(name)s`` Python-style) or sent as they are to the PostgreSQL server
+  (which only supports ``$1``, ``$2`` parameters)?
+
+Psycopg exposes the following classes to implement the different strategies.
+All the classes are exposed by the main `!psycopg` package. Every class has
+also an `!Async`\ -prefixed counterparts, designed to be used in conjunction
+with `AsyncConnection` in `asyncio` programs.
+
+================= =========== =========== ==================== ==================================
+Class             Binding     Storage     Placeholders         See also
+================= =========== =========== ==================== ==================================
+`Cursor`          server-side client-side ``%s``, ``%(name)s`` :ref:`client-side-cursors`
+`ClientCursor`    cient-side  client-side ``%s``, ``%(name)s`` :ref:`client-side-binding-cursors`
+`ServerCursor`    server-side server-side ``%s``, ``%(name)s`` :ref:`server-side-cursors`
+`RawCursor`       server-side client-side ``$1``               :ref:`raw-query-cursors`
+================= =========== =========== ==================== ==================================
+
+If not specified by a `~Connection.cursor_factory`, `~Connection.cursor()`
+will usually produce `Cursor` objects.
+
 
 .. index::
     double: Cursor; Client-side
@@ -194,8 +230,8 @@ directly call the fetch methods, skipping the `~ServerCursor.execute()` call:
 
 .. _raw-query-cursors:
 
-Raw Query Cursors
-------------------
+Raw query cursors
+-----------------
 
 .. versionadded:: 3.2
 
index b6f51f17be31f3c7e8902e071c654f75f1c4fa5d..9abb3f989c3c19aa447f7b2cea58902f53ae2248 100644 (file)
@@ -11,11 +11,12 @@ Using the `!name` parameter on `!cursor()` will create a `ServerCursor` or
 `AsyncServerCursor`, which can be used to retrieve partial results from a
 database.
 
-A `Connection` can create several cursors, but only one at time can perform
-operations, so they are not the best way to achieve parallelism (you may want
-to operate with several connections instead). All the cursors on the same
-connection have a view of the same session, so they can see each other's
-uncommitted data.
+Other cursor classes can be created by directly instantiating them, or can be
+set as `Connection.cursor_factory` to require them on `!cursor()` call.
+
+This page describe the details of the `!Cursor` class interface. Please refer
+to :ref:`cursor-types` for general information about the different types of
+cursors available in Psycopg.
 
 
 The `!Cursor` class
index e80f4d5943698470080194a4a87eff73466e0cd7..b109c2716a8c27421edace9eb5806120864b0830 100644 (file)
@@ -92,12 +92,6 @@ server before performing a connection.
     .. warning::
         This is an experimental method.
 
-    .. versionchanged:: 3.1
-        Unlike the sync counterpart, perform non-blocking address
-        resolution and populate the ``hostaddr`` connection parameter,
-        unless the user has provided one themselves. See
-        `resolve_hostaddr_async()` for details.
-
 
 .. function:: resolve_hostaddr_async(params)
     :async:
index f09fbc637be48f32e9708a3439ed9f01afefe695..13719d404713adcc796d36f7a13a1ea6103e2c32 100644 (file)
@@ -6,7 +6,7 @@ Installation
 In short, if you use a :ref:`supported system<supported-systems>`::
 
     pip install --upgrade pip           # upgrade pip to at least 20.3
-    pip install "psycopg[binary]"
+    pip install "psycopg[binary]"       # remove [binary] for PyPy
 
 and you should be :ref:`ready to start <module-usage>`. Read further for
 alternative ways to install.
@@ -27,6 +27,10 @@ The Psycopg version documented here has *official and tested* support for:
   - Python 3.6 supported before Psycopg 3.1
   - Python 3.7 supported before Psycopg 3.2
 
+- PyPy: from version 3.9 to 3.10
+
+  - **Note:** Only the pure Python version is supported.
+
 - PostgreSQL: from version 10 to 16
 - OS: Linux, macOS, Windows
 
@@ -76,6 +80,10 @@ installation <local-installation>` or a :ref:`pure Python installation
     For further information about the differences between the packages see
     :ref:`pq-impl`.
 
+.. warning::
+
+   The binary installation is not supported by PyPy.
+
 
 .. _local-installation:
 
@@ -103,6 +111,10 @@ If your build prerequisites are in place you can run::
 
     pip install "psycopg[c]"
 
+.. warning::
+
+   The local installation is not supported by PyPy.
+
 
 .. _pure-python-installation:
 
index b4e4fd97c76e5dce604dc56bd9ec31f828d05975..c7540dfca2cf984719ddf8be10c9012fcc941500 100644 (file)
@@ -30,19 +30,57 @@ Psycopg 3.2 (unreleased)
 .. __: https://numpy.org/doc/stable/reference/arrays.scalars.html#built-in-scalar-types
 
 
-Psycopg 3.1.13 (unreleased)
+Psycopg 3.1.17 (unreleased)
 ^^^^^^^^^^^^^^^^^^^^^^^^^^^
 
+- Use `typing.Self` as a more correct return value annotation of context
+  managers and other self-returning methods (see :ticket:`708`).
+
+
+Current release
+---------------
+
+Psycopg 3.1.16
+^^^^^^^^^^^^^^
+
+- Fix empty ports handling in async multiple connection attempts
+  (:ticket:`#703`).
+
+
+Psycopg 3.1.15
+^^^^^^^^^^^^^^
+
+- Fix use of ``service`` in connection string (regression in 3.1.13,
+  :ticket:`#694`).
+- Fix async connection to hosts resolving to multiple IP addresses (regression
+  in 3.1.13, :ticket:`#695`).
+- Respect the :envvar:`PGCONNECT_TIMEOUT` environment variable to determine
+  the connection timeout.
+
+
+Psycopg 3.1.14
+^^^^^^^^^^^^^^
+
+- Fix :ref:`interaction with gevent <gevent>` (:ticket:`#527`).
+- Add support for PyPy (:ticket:`#686`).
+
+.. _gevent: https://www.gevent.org/
+
+
+Psycopg 3.1.13
+^^^^^^^^^^^^^^
+
 - Raise `DataError` instead of whatever internal failure trying to dump a
   `~datetime.time` object with with a `!tzinfo` specified as
   `~zoneinfo.ZoneInfo` (ambiguous offset, see :ticket:`#652`).
 - Handle gracefully EINTR on signals instead of raising `InterruptedError`,
   consistently with :pep:`475` guideline (:ticket:`#667`).
+- Fix support for connection strings with multiple hosts/ports and for the
+  ``load_balance_hosts`` connection parameter (:ticket:`#674`).
+- Fix memory leak receiving notifications in Python implementation
+  (:ticket:`#679`).
 
 
-Current release
----------------
-
 Psycopg 3.1.12
 ^^^^^^^^^^^^^^
 
index dc1c9140b550056f4715ef4f1cd98e326f1c3b85..003495cfc6af4cca04535d73086f01091acf67dd 100644 (file)
@@ -420,13 +420,10 @@ class BaseConnection(Generic[Row]):
     # should have a lock and hold it before calling and consuming them.
 
     @classmethod
-    def _connect_gen(
-        cls, conninfo: str = "", *, autocommit: bool = False
-    ) -> PQGenConn[Self]:
+    def _connect_gen(cls, conninfo: str = "") -> PQGenConn[Self]:
         """Generator to connect to the database and create a new instance."""
         pgconn = yield from generators.connect(conninfo)
         conn = cls(pgconn)
-        conn._autocommit = bool(autocommit)
         return conn
 
     def _exec_command(
index 1e146ba216c7bad4dac15f8138af5f3ddbdb7306..a9619b56d09396a4418a0b9404f7714ae977e5dc 100644 (file)
@@ -23,7 +23,7 @@ except ImportError:
     )
 
 from . import errors as e
-from .conninfo import resolve_hostaddr_async as resolve_hostaddr_async_
+from . import conninfo
 
 if TYPE_CHECKING:
     from dns.rdtypes.IN.SRV import SRV
@@ -48,7 +48,30 @@ async def resolve_hostaddr_async(params: Dict[str, Any]) -> Dict[str, Any]:
         "from psycopg 3.1, resolve_hostaddr_async() is not needed anymore",
         DeprecationWarning,
     )
-    return await resolve_hostaddr_async_(params)
+    hosts: list[str] = []
+    hostaddrs: list[str] = []
+    ports: list[str] = []
+
+    for attempt in await conninfo.conninfo_attempts_async(params):
+        if attempt.get("host") is not None:
+            hosts.append(attempt["host"])
+        if attempt.get("hostaddr") is not None:
+            hostaddrs.append(attempt["hostaddr"])
+        if attempt.get("port") is not None:
+            ports.append(str(attempt["port"]))
+
+    out = params.copy()
+    shosts = ",".join(hosts)
+    if shosts:
+        out["host"] = shosts
+    shostaddrs = ",".join(hostaddrs)
+    if shostaddrs:
+        out["hostaddr"] = shostaddrs
+    sports = ",".join(ports)
+    if ports:
+        out["port"] = sports
+
+    return out
 
 
 def resolve_srv(params: Dict[str, Any]) -> Dict[str, Any]:
index fe5a2a02a9bb10a45258e4a94d5c3c930282f2d9..dc02ce3815c359a6759710a174e7a145db596a20 100644 (file)
@@ -11,7 +11,7 @@ from __future__ import annotations
 
 import logging
 from types import TracebackType
-from typing import Any, Generator, Iterator, Dict, List, Optional
+from typing import Any, Generator, Iterator, List, Optional
 from typing import Type, Union, cast, overload, TYPE_CHECKING
 from contextlib import contextmanager
 
@@ -24,7 +24,8 @@ from .rows import Row, RowFactory, tuple_row, args_row
 from .adapt import AdaptersMap
 from ._enums import IsolationLevel
 from ._compat import Self
-from .conninfo import make_conninfo, conninfo_to_dict
+from .conninfo import ConnDict, make_conninfo, conninfo_to_dict
+from .conninfo import conninfo_attempts, timeout_from_conninfo
 from ._pipeline import Pipeline
 from ._encodings import pgconn_encoding
 from .generators import notifies
@@ -89,16 +90,30 @@ class Connection(BaseConnection[Row]):
         """
 
         params = cls._get_connection_params(conninfo, **kwargs)
-        conninfo = make_conninfo(**params)
-
-        try:
-            rv = cls._wait_conn(
-                cls._connect_gen(conninfo, autocommit=autocommit),
-                timeout=params["connect_timeout"],
-            )
-        except e._NO_TRACEBACK as ex:
-            raise ex.with_traceback(None)
-
+        timeout = timeout_from_conninfo(params)
+        rv = None
+        attempts = conninfo_attempts(params)
+        for attempt in attempts:
+            try:
+                conninfo = make_conninfo(**attempt)
+                rv = cls._wait_conn(cls._connect_gen(conninfo), timeout=timeout)
+                break
+            except e._NO_TRACEBACK as ex:
+                if len(attempts) > 1:
+                    logger.debug(
+                        "connection attempt failed: host: %r port: %r, hostaddr %r: %s",
+                        attempt.get("host"),
+                        attempt.get("port"),
+                        attempt.get("hostaddr"),
+                        str(ex),
+                    )
+                last_ex = ex
+
+        if not rv:
+            assert last_ex
+            raise last_ex.with_traceback(None)
+
+        rv._autocommit = bool(autocommit)
         if row_factory:
             rv.row_factory = row_factory
         if cursor_factory:
@@ -135,23 +150,9 @@ class Connection(BaseConnection[Row]):
             self.close()
 
     @classmethod
-    def _get_connection_params(cls, conninfo: str, **kwargs: Any) -> Dict[str, Any]:
-        """Manipulate connection parameters before connecting.
-
-        :param conninfo: Connection string as received by `~Connection.connect()`.
-        :param kwargs: Overriding connection arguments as received by `!connect()`.
-        :return: Connection arguments merged and eventually modified, in a
-            format similar to `~conninfo.conninfo_to_dict()`.
-        """
-        params = conninfo_to_dict(conninfo, **kwargs)
-
-        # Make sure there is an usable connect_timeout
-        if "connect_timeout" in params:
-            params["connect_timeout"] = int(params["connect_timeout"])
-        else:
-            params["connect_timeout"] = None
-
-        return params
+    def _get_connection_params(cls, conninfo: str, **kwargs: Any) -> ConnDict:
+        """Manipulate connection parameters before connecting."""
+        return conninfo_to_dict(conninfo, **kwargs)
 
     def close(self) -> None:
         """Close the database connection."""
index b08375442290857dd3233faf96fd12d003adb252..46269fda4f6fc43bc2b0684a3bbf6937c01fd87b 100644 (file)
@@ -8,7 +8,7 @@ from __future__ import annotations
 
 import logging
 from types import TracebackType
-from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional
+from typing import Any, AsyncGenerator, AsyncIterator, List, Optional
 from typing import Type, Union, cast, overload, TYPE_CHECKING
 from contextlib import asynccontextmanager
 
@@ -21,7 +21,8 @@ from .rows import Row, AsyncRowFactory, tuple_row, args_row
 from .adapt import AdaptersMap
 from ._enums import IsolationLevel
 from ._compat import Self
-from .conninfo import make_conninfo, conninfo_to_dict
+from .conninfo import ConnDict, make_conninfo, conninfo_to_dict
+from .conninfo import conninfo_attempts_async, timeout_from_conninfo
 from ._pipeline import AsyncPipeline
 from ._encodings import pgconn_encoding
 from .generators import notifies
@@ -34,7 +35,6 @@ if True:  # ASYNC
     import sys
     import asyncio
     from asyncio import Lock
-    from .conninfo import resolve_hostaddr_async
 else:
     from threading import Lock
 
@@ -105,16 +105,30 @@ class AsyncConnection(BaseConnection[Row]):
                     )
 
         params = await cls._get_connection_params(conninfo, **kwargs)
-        conninfo = make_conninfo(**params)
+        timeout = timeout_from_conninfo(params)
+        rv = None
+        attempts = await conninfo_attempts_async(params)
+        for attempt in attempts:
+            try:
+                conninfo = make_conninfo(**attempt)
+                rv = await cls._wait_conn(cls._connect_gen(conninfo), timeout=timeout)
+                break
+            except e._NO_TRACEBACK as ex:
+                if len(attempts) > 1:
+                    logger.debug(
+                        "connection attempt failed: host: %r port: %r, hostaddr %r: %s",
+                        attempt.get("host"),
+                        attempt.get("port"),
+                        attempt.get("hostaddr"),
+                        str(ex),
+                    )
+                last_ex = ex
 
-        try:
-            rv = await cls._wait_conn(
-                cls._connect_gen(conninfo, autocommit=autocommit),
-                timeout=params["connect_timeout"],
-            )
-        except e._NO_TRACEBACK as ex:
-            raise ex.with_traceback(None)
+        if not rv:
+            assert last_ex
+            raise last_ex.with_traceback(None)
 
+        rv._autocommit = bool(autocommit)
         if row_factory:
             rv.row_factory = row_factory
         if cursor_factory:
@@ -151,29 +165,9 @@ class AsyncConnection(BaseConnection[Row]):
             await self.close()
 
     @classmethod
-    async def _get_connection_params(
-        cls, conninfo: str, **kwargs: Any
-    ) -> Dict[str, Any]:
-        """Manipulate connection parameters before connecting.
-
-        :param conninfo: Connection string as received by `~Connection.connect()`.
-        :param kwargs: Overriding connection arguments as received by `!connect()`.
-        :return: Connection arguments merged and eventually modified, in a
-            format similar to `~conninfo.conninfo_to_dict()`.
-        """
-        params = conninfo_to_dict(conninfo, **kwargs)
-
-        # Make sure there is an usable connect_timeout
-        if "connect_timeout" in params:
-            params["connect_timeout"] = int(params["connect_timeout"])
-        else:
-            params["connect_timeout"] = None
-
-        if True:  # ASYNC
-            # Resolve host addresses in non-blocking way
-            params = await resolve_hostaddr_async(params)
-
-        return params
+    async def _get_connection_params(cls, conninfo: str, **kwargs: Any) -> ConnDict:
+        """Manipulate connection parameters before connecting."""
+        return conninfo_to_dict(conninfo, **kwargs)
 
     async def close(self) -> None:
         """Close the database connection."""
index 38d1f7dabc43d16a43e18b990156ca6b22f60c4c..c57a6d9c095ff4d70fcd0fb528b36280788d306e 100644 (file)
@@ -4,21 +4,36 @@ Functions to manipulate conninfo strings
 
 # Copyright (C) 2020 The Psycopg Team
 
+from __future__ import annotations
+
 import os
 import re
 import socket
 import asyncio
-from typing import Any, Dict, List, Optional
+import logging
+from typing import Any
+from random import shuffle
 from pathlib import Path
 from datetime import tzinfo
 from functools import lru_cache
 from ipaddress import ip_address
+from dataclasses import dataclass
+from typing_extensions import TypeAlias
 
 from . import pq
 from . import errors as e
 from ._tz import get_tzinfo
 from ._encodings import pgconn_encoding
 
+ConnDict: TypeAlias = "dict[str, Any]"
+
+# Default timeout for connection a attempt.
+# Arbitrary timeout, what applied by the libpq on my computer.
+# Your mileage won't vary.
+_DEFAULT_CONNECT_TIMEOUT = 130
+
+logger = logging.getLogger("psycopg")
+
 
 def make_conninfo(conninfo: str = "", **kwargs: Any) -> str:
     """
@@ -61,7 +76,7 @@ def make_conninfo(conninfo: str = "", **kwargs: Any) -> str:
     return conninfo
 
 
-def conninfo_to_dict(conninfo: str = "", **kwargs: Any) -> Dict[str, Any]:
+def conninfo_to_dict(conninfo: str = "", **kwargs: Any) -> ConnDict:
     """
     Convert the `!conninfo` string into a dictionary of parameters.
 
@@ -84,7 +99,7 @@ def conninfo_to_dict(conninfo: str = "", **kwargs: Any) -> Dict[str, Any]:
     return rv
 
 
-def _parse_conninfo(conninfo: str) -> List[pq.ConninfoOption]:
+def _parse_conninfo(conninfo: str) -> list[pq.ConninfoOption]:
     """
     Verify that `!conninfo` is a valid connection string.
 
@@ -167,7 +182,7 @@ class ConnectionInfo:
         """
         return self._get_pgconn_attr("options")
 
-    def get_parameters(self) -> Dict[str, str]:
+    def get_parameters(self) -> dict[str, str]:
         """Return the connection parameters values.
 
         Return all the parameters set to a non-default value, which might come
@@ -228,7 +243,7 @@ class ConnectionInfo:
         """
         return pq.PipelineStatus(self.pgconn.pipeline_status)
 
-    def parameter_status(self, param_name: str) -> Optional[str]:
+    def parameter_status(self, param_name: str) -> str | None:
         """
         Return a parameter setting of the connection.
 
@@ -275,97 +290,228 @@ class ConnectionInfo:
         return value.decode(self.encoding)
 
 
-async def resolve_hostaddr_async(params: Dict[str, Any]) -> Dict[str, Any]:
+def conninfo_attempts(params: ConnDict) -> list[ConnDict]:
+    """Split a set of connection params on the single attempts to perform.
+
+    A connection param can perform more than one attempt more than one ``host``
+    is provided.
+
+    Because the libpq async function doesn't honour the timeout, we need to
+    reimplement the repeated attempts.
     """
-    Perform async DNS lookup of the hosts and return a new params dict.
+    # TODO: we should actually resolve the hosts ourselves.
+    # If an host resolves to more than one ip, the libpq will make more than
+    # one attempt and wouldn't get to try the following ones, as before
+    # fixing #674.
+    attempts = _split_attempts(params)
+    if _get_param(params, "load_balance_hosts") == "random":
+        shuffle(attempts)
+    return attempts
 
-    :param params: The input parameters, for instance as returned by
-        `~psycopg.conninfo.conninfo_to_dict()`.
+
+async def conninfo_attempts_async(params: ConnDict) -> list[ConnDict]:
+    """Split a set of connection params on the single attempts to perform.
+
+    A connection param can perform more than one attempt more than one ``host``
+    is provided.
+
+    Also perform async resolution of the hostname into hostaddr in order to
+    avoid blocking. Because a host can resolve to more than one address, this
+    can lead to yield more attempts too. Raise `OperationalError` if no host
+    could be resolved.
+
+    Because the libpq async function doesn't honour the timeout, we need to
+    reimplement the repeated attempts.
+    """
+    last_exc = None
+    attempts = []
+    for attempt in _split_attempts(params):
+        try:
+            attempts.extend(await _resolve_hostnames(attempt))
+        except OSError as ex:
+            logger.debug("failed to resolve host %r: %s", attempt.get("host"), str(ex))
+            last_exc = ex
+
+    if not attempts:
+        assert last_exc
+        # We couldn't resolve anything
+        raise e.OperationalError(str(last_exc))
+
+    if _get_param(params, "load_balance_hosts") == "random":
+        shuffle(attempts)
+
+    return attempts
+
+
+def _split_attempts(params: ConnDict) -> list[ConnDict]:
+    """
+    Split connection parameters with a sequence of hosts into separate attempts.
+    """
+
+    def split_val(key: str) -> list[str]:
+        val = _get_param(params, key)
+        return val.split(",") if val else []
+
+    hosts = split_val("host")
+    hostaddrs = split_val("hostaddr")
+    ports = split_val("port")
+
+    if hosts and hostaddrs and len(hosts) != len(hostaddrs):
+        raise e.OperationalError(
+            f"could not match {len(hosts)} host names"
+            f" with {len(hostaddrs)} hostaddr values"
+        )
+
+    nhosts = max(len(hosts), len(hostaddrs))
+
+    if 1 < len(ports) != nhosts:
+        raise e.OperationalError(
+            f"could not match {len(ports)} port numbers to {len(hosts)} hosts"
+        )
+
+    # A single attempt to make. Don't mangle the conninfo string.
+    if nhosts <= 1:
+        return [params]
+
+    if len(ports) == 1:
+        ports *= nhosts
+
+    # Now all lists are either empty or have the same length
+    rv = []
+    for i in range(nhosts):
+        attempt = params.copy()
+        if hosts:
+            attempt["host"] = hosts[i]
+        if hostaddrs:
+            attempt["hostaddr"] = hostaddrs[i]
+        if ports:
+            attempt["port"] = ports[i]
+        rv.append(attempt)
+
+    return rv
+
+
+async def _resolve_hostnames(params: ConnDict) -> list[ConnDict]:
+    """
+    Perform async DNS lookup of the hosts and return a new params dict.
 
     If a ``host`` param is present but not ``hostname``, resolve the host
-    addresses dynamically.
+    addresses asynchronously.
 
-    The function may change the input ``host``, ``hostname``, ``port`` to allow
-    connecting without further DNS lookups, eventually removing hosts that are
-    not resolved, keeping the lists of hosts and ports consistent.
+    :param params: The input parameters, for instance as returned by
+        `~psycopg.conninfo.conninfo_to_dict()`. The function expects at most
+        a single entry for host, hostaddr because it is designed to further
+        process the input of _split_attempts().
 
-    Raise `~psycopg.OperationalError` if connection is not possible (e.g. no
-    host resolve, inconsistent lists length).
+    :return: A list of attempts to make (to include the case of a hostname
+        resolving to more than one IP).
     """
-    hostaddr_arg = params.get("hostaddr", os.environ.get("PGHOSTADDR", ""))
-    if hostaddr_arg:
+    host = _get_param(params, "host")
+    if not host or host.startswith("/") or host[1:2] == ":":
+        # Local path, or no host to resolve
+        return [params]
+
+    hostaddr = _get_param(params, "hostaddr")
+    if hostaddr:
         # Already resolved
-        return params
-
-    host_arg: str = params.get("host", os.environ.get("PGHOST", ""))
-    if not host_arg:
-        # Nothing to resolve
-        return params
-
-    hosts_in = host_arg.split(",")
-    port_arg: str = str(params.get("port", os.environ.get("PGPORT", "")))
-    ports_in = port_arg.split(",") if port_arg else []
-    default_port = "5432"
-
-    if len(ports_in) == 1:
-        # If only one port is specified, the libpq will apply it to all
-        # the hosts, so don't mangle it.
-        default_port = ports_in.pop()
-
-    elif len(ports_in) > 1:
-        if len(ports_in) != len(hosts_in):
-            # ProgrammingError would have been more appropriate, but this is
-            # what the raise if the libpq fails connect in the same case.
-            raise e.OperationalError(
-                f"cannot match {len(hosts_in)} hosts with {len(ports_in)} port numbers"
-            )
-        ports_out = []
+        return [params]
+
+    if is_ip_address(host):
+        # If the host is already an ip address don't try to resolve it
+        return [{**params, "hostaddr": host}]
 
-    hosts_out = []
-    hostaddr_out = []
     loop = asyncio.get_running_loop()
-    for i, host in enumerate(hosts_in):
-        if not host or host.startswith("/") or host[1:2] == ":":
-            # Local path
-            hosts_out.append(host)
-            hostaddr_out.append("")
-            if ports_in:
-                ports_out.append(ports_in[i])
-            continue
 
-        # If the host is already an ip address don't try to resolve it
-        if is_ip_address(host):
-            hosts_out.append(host)
-            hostaddr_out.append(host)
-            if ports_in:
-                ports_out.append(ports_in[i])
-            continue
+    port = _get_param(params, "port")
+    if not port:
+        port_def = _get_param_def("port")
+        port = port_def and port_def.compiled or "5432"
+
+    ans = await loop.getaddrinfo(
+        host, int(port), proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM
+    )
+    return [{**params, "hostaddr": item[4][0]} for item in ans]
+
+
+def timeout_from_conninfo(params: ConnDict) -> int:
+    """
+    Return the timeout in seconds from the connection parameters.
+    """
+    # Follow the libpq convention:
+    #
+    # - 0 or less means no timeout (but we will use a default to simulate
+    #   the socket timeout)
+    # - at least 2 seconds.
+    #
+    # See connectDBComplete in fe-connect.c
+    value: str | int | None = _get_param(params, "connect_timeout")
+    if value is None:
+        value = _DEFAULT_CONNECT_TIMEOUT
+    try:
+        timeout = int(value)
+    except ValueError:
+        raise e.ProgrammingError(f"bad value for connect_timeout: {value!r}")
+
+    if timeout <= 0:
+        # The sync connect function will stop on the default socket timeout
+        # Because in async connection mode we need to enforce the timeout
+        # ourselves, we need a finite value.
+        timeout = _DEFAULT_CONNECT_TIMEOUT
+    elif timeout < 2:
+        # Enforce a 2s min
+        timeout = 2
+
+    return timeout
 
-        try:
-            port = ports_in[i] if ports_in else default_port
-            ans = await loop.getaddrinfo(
-                host, port, proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM
-            )
-        except OSError as ex:
-            last_exc = ex
-        else:
-            for item in ans:
-                hosts_out.append(host)
-                hostaddr_out.append(item[4][0])
-                if ports_in:
-                    ports_out.append(ports_in[i])
-
-    # Throw an exception if no host could be resolved
-    if not hosts_out:
-        raise e.OperationalError(str(last_exc))
 
-    out = params.copy()
-    out["host"] = ",".join(hosts_out)
-    out["hostaddr"] = ",".join(hostaddr_out)
-    if ports_in:
-        out["port"] = ",".join(ports_out)
+def _get_param(params: ConnDict, name: str) -> str | None:
+    """
+    Return a value from a connection string.
+
+    The value may be also specified in a PG* env var.
+    """
+    if name in params:
+        return str(params[name])
+
+    # TODO: check if in service
+
+    paramdef = _get_param_def(name)
+    if not paramdef:
+        return None
+
+    env = os.environ.get(paramdef.envvar)
+    if env is not None:
+        return env
+
+    return None
+
+
+@dataclass
+class ParamDef:
+    """
+    Information about defaults and env vars for connection params
+    """
+
+    keyword: str
+    envvar: str
+    compiled: str | None
+
+
+def _get_param_def(keyword: str, _cache: dict[str, ParamDef] = {}) -> ParamDef | None:
+    """
+    Return the ParamDef of a connection string parameter.
+    """
+    if not _cache:
+        defs = pq.Conninfo.get_defaults()
+        for d in defs:
+            cd = ParamDef(
+                keyword=d.keyword.decode(),
+                envvar=d.envvar.decode() if d.envvar else "",
+                compiled=d.compiled.decode() if d.compiled is not None else None,
+            )
+            _cache[cd.keyword] = cd
 
-    return out
+    return _cache.get(keyword)
 
 
 @lru_cache()
index 98500a64ba1167567694d4e3b8524b134ef91c17..10741c95fd2eec7c4dd8f408c8a7c236986b8d98 100644 (file)
@@ -34,13 +34,11 @@ class Cursor(BaseCursor["Connection[Any]", Row]):
     __slots__ = ()
 
     @overload
-    def __init__(self: Cursor[Row], connection: Connection[Row]):
+    def __init__(self, connection: Connection[Row]):
         ...
 
     @overload
-    def __init__(
-        self: Cursor[Row], connection: Connection[Any], *, row_factory: RowFactory[Row]
-    ):
+    def __init__(self, connection: Connection[Any], *, row_factory: RowFactory[Row]):
         ...
 
     def __init__(
index 6c6d3f814855a82bab22307238c1e47fa3cd9f57..603560155c2800b6cda9206a38da7df9ab33d044 100644 (file)
@@ -31,15 +31,12 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
     __slots__ = ()
 
     @overload
-    def __init__(self: AsyncCursor[Row], connection: AsyncConnection[Row]):
+    def __init__(self, connection: AsyncConnection[Row]):
         ...
 
     @overload
     def __init__(
-        self: AsyncCursor[Row],
-        connection: AsyncConnection[Any],
-        *,
-        row_factory: AsyncRowFactory[Row],
+        self, connection: AsyncConnection[Any], *, row_factory: AsyncRowFactory[Row]
     ):
         ...
 
index 07c2d95cca50d57c4823aad837441faa34b3b977..f04a803679fcabbd6ba8395b0c1a4c974533a640 100644 (file)
@@ -47,7 +47,10 @@ def version() -> int:
 
 @impl.PQnoticeReceiver  # type: ignore
 def notice_receiver(arg: c_void_p, result_ptr: impl.PGresult_struct) -> None:
-    pgconn = cast(arg, POINTER(py_object)).contents.value()
+    pgconn = cast(arg, POINTER(py_object)).contents.value
+    if callable(pgconn):  # Not a weak reference on PyPy.
+        pgconn = pgconn()
+
     if not (pgconn and pgconn.notice_handler):
         return
 
@@ -604,8 +607,9 @@ class PGconn:
         ptr = impl.PQnotifies(self._pgconn_ptr)
         if ptr:
             c = ptr.contents
-            return PGnotify(c.relname, c.be_pid, c.extra)
+            rv = PGnotify(c.relname, c.be_pid, c.extra)
             impl.PQfreemem(ptr)
+            return rv
         else:
             return None
 
index 7039d2950105cb7d706ee22c2be55e2c79d07622..1c6e77aa10f96b33647029e9be2083edbf29c7a4 100644 (file)
@@ -216,7 +216,7 @@ class ServerCursor(ServerCursorMixin["Connection[Any]", Row], Cursor[Row]):
 
     @overload
     def __init__(
-        self: "ServerCursor[Row]",
+        self,
         connection: "Connection[Row]",
         name: str,
         *,
@@ -227,7 +227,7 @@ class ServerCursor(ServerCursorMixin["Connection[Any]", Row], Cursor[Row]):
 
     @overload
     def __init__(
-        self: "ServerCursor[Row]",
+        self,
         connection: "Connection[Any]",
         name: str,
         *,
@@ -357,7 +357,7 @@ class AsyncServerCursor(
 
     @overload
     def __init__(
-        self: "AsyncServerCursor[Row]",
+        self,
         connection: "AsyncConnection[Row]",
         name: str,
         *,
@@ -368,7 +368,7 @@ class AsyncServerCursor(
 
     @overload
     def __init__(
-        self: "AsyncServerCursor[Row]",
+        self,
         connection: "AsyncConnection[Any]",
         name: str,
         *,
index deee10cba3bbd922678f72b80239580c1e0600e6..d6db0d922e899f53c130d538069fd26e7697243e 100644 (file)
@@ -12,6 +12,7 @@ These functions are designed to consume the generators returned by the
 import os
 import sys
 import select
+import logging
 import selectors
 from typing import Optional
 from asyncio import get_event_loop, wait_for, Event, TimeoutError
@@ -29,6 +30,8 @@ READY_R = Ready.R
 READY_W = Ready.W
 READY_RW = Ready.RW
 
+logger = logging.getLogger(__name__)
+
 
 def wait_selector(gen: PQGen[RV], fileno: int, timeout: Optional[float] = None) -> RV:
     """
@@ -356,6 +359,27 @@ def wait_poll(gen: PQGen[RV], fileno: int, timeout: Optional[float] = None) -> R
         return rv
 
 
+def _is_select_patched() -> bool:
+    """
+    Detect if some greenlet library has patched the select library.
+
+    If this is the case, avoid to use the wait_c function as it doesn't behave
+    in a collaborative way.
+
+    Currently supported: gevent.
+    """
+    # If not imported, don't import it.
+    m = sys.modules.get("gevent.monkey")
+    if m:
+        try:
+            if m.is_module_patched("select"):
+                return True
+        except Exception as ex:
+            logger.warning("failed to detect gevent monkey-patching: %s", ex)
+
+    return False
+
+
 if _psycopg:
     wait_c = _psycopg.wait_c
 
@@ -380,7 +404,7 @@ if "PSYCOPG_WAIT_FUNC" in os.environ:
 # On Windows, for the moment, avoid using wait_c, because it was reported to
 # use excessive CPU (see #645).
 # TODO: investigate why.
-elif _psycopg and sys.platform != "win32":
+elif _psycopg and sys.platform != "win32" and not _is_select_patched():
     wait = wait_c
 
 elif selectors.DefaultSelector is getattr(selectors, "SelectSelector", None):
index f335c589655d0e107029339cafb67aab20053445..f734c40ec99a0f178ac44795df1bf0e318b36882 100644 (file)
@@ -38,6 +38,8 @@ classifiers =
     Programming Language :: Python :: 3.10
     Programming Language :: Python :: 3.11
     Programming Language :: Python :: 3.12
+    Programming Language :: Python :: Implementation :: CPython
+    Programming Language :: Python :: Implementation :: PyPy
     Topic :: Database
     Topic :: Database :: Front-Ends
     Topic :: Software Development
@@ -58,9 +60,9 @@ install_requires =
 
 [options.extras_require]
 c =
-    psycopg-c == 3.2.0.dev1
+    psycopg-c == 3.2.0.dev1; implementation_name != "pypy"
 binary =
-    psycopg-binary == 3.2.0.dev1
+    psycopg-binary == 3.2.0.dev1; implementation_name != "pypy"
 pool =
     psycopg-pool
 test =
@@ -71,9 +73,9 @@ test =
     pytest-cov >= 3.0
     pytest-randomly >= 3.5
 dev =
-    # Version pinned to work around https://github.com/t3rn0/ast-comments/issues/21
-    ast-comments==1.1.0
+    ast-comments >= 1.1.2
     black >= 23.1.0
+    codespell >= 2.2
     dnspython >= 2.1
     flake8 >= 4.0
     mypy >= 1.6
index 730596f4d8ee91b5dba47551ddab222ef1523fae..87bc8a39a054e625b7d7fe2fd2f98dc33b30b222 100644 (file)
@@ -22,12 +22,14 @@ classifiers =
     Operating System :: MacOS :: MacOS X
     Operating System :: Microsoft :: Windows
     Operating System :: POSIX
+    Programming Language :: Cython
     Programming Language :: Python :: 3
     Programming Language :: Python :: 3.8
     Programming Language :: Python :: 3.9
     Programming Language :: Python :: 3.10
     Programming Language :: Python :: 3.11
     Programming Language :: Python :: 3.12
+    Programming Language :: Python :: Implementation :: CPython
     Topic :: Database
     Topic :: Database :: Front-Ends
     Topic :: Software Development
index 0d4f476cb18e7ad5a5faa412867cb753c60c5951..5889b1738aab9c43cd182b60ebbbf3c2ff788e7a 100644 (file)
@@ -32,6 +32,8 @@ classifiers =
     Programming Language :: Python :: 3.10
     Programming Language :: Python :: 3.11
     Programming Language :: Python :: 3.12
+    Programming Language :: Python :: Implementation :: CPython
+    Programming Language :: Python :: Implementation :: PyPy
     Topic :: Database
     Topic :: Database :: Front-Ends
     Topic :: Software Development
index 7735c18b54497a9dd3c217c53d0f71a1152037f9..f33ac7ff81b50cdb9b285ea349195b656e7e2748 100644 (file)
@@ -3,6 +3,7 @@ requires = ["setuptools>=49.2.0", "wheel>=0.37"]
 build-backend = "setuptools.build_meta"
 
 [tool.pytest.ini_options]
+addopts = "-ra"
 filterwarnings = [
     "error",
 ]
index 296a7f7f4b95a58246d30fca4da4b8610e68c09f..56d3859266f333432b4b985ef5b15cf8a4ede422 100644 (file)
@@ -7,6 +7,15 @@ from dataclasses import dataclass
 
 import pytest
 import psycopg
+from psycopg.conninfo import conninfo_to_dict
+
+try:
+    from psycopg.conninfo import _DEFAULT_CONNECT_TIMEOUT as DEFAULT_TIMEOUT
+except ImportError:
+    # Allow tests to import (not necessarily to pass all) if the psycopg module
+    # imported is not the one expected (e.g. running psycopg pool tests on the
+    # master branch with psycopg 3.1.x imported).
+    DEFAULT_TIMEOUT = 130
 
 
 @pytest.fixture
@@ -75,17 +84,17 @@ conninfo_params_timeout = [
     (
         "",
         {"dbname": "mydb", "connect_timeout": None},
-        ({"dbname": "mydb"}, None),
+        ({"dbname": "mydb"}, DEFAULT_TIMEOUT),
     ),
     (
         "",
         {"dbname": "mydb", "connect_timeout": 1},
-        ({"dbname": "mydb", "connect_timeout": "1"}, 1),
+        ({"dbname": "mydb", "connect_timeout": 1}, 2),
     ),
     (
         "dbname=postgres",
         {},
-        ({"dbname": "postgres"}, None),
+        ({"dbname": "postgres"}, DEFAULT_TIMEOUT),
     ),
     (
         "dbname=postgres connect_timeout=2",
@@ -95,6 +104,21 @@ conninfo_params_timeout = [
     (
         "postgresql:///postgres?connect_timeout=2",
         {"connect_timeout": 10},
-        ({"dbname": "postgres", "connect_timeout": "10"}, 10),
+        ({"dbname": "postgres", "connect_timeout": 10}, 10),
     ),
 ]
+
+
+def drop_default_args_from_conninfo(conninfo):
+    if isinstance(conninfo, str):
+        params = conninfo_to_dict(conninfo)
+    else:
+        params = conninfo.copy()
+
+    def removeif(key, value):
+        if params.get(key) == value:
+            params.pop(key)
+
+    removeif("connect_timeout", str(DEFAULT_TIMEOUT))
+
+    return params
index 05d79f9990db44352e75afb8ca0053e922532cef..6ad37840eaaf390b1f29241d368d07b06d1670bd 100644 (file)
@@ -1,6 +1,6 @@
-import sys
 import asyncio
 import selectors
+import sys
 from typing import Any, Dict, List
 
 import pytest
@@ -13,6 +13,7 @@ pytest_plugins = (
     "tests.fix_proxy",
     "tests.fix_psycopg",
     "tests.fix_crdb",
+    "tests.fix_gc",
     "tests.pool.fix_pool",
 )
 
@@ -25,6 +26,7 @@ def pytest_configure(config):
         # catch the exception for my life.
         "subprocess: the test import psycopg after subprocess",
         "timing: the test is timing based and can fail on cheese hardware",
+        "gevent: the test requires the gevent module to be installed",
         "dns: the test requires dnspython to run",
         "postgis: the test requires the PostGIS extension to run",
         "numpy: the test requires numpy module to be installed",
index 2bf714f1c3389ebb4525b52a56300ab8b87e8ac1..4100c33f605a6478d93111c9c822c23e43187849 100644 (file)
@@ -7,7 +7,7 @@ from psycopg.pq import Format
 from psycopg.adapt import PyFormat
 from psycopg.types.numeric import Int4
 
-from ..utils import eur, gc_collect, gc_count
+from ..utils import eur
 from .._test_copy import sample_text, sample_binary  # noqa
 from .._test_copy import ensure_table, sample_records
 from .._test_copy import sample_tabledef as sample_tabledef_pg
@@ -191,7 +191,7 @@ from copy_in group by 1, 2, 3
     [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)],
 )
 @pytest.mark.crdb_skip("copy array")
-def test_copy_from_leaks(conn_cls, dsn, faker, fmt, set_types):
+def test_copy_from_leaks(conn_cls, dsn, faker, fmt, set_types, gc):
     faker.format = PyFormat.from_pq(fmt)
     faker.choose_schema(ncols=20)
     faker.make_records(20)
@@ -219,12 +219,12 @@ def test_copy_from_leaks(conn_cls, dsn, faker, fmt, set_types):
                 for got, want in zip(recs, faker.records):
                     faker.assert_record(got, want)
 
-    gc_collect()
+    gc.collect()
     n = []
     for i in range(3):
         work()
-        gc_collect()
-        n.append(gc_count())
+        gc.collect()
+        n.append(gc.count())
 
     assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
 
index a994d9071fad16cdf7bc9325d97ad27783416256..17a37a95fa1517fdb3875df81a997818c0826568 100644 (file)
@@ -7,7 +7,7 @@ from psycopg import sql, errors as e
 from psycopg.adapt import PyFormat
 from psycopg.types.numeric import Int4
 
-from ..utils import eur, gc_collect, gc_count
+from ..utils import eur
 from .._test_copy import sample_text, sample_binary  # noqa
 from .._test_copy import ensure_table_async, sample_records
 from .test_copy import sample_tabledef, copyopt
@@ -196,7 +196,7 @@ from copy_in group by 1, 2, 3
     [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)],
 )
 @pytest.mark.crdb_skip("copy array")
-async def test_copy_from_leaks(aconn_cls, dsn, faker, fmt, set_types):
+async def test_copy_from_leaks(aconn_cls, dsn, faker, fmt, set_types, gc):
     faker.format = PyFormat.from_pq(fmt)
     faker.choose_schema(ncols=20)
     faker.make_records(20)
@@ -224,11 +224,11 @@ async def test_copy_from_leaks(aconn_cls, dsn, faker, fmt, set_types):
                 for got, want in zip(recs, faker.records):
                     faker.assert_record(got, want)
 
-    gc_collect()
+    gc.collect()
     n = []
     for i in range(3):
         await work()
-        gc_collect()
-        n.append(gc_count())
+        gc.collect()
+        n.append(gc.count())
 
     assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
index 890e4ed5a167c050cb0b3a864253d2a6f235fc01..9abeda79e9b5c13d46a4a4dfe2e3c03c4cd2e98a 100644 (file)
@@ -9,6 +9,7 @@ from typing import Optional
 import psycopg
 from psycopg import pq
 from psycopg import sql
+from psycopg.conninfo import conninfo_to_dict, make_conninfo
 from psycopg._compat import cache
 from psycopg.pq._debug import PGconnDebug
 
@@ -104,6 +105,23 @@ def dsn(session_dsn, request):
     return session_dsn
 
 
+@pytest.fixture
+def dsn_env(dsn):
+    """Return a dsn including the connection parameters set in PG* env vars.
+
+    Provide a working conninfo even in tests that modify the env vars.
+    """
+    args = conninfo_to_dict(dsn)
+    for opt in pq.Conninfo.get_defaults():
+        if not (opt.envvar and opt.envvar.decode() in os.environ):
+            continue
+        if opt.keyword.decode() in args:
+            continue
+        args[opt.keyword.decode()] = os.environ[opt.envvar.decode()]
+
+    return make_conninfo(**args)
+
+
 @pytest.fixture(scope="session")
 def tracefile(request):
     """Open and yield a file for libpq client/server communication traces if
diff --git a/tests/fix_gc.py b/tests/fix_gc.py
new file mode 100644 (file)
index 0000000..ead6c6b
--- /dev/null
@@ -0,0 +1,85 @@
+import gc
+import sys
+from typing import Tuple
+
+import pytest
+
+
+def pytest_collection_modifyitems(items):
+    for item in items:
+        if "gc" in item.fixturenames:
+            item.add_marker(pytest.mark.refcount)
+
+
+def pytest_configure(config):
+    config.addinivalue_line(
+        "markers",
+        "refcount: the test checks ref counts which is sometimes flaky",
+    )
+
+
+NO_COUNT_TYPES: Tuple[type, ...] = ()
+
+if sys.version_info[:2] == (3, 10):
+    # On my laptop there are occasional creations of a single one of these objects
+    # with empty content, which might be some Decimal caching.
+    # Keeping the guard as strict as possible, to be extended if other types
+    # or versions are necessary.
+    try:
+        from _contextvars import Context  # type: ignore
+    except ImportError:
+        pass
+    else:
+        NO_COUNT_TYPES += (Context,)
+
+
+class GCFixture:
+    __slots__ = ()
+
+    @staticmethod
+    def collect() -> None:
+        """
+        gc.collect(), but more insisting.
+        """
+        for i in range(3):
+            gc.collect()
+
+    @staticmethod
+    def count() -> int:
+        """
+        len(gc.get_objects()), with subtleties.
+        """
+
+        if not NO_COUNT_TYPES:
+            return len(gc.get_objects())
+
+        # Note: not using a list comprehension because it pollutes the objects list.
+        rv = 0
+        for obj in gc.get_objects():
+            if isinstance(obj, NO_COUNT_TYPES):
+                continue
+            rv += 1
+
+        return rv
+
+
+@pytest.fixture(name="gc")
+def fixture_gc():
+    """
+    Provides a consistent way to run garbage collection and count references.
+
+    **Note:** This will skip tests on PyPy.
+    """
+    if sys.implementation.name == "pypy":
+        pytest.skip(reason="depends on refcount semantics")
+    return GCFixture()
+
+
+@pytest.fixture
+def gc_collect():
+    """
+    Provides a consistent way to run garbage collection.
+
+    **Note:** This will *not* skip tests on PyPy.
+    """
+    return GCFixture.collect
index 6811a26c32a3337bf0d1364684be9c06c82ee3a9..917dfc9193e9c558778034979587cf175a188b7f 100644 (file)
@@ -53,7 +53,7 @@ def libpq():
         # Not available when testing the binary package
         libname = find_libpq_full_path()
         assert libname, "libpq libname not found"
-        return ctypes.pydll.LoadLibrary(libname)
+        return ctypes.cdll.LoadLibrary(libname)
     except Exception as e:
         if pq.__impl__ == "binary":
             pytest.skip(f"can't load libpq for testing: {e}")
index e50f5ec05f28b460cc7c9c349f055db763a64ad3..1d566b5e5dc178aefc104415e0e8932b90f27779 100644 (file)
@@ -60,7 +60,7 @@ class Proxy:
         # Get server params
         host = cdict.get("host") or os.environ.get("PGHOST")
         self.server_host = host if host and not host.startswith("/") else "localhost"
-        self.server_port = cdict.get("port", "5432")
+        self.server_port = cdict.get("port") or os.environ.get("PGPORT", "5432")
 
         # Get client params
         self.client_host = "localhost"
index 59b0659c91dad66b5a53416fa7b6dd1c28cd3224..e2aeb7cc39164dd5f239adca2766d902816d07fb 100644 (file)
@@ -358,7 +358,7 @@ def test_fail_rollback_close(dsn, caplog, monkeypatch):
     assert "BAD" in caplog.records[2].message
 
 
-def test_del_no_warning(dsn, recwarn):
+def test_del_no_warning(dsn, recwarn, gc_collect):
     p = pool.ConnectionPool(dsn, min_size=2, open=False)
     p.open()
     with p.connection() as conn:
@@ -367,6 +367,7 @@ def test_del_no_warning(dsn, recwarn):
     p.wait()
     ref = weakref.ref(p)
     del p
+    gc_collect()
     assert not ref()
     assert not recwarn, [str(w.message) for w in recwarn.list]
 
index ca9b29cef45674749ed950b9a4899008d33f3d7e..0bc84f8bc6428cc41f13590b121ef8c44ae7e7e2 100644 (file)
@@ -362,7 +362,7 @@ async def test_fail_rollback_close(dsn, caplog, monkeypatch):
     assert "BAD" in caplog.records[2].message
 
 
-async def test_del_no_warning(dsn, recwarn):
+async def test_del_no_warning(dsn, recwarn, gc_collect):
     p = pool.AsyncConnectionPool(dsn, min_size=2, open=False)
     await p.open()
     async with p.connection() as conn:
@@ -371,6 +371,7 @@ async def test_del_no_warning(dsn, recwarn):
     await p.wait()
     ref = weakref.ref(p)
     del p
+    gc_collect()
     assert not ref()
     assert not recwarn, [str(w.message) for w in recwarn.list]
 
index 3b96d9a34614c7ed1bc05021d5394111fb9dd8a6..b2ad2ffa2c81b9c9b4572af762a7151a2d124916 100644 (file)
@@ -6,8 +6,6 @@ import sys
 
 import pytest
 
-from ..utils import gc_collect
-
 try:
     import psycopg_pool as pool
 except ImportError:
@@ -63,7 +61,7 @@ def test_cant_create_open_outside_loop(dsn):
 
 
 @pytest.fixture
-def asyncio_run(recwarn):
+def asyncio_run(recwarn, gc_collect):
     """Fixture reuturning asyncio.run, but managing resources at exit.
 
     In certain runs, fd objects are leaked and the error will only be caught
index b37026715d482635234eacc6f1c572e563ba1886..ddf78a693638723d0a229380e92de76af7fd60d7 100644 (file)
@@ -347,11 +347,12 @@ def test_putconn_wrong_pool(pool_cls, dsn):
 
 @skip_async
 @pytest.mark.slow
-def test_del_stops_threads(pool_cls, dsn):
+def test_del_stops_threads(pool_cls, dsn, gc):
     p = pool_cls(dsn)
     assert p._sched_runner is not None
     ts = [p._sched_runner] + p._workers
     del p
+    gc.collect()
     sleep(0.1)
     for t in ts:
         assert not is_alive(t), t
index 567ad5b706a2c3401ffaf286cc3716530a7691bc..c49d7d6a1435220fc2f6178ca7ba97d3960a127b 100644 (file)
@@ -364,11 +364,12 @@ async def test_putconn_wrong_pool(pool_cls, dsn):
 
 @skip_async
 @pytest.mark.slow
-async def test_del_stops_threads(pool_cls, dsn):
+async def test_del_stops_threads(pool_cls, dsn, gc):
     p = pool_cls(dsn)
     assert p._sched_runner is not None
     ts = [p._sched_runner] + p._workers
     del p
+    gc.collect()
     await asleep(0.1)
     for t in ts:
         assert not is_alive(t), t
index 05661511ad0b14f16561f6d1bea07549fff0d83f..ff18379a44ed7cabe19ff44506241b796102baea 100644 (file)
@@ -11,8 +11,6 @@ import psycopg
 from psycopg import pq
 import psycopg.generators
 
-from ..utils import gc_collect
-
 
 def test_connectdb(dsn):
     conn = pq.PGconn.connect(dsn.encode())
@@ -82,7 +80,7 @@ def test_finish(pgconn):
 
 
 @pytest.mark.slow
-def test_weakref(dsn):
+def test_weakref(dsn, gc_collect):
     conn = pq.PGconn.connect(dsn.encode())
     w = weakref.ref(conn)
     conn.finish()
index 1ead1ba9de39c6032ad5e5f1dca911aace36fc06..9cfa7459f53084a7f6ff49501808364c7ab0ecfc 100644 (file)
@@ -11,9 +11,8 @@ from typing import Any, List
 import psycopg
 from psycopg import Notify, pq, errors as e
 from psycopg.rows import tuple_row
-from psycopg.conninfo import conninfo_to_dict, make_conninfo
+from psycopg.conninfo import conninfo_to_dict, timeout_from_conninfo
 
-from .utils import gc_collect
 from .acompat import is_async, skip_sync, skip_async
 from ._test_cursor import my_row_factory
 from ._test_connection import tx_params, tx_params_isolation, tx_values_map
@@ -49,9 +48,41 @@ def test_connect_str_subclass(conn_cls, dsn):
 def test_connect_timeout(conn_cls, deaf_port):
     t0 = time.time()
     with pytest.raises(psycopg.OperationalError, match="timeout expired"):
-        conn_cls.connect(host="localhost", port=deaf_port, connect_timeout=1)
+        conn_cls.connect(host="localhost", port=deaf_port, connect_timeout=2)
     elapsed = time.time() - t0
-    assert elapsed == pytest.approx(1.0, abs=0.05)
+    assert elapsed == pytest.approx(2.0, abs=0.05)
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+def test_multi_hosts(conn_cls, proxy, dsn, deaf_port, monkeypatch):
+    args = conninfo_to_dict(dsn)
+    args["host"] = f"{proxy.client_host},{proxy.server_host}"
+    args["port"] = f"{deaf_port},{proxy.server_port}"
+    args.pop("hostaddr", None)
+    monkeypatch.setattr(psycopg.conninfo, "_DEFAULT_CONNECT_TIMEOUT", 2)
+    t0 = time.time()
+    with conn_cls.connect(**args) as conn:
+        elapsed = time.time() - t0
+        assert 2.0 < elapsed < 2.5
+        assert conn.info.port == int(proxy.server_port)
+        assert conn.info.host == proxy.server_host
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+def test_multi_hosts_timeout(conn_cls, proxy, dsn, deaf_port):
+    args = conninfo_to_dict(dsn)
+    args["host"] = f"{proxy.client_host},{proxy.server_host}"
+    args["port"] = f"{deaf_port},{proxy.server_port}"
+    args.pop("hostaddr", None)
+    args["connect_timeout"] = "2"
+    t0 = time.time()
+    with conn_cls.connect(**args) as conn:
+        elapsed = time.time() - t0
+        assert 2.0 < elapsed < 2.5
+        assert conn.info.port == int(proxy.server_port)
+        assert conn.info.host == proxy.server_host
 
 
 def test_close(conn):
@@ -99,13 +130,14 @@ def test_cursor_closed(conn):
 # compiled with Cython-3.0.0b3, not before.
 
 
+@pytest.mark.slow
 @pytest.mark.xfail(
     pq.__impl__ in ("c", "binary")
     and sys.version_info[:2] == (3, 12)
     and (not is_async(__name__)),
     reason="Something with Exceptions, C, Python 3.12",
 )
-def test_connection_warn_close(conn_cls, dsn, recwarn):
+def test_connection_warn_close(conn_cls, dsn, recwarn, gc_collect):
     conn = conn_cls.connect(dsn)
     conn.close()
     del conn
@@ -113,11 +145,13 @@ def test_connection_warn_close(conn_cls, dsn, recwarn):
 
     conn = conn_cls.connect(dsn)
     del conn
+    gc_collect()
     assert "IDLE" in str(recwarn.pop(ResourceWarning).message)
 
     conn = conn_cls.connect(dsn)
     conn.execute("select 1")
     del conn
+    gc_collect()
     assert "INTRANS" in str(recwarn.pop(ResourceWarning).message)
 
     conn = conn_cls.connect(dsn)
@@ -210,7 +244,7 @@ def test_context_active_rollback_no_clobber(conn_cls, dsn, caplog):
 
 
 @pytest.mark.slow
-def test_weakref(conn_cls, dsn):
+def test_weakref(conn_cls, dsn, gc_collect):
     conn = conn_cls.connect(dsn)
     w = weakref.ref(conn)
     conn.close()
@@ -399,26 +433,26 @@ def test_autocommit_unknown(conn):
         (("dbname=foo user=bar",), {}, "dbname=foo user=bar"),
         (("dbname=foo",), {"user": "baz"}, "dbname=foo user=baz"),
         (
-            ("dbname=foo port=5432",),
+            ("dbname=foo port=5433",),
             {"dbname": "qux", "user": "joe"},
-            "dbname=qux user=joe port=5432",
+            "dbname=qux user=joe port=5433",
         ),
         (("dbname=foo",), {"user": None}, "dbname=foo"),
     ],
 )
 def test_connect_args(conn_cls, monkeypatch, setpgenv, pgconn, args, kwargs, want):
-    the_conninfo: str
+    got_conninfo: str
 
     def fake_connect(conninfo):
-        nonlocal the_conninfo
-        the_conninfo = conninfo
+        nonlocal got_conninfo
+        got_conninfo = conninfo
         return pgconn
         yield
 
     setpgenv({})
     monkeypatch.setattr(psycopg.generators, "connect", fake_connect)
     conn = conn_cls.connect(*args, **kwargs)
-    assert conninfo_to_dict(the_conninfo) == conninfo_to_dict(want)
+    assert conninfo_to_dict(got_conninfo) == conninfo_to_dict(want)
     conn.close()
 
 
@@ -789,9 +823,8 @@ def test_set_transaction_param_strange_property(conn):
 def test_get_connection_params(conn_cls, dsn, kwargs, exp, setpgenv):
     setpgenv({})
     params = conn_cls._get_connection_params(dsn, **kwargs)
-    conninfo = make_conninfo(**params)
-    assert conninfo_to_dict(conninfo) == exp[0]
-    assert params["connect_timeout"] == exp[1]
+    assert params == exp[0]
+    assert timeout_from_conninfo(params) == exp[1]
 
 
 def test_connect_context_adapters(conn_cls, dsn):
index 36a6c9f9feb354985f9f0857bebf3a7ca6355389..754e57b386e7bd31c3714417774c5b8895153f31 100644 (file)
@@ -8,9 +8,8 @@ from typing import Any, List
 import psycopg
 from psycopg import Notify, pq, errors as e
 from psycopg.rows import tuple_row
-from psycopg.conninfo import conninfo_to_dict, make_conninfo
+from psycopg.conninfo import conninfo_to_dict, timeout_from_conninfo
 
-from .utils import gc_collect
 from .acompat import is_async, skip_sync, skip_async
 from ._test_cursor import my_row_factory
 from ._test_connection import tx_params, tx_params_isolation, tx_values_map
@@ -46,9 +45,41 @@ async def test_connect_str_subclass(aconn_cls, dsn):
 async def test_connect_timeout(aconn_cls, deaf_port):
     t0 = time.time()
     with pytest.raises(psycopg.OperationalError, match="timeout expired"):
-        await aconn_cls.connect(host="localhost", port=deaf_port, connect_timeout=1)
+        await aconn_cls.connect(host="localhost", port=deaf_port, connect_timeout=2)
     elapsed = time.time() - t0
-    assert elapsed == pytest.approx(1.0, abs=0.05)
+    assert elapsed == pytest.approx(2.0, abs=0.05)
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+async def test_multi_hosts(aconn_cls, proxy, dsn, deaf_port, monkeypatch):
+    args = conninfo_to_dict(dsn)
+    args["host"] = f"{proxy.client_host},{proxy.server_host}"
+    args["port"] = f"{deaf_port},{proxy.server_port}"
+    args.pop("hostaddr", None)
+    monkeypatch.setattr(psycopg.conninfo, "_DEFAULT_CONNECT_TIMEOUT", 2)
+    t0 = time.time()
+    async with await aconn_cls.connect(**args) as conn:
+        elapsed = time.time() - t0
+        assert 2.0 < elapsed < 2.5
+        assert conn.info.port == int(proxy.server_port)
+        assert conn.info.host == proxy.server_host
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+async def test_multi_hosts_timeout(aconn_cls, proxy, dsn, deaf_port):
+    args = conninfo_to_dict(dsn)
+    args["host"] = f"{proxy.client_host},{proxy.server_host}"
+    args["port"] = f"{deaf_port},{proxy.server_port}"
+    args.pop("hostaddr", None)
+    args["connect_timeout"] = "2"
+    t0 = time.time()
+    async with await aconn_cls.connect(**args) as conn:
+        elapsed = time.time() - t0
+        assert 2.0 < elapsed < 2.5
+        assert conn.info.port == int(proxy.server_port)
+        assert conn.info.host == proxy.server_host
 
 
 async def test_close(aconn):
@@ -96,13 +127,14 @@ async def test_cursor_closed(aconn):
 
 # TODO: the INERROR started failing in the C implementation in Python 3.12a7
 # compiled with Cython-3.0.0b3, not before.
+@pytest.mark.slow
 @pytest.mark.xfail(
     pq.__impl__ in ("c", "binary")
     and sys.version_info[:2] == (3, 12)
     and not is_async(__name__),
     reason="Something with Exceptions, C, Python 3.12",
 )
-async def test_connection_warn_close(aconn_cls, dsn, recwarn):
+async def test_connection_warn_close(aconn_cls, dsn, recwarn, gc_collect):
     conn = await aconn_cls.connect(dsn)
     await conn.close()
     del conn
@@ -110,11 +142,13 @@ async def test_connection_warn_close(aconn_cls, dsn, recwarn):
 
     conn = await aconn_cls.connect(dsn)
     del conn
+    gc_collect()
     assert "IDLE" in str(recwarn.pop(ResourceWarning).message)
 
     conn = await aconn_cls.connect(dsn)
     await conn.execute("select 1")
     del conn
+    gc_collect()
     assert "INTRANS" in str(recwarn.pop(ResourceWarning).message)
 
     conn = await aconn_cls.connect(dsn)
@@ -208,7 +242,7 @@ async def test_context_active_rollback_no_clobber(aconn_cls, dsn, caplog):
 
 
 @pytest.mark.slow
-async def test_weakref(aconn_cls, dsn):
+async def test_weakref(aconn_cls, dsn, gc_collect):
     conn = await aconn_cls.connect(dsn)
     w = weakref.ref(conn)
     await conn.close()
@@ -397,9 +431,9 @@ async def test_autocommit_unknown(aconn):
         (("dbname=foo user=bar",), {}, "dbname=foo user=bar"),
         (("dbname=foo",), {"user": "baz"}, "dbname=foo user=baz"),
         (
-            ("dbname=foo port=5432",),
+            ("dbname=foo port=5433",),
             {"dbname": "qux", "user": "joe"},
-            "dbname=qux user=joe port=5432",
+            "dbname=qux user=joe port=5433",
         ),
         (("dbname=foo",), {"user": None}, "dbname=foo"),
     ],
@@ -407,18 +441,18 @@ async def test_autocommit_unknown(aconn):
 async def test_connect_args(
     aconn_cls, monkeypatch, setpgenv, pgconn, args, kwargs, want
 ):
-    the_conninfo: str
+    got_conninfo: str
 
     def fake_connect(conninfo):
-        nonlocal the_conninfo
-        the_conninfo = conninfo
+        nonlocal got_conninfo
+        got_conninfo = conninfo
         return pgconn
         yield
 
     setpgenv({})
     monkeypatch.setattr(psycopg.generators, "connect", fake_connect)
     conn = await aconn_cls.connect(*args, **kwargs)
-    assert conninfo_to_dict(the_conninfo) == conninfo_to_dict(want)
+    assert conninfo_to_dict(got_conninfo) == conninfo_to_dict(want)
     await conn.close()
 
 
@@ -797,9 +831,8 @@ def test_set_transaction_param_strange_property(conn):
 async def test_get_connection_params(aconn_cls, dsn, kwargs, exp, setpgenv):
     setpgenv({})
     params = await aconn_cls._get_connection_params(dsn, **kwargs)
-    conninfo = make_conninfo(**params)
-    assert conninfo_to_dict(conninfo) == exp[0]
-    assert params["connect_timeout"] == exp[1]
+    assert params == exp[0]
+    assert timeout_from_conninfo(params) == exp[1]
 
 
 async def test_connect_context_adapters(aconn_cls, dsn):
index 56a944ff17b754160fd3eb383f2ce53c29144a53..d9888d3bdd44fc11139aa1d0cac224b44efb71c9 100644 (file)
@@ -7,7 +7,8 @@ import pytest
 import psycopg
 from psycopg import ProgrammingError
 from psycopg.conninfo import make_conninfo, conninfo_to_dict, ConnectionInfo
-from psycopg.conninfo import resolve_hostaddr_async
+from psycopg.conninfo import conninfo_attempts, conninfo_attempts_async
+from psycopg.conninfo import timeout_from_conninfo, _DEFAULT_CONNECT_TIMEOUT
 from psycopg._encodings import pg2pyenc
 
 from .fix_crdb import crdb_encoding
@@ -319,43 +320,108 @@ class TestConnectionInfo:
 @pytest.mark.parametrize(
     "conninfo, want, env",
     [
-        ("", "", None),
-        ("host='' user=bar", "host='' user=bar", None),
+        ("", [""], None),
+        ("service=foo", ["service=foo"], None),
+        ("host='' user=bar", ["host='' user=bar"], None),
         (
             "host=127.0.0.1 user=bar",
-            "host=127.0.0.1 user=bar hostaddr=127.0.0.1",
+            ["host=127.0.0.1 user=bar"],
             None,
         ),
         (
             "host=1.1.1.1,2.2.2.2 user=bar",
-            "host=1.1.1.1,2.2.2.2 user=bar hostaddr=1.1.1.1,2.2.2.2",
+            ["host=1.1.1.1 user=bar", "host=2.2.2.2 user=bar"],
             None,
         ),
         (
             "host=1.1.1.1,2.2.2.2 port=5432",
-            "host=1.1.1.1,2.2.2.2 port=5432 hostaddr=1.1.1.1,2.2.2.2",
+            ["host=1.1.1.1 port=5432", "host=2.2.2.2 port=5432"],
+            None,
+        ),
+        (
+            "host=1.1.1.1,1.1.1.1 port=5432,",
+            ["host=1.1.1.1 port=5432", "host=1.1.1.1 port=''"],
+            None,
+        ),
+        (
+            "host=foo.com port=5432",
+            ["host=foo.com port=5432"],
+            {"PGHOSTADDR": "1.2.3.4"},
+        ),
+    ],
+)
+@pytest.mark.anyio
+def test_conninfo_attempts(setpgenv, conninfo, want, env):
+    setpgenv(env)
+    params = conninfo_to_dict(conninfo)
+    attempts = conninfo_attempts(params)
+    want = list(map(conninfo_to_dict, want))
+    assert want == attempts
+
+
+@pytest.mark.parametrize(
+    "conninfo, want, env",
+    [
+        ("", [""], None),
+        ("host='' user=bar", ["host='' user=bar"], None),
+        (
+            "host=127.0.0.1 user=bar port=''",
+            ["host=127.0.0.1 user=bar port='' hostaddr=127.0.0.1"],
+            None,
+        ),
+        (
+            "host=127.0.0.1 user=bar",
+            ["host=127.0.0.1 user=bar hostaddr=127.0.0.1"],
+            None,
+        ),
+        (
+            "host=1.1.1.1,2.2.2.2 user=bar",
+            [
+                "host=1.1.1.1 user=bar hostaddr=1.1.1.1",
+                "host=2.2.2.2 user=bar hostaddr=2.2.2.2",
+            ],
+            None,
+        ),
+        (
+            "host=1.1.1.1,2.2.2.2 port=5432",
+            [
+                "host=1.1.1.1 port=5432 hostaddr=1.1.1.1",
+                "host=2.2.2.2 port=5432 hostaddr=2.2.2.2",
+            ],
+            None,
+        ),
+        (
+            "host=1.1.1.1,2.2.2.2 port=5432,",
+            [
+                "host=1.1.1.1 port=5432 hostaddr=1.1.1.1",
+                "host=2.2.2.2 port='' hostaddr=2.2.2.2",
+            ],
             None,
         ),
         (
             "port=5432",
-            "host=1.1.1.1,2.2.2.2 port=5432 hostaddr=1.1.1.1,2.2.2.2",
+            [
+                "host=1.1.1.1 port=5432 hostaddr=1.1.1.1",
+                "host=2.2.2.2 port=5432 hostaddr=2.2.2.2",
+            ],
             {"PGHOST": "1.1.1.1,2.2.2.2"},
         ),
         (
             "host=foo.com port=5432",
-            "host=foo.com port=5432",
+            ["host=foo.com port=5432"],
             {"PGHOSTADDR": "1.2.3.4"},
         ),
     ],
 )
 @pytest.mark.anyio
-async def test_resolve_hostaddr_async_no_resolve(
+async def test_conninfo_attempts_async_no_resolve(
     setpgenv, conninfo, want, env, fail_resolve
 ):
     setpgenv(env)
     params = conninfo_to_dict(conninfo)
-    params = await resolve_hostaddr_async(params)
-    assert conninfo_to_dict(want) == params
+    attempts = await conninfo_attempts_async(params)
+    want = list(map(conninfo_to_dict, want))
+    assert want == attempts
 
 
 @pytest.mark.parametrize(
@@ -363,46 +429,66 @@ async def test_resolve_hostaddr_async_no_resolve(
     [
         (
             "host=foo.com,qux.com",
-            "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2",
+            ["host=foo.com hostaddr=1.1.1.1", "host=qux.com hostaddr=2.2.2.2"],
             None,
         ),
         (
             "host=foo.com,qux.com port=5433",
-            "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2 port=5433",
+            [
+                "host=foo.com hostaddr=1.1.1.1 port=5433",
+                "host=qux.com hostaddr=2.2.2.2 port=5433",
+            ],
             None,
         ),
         (
             "host=foo.com,qux.com port=5432,5433",
-            "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2 port=5432,5433",
+            [
+                "host=foo.com hostaddr=1.1.1.1 port=5432",
+                "host=qux.com hostaddr=2.2.2.2 port=5433",
+            ],
+            None,
+        ),
+        (
+            "host=foo.com,foo.com port=5432,",
+            [
+                "host=foo.com hostaddr=1.1.1.1 port=5432",
+                "host=foo.com hostaddr=1.1.1.1 port=''",
+            ],
             None,
         ),
         (
             "host=foo.com,nosuchhost.com",
-            "host=foo.com hostaddr=1.1.1.1",
+            ["host=foo.com hostaddr=1.1.1.1"],
             None,
         ),
         (
             "host=foo.com, port=5432,5433",
-            "host=foo.com, hostaddr=1.1.1.1, port=5432,5433",
+            ["host=foo.com hostaddr=1.1.1.1 port=5432", "host='' port=5433"],
             None,
         ),
         (
             "host=nosuchhost.com,foo.com",
-            "host=foo.com hostaddr=1.1.1.1",
+            ["host=foo.com hostaddr=1.1.1.1"],
             None,
         ),
         (
             "host=foo.com,qux.com",
-            "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2",
+            ["host=foo.com hostaddr=1.1.1.1", "host=qux.com hostaddr=2.2.2.2"],
             {},
         ),
+        (
+            "host=dup.com",
+            ["host=dup.com hostaddr=3.3.3.3", "host=dup.com hostaddr=3.3.3.4"],
+            None,
+        ),
     ],
 )
 @pytest.mark.anyio
-async def test_resolve_hostaddr_async(conninfo, want, env, fake_resolve):
+async def test_conninfo_attempts_async(conninfo, want, env, fake_resolve):
     params = conninfo_to_dict(conninfo)
-    params = await resolve_hostaddr_async(params)
-    assert conninfo_to_dict(want) == params
+    attempts = await conninfo_attempts_async(params)
+    want = list(map(conninfo_to_dict, want))
+    assert want == attempts
 
 
 @pytest.mark.parametrize(
@@ -415,29 +501,105 @@ async def test_resolve_hostaddr_async(conninfo, want, env, fake_resolve):
     ],
 )
 @pytest.mark.anyio
-async def test_resolve_hostaddr_async_bad(setpgenv, conninfo, env, fake_resolve):
+async def test_conninfo_attempts_async_bad(setpgenv, conninfo, env, fake_resolve):
     setpgenv(env)
     params = conninfo_to_dict(conninfo)
     with pytest.raises(psycopg.Error):
-        await resolve_hostaddr_async(params)
+        await conninfo_attempts_async(params)
+
+
+@pytest.mark.parametrize(
+    "conninfo, env",
+    [
+        ("host=foo.com port=1,2", None),
+        ("host=1.1.1.1,2.2.2.2 port=5432,5433,5434", None),
+        ("host=1.1.1.1,2.2.2.2", {"PGPORT": "1,2,3"}),
+    ],
+)
+@pytest.mark.anyio
+def test_conninfo_attempts_bad(setpgenv, conninfo, env):
+    setpgenv(env)
+    params = conninfo_to_dict(conninfo)
+    with pytest.raises(psycopg.Error):
+        conninfo_attempts(params)
+
+
+def test_conninfo_random():
+    hosts = [f"host{n:02d}" for n in range(50)]
+    args = {"host": ",".join(hosts)}
+    ahosts = [att["host"] for att in conninfo_attempts(args)]
+    assert ahosts == hosts
+
+    args["load_balance_hosts"] = "disable"
+    ahosts = [att["host"] for att in conninfo_attempts(args)]
+    assert ahosts == hosts
+
+    args["load_balance_hosts"] = "random"
+    ahosts = [att["host"] for att in conninfo_attempts(args)]
+    assert ahosts != hosts
+    ahosts.sort()
+    assert ahosts == hosts
+
+
+@pytest.mark.anyio
+async def test_conninfo_random_async(fake_resolve):
+    args = {"host": "alot.com"}
+    hostaddrs = [att["hostaddr"] for att in await conninfo_attempts_async(args)]
+    assert len(hostaddrs) == 20
+    assert hostaddrs == sorted(hostaddrs)
+
+    args["load_balance_hosts"] = "disable"
+    hostaddrs = [att["hostaddr"] for att in await conninfo_attempts_async(args)]
+    assert hostaddrs == sorted(hostaddrs)
+
+    args["load_balance_hosts"] = "random"
+    hostaddrs = [att["hostaddr"] for att in await conninfo_attempts_async(args)]
+    assert hostaddrs != sorted(hostaddrs)
+
+
+@pytest.mark.parametrize(
+    "conninfo, want, env",
+    [
+        ("", _DEFAULT_CONNECT_TIMEOUT, None),
+        ("host=foo", _DEFAULT_CONNECT_TIMEOUT, None),
+        ("connect_timeout=-1", _DEFAULT_CONNECT_TIMEOUT, None),
+        ("connect_timeout=0", _DEFAULT_CONNECT_TIMEOUT, None),
+        ("connect_timeout=1", 2, None),
+        ("connect_timeout=10", 10, None),
+        ("", 15, {"PGCONNECT_TIMEOUT": "15"}),
+    ],
+)
+def test_timeout(setpgenv, conninfo, want, env):
+    setpgenv(env)
+    params = conninfo_to_dict(conninfo)
+    timeout = timeout_from_conninfo(params)
+    assert timeout == want
 
 
 @pytest.fixture
 async def fake_resolve(monkeypatch):
     fake_hosts = {
-        "localhost": "127.0.0.1",
-        "foo.com": "1.1.1.1",
-        "qux.com": "2.2.2.2",
+        "localhost": ["127.0.0.1"],
+        "foo.com": ["1.1.1.1"],
+        "qux.com": ["2.2.2.2"],
+        "dup.com": ["3.3.3.3", "3.3.3.4"],
+        "alot.com": [f"4.4.4.{n}" for n in range(10, 30)],
     }
 
+    def family(host):
+        return socket.AF_INET6 if ":" in host else socket.AF_INET
+
     async def fake_getaddrinfo(host, port, **kwargs):
         assert isinstance(port, int) or (isinstance(port, str) and port.isdigit())
         try:
-            addr = fake_hosts[host]
+            addrs = fake_hosts[host]
         except KeyError:
             raise OSError(f"unknown test host: {host}")
         else:
-            return [(socket.AF_INET, socket.SOCK_STREAM, 6, "", (addr, 432))]
+            return [
+                (family(addr), socket.SOCK_STREAM, 6, "", (addr, port))
+                for addr in addrs
+            ]
 
     monkeypatch.setattr(asyncio.get_running_loop(), "getaddrinfo", fake_getaddrinfo)
 
index 4b3e182e65dd524f308d44af310225bbc933db3f..fda854e60bea1adeee1eea27bb3e7b6f0888f5ef 100644 (file)
@@ -20,7 +20,7 @@ from psycopg.types import TypeInfo
 from psycopg.types.hstore import register_hstore
 from psycopg.types.numeric import Int4
 
-from .utils import eur, gc_collect, gc_count
+from .utils import eur
 from ._test_copy import sample_text, sample_binary, sample_binary_rows  # noqa
 from ._test_copy import sample_values, sample_records, sample_tabledef
 from ._test_copy import ensure_table, py_to_raw, special_chars, FileWriter
@@ -677,7 +677,7 @@ def test_connection_writer(conn, format, buffer):
     "fmt, set_types", [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)]
 )
 @pytest.mark.parametrize("method", ["read", "iter", "row", "rows"])
-def test_copy_to_leaks(conn_cls, dsn, faker, fmt, set_types, method):
+def test_copy_to_leaks(conn_cls, dsn, faker, fmt, set_types, method, gc):
     faker.format = PyFormat.from_pq(fmt)
     faker.choose_schema(ncols=20)
     faker.make_records(20)
@@ -717,12 +717,12 @@ def test_copy_to_leaks(conn_cls, dsn, faker, fmt, set_types, method):
                     elif method == "rows":
                         list(copy.rows())
 
-    gc_collect()
+    gc.collect()
     n = []
     for i in range(3):
         work()
-        gc_collect()
-        n.append(gc_count())
+        gc.collect()
+        n.append(gc.count())
 
     assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
 
@@ -731,7 +731,7 @@ def test_copy_to_leaks(conn_cls, dsn, faker, fmt, set_types, method):
 @pytest.mark.parametrize(
     "fmt, set_types", [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)]
 )
-def test_copy_from_leaks(conn_cls, dsn, faker, fmt, set_types):
+def test_copy_from_leaks(conn_cls, dsn, faker, fmt, set_types, gc):
     faker.format = PyFormat.from_pq(fmt)
     faker.choose_schema(ncols=20)
     faker.make_records(20)
@@ -759,12 +759,12 @@ def test_copy_from_leaks(conn_cls, dsn, faker, fmt, set_types):
                 for got, want in zip(recs, faker.records):
                     faker.assert_record(got, want)
 
-    gc_collect()
+    gc.collect()
     n = []
     for i in range(3):
         work()
-        gc_collect()
-        n.append(gc_count())
+        gc.collect()
+        n.append(gc.count())
 
     assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
 
index 68c21d8ef87adf3193011f1b9f70f66255e5fbb6..f52cc614431dbe5cc93b3f39dfd2ca8dbe568fce 100644 (file)
@@ -17,7 +17,7 @@ from psycopg.types import TypeInfo
 from psycopg.types.hstore import register_hstore
 from psycopg.types.numeric import Int4
 
-from .utils import eur, gc_collect, gc_count
+from .utils import eur
 from .acompat import alist
 from ._test_copy import sample_text, sample_binary, sample_binary_rows  # noqa
 from ._test_copy import sample_values, sample_records, sample_tabledef
@@ -695,7 +695,7 @@ async def test_connection_writer(aconn, format, buffer):
     [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)],
 )
 @pytest.mark.parametrize("method", ["read", "iter", "row", "rows"])
-async def test_copy_to_leaks(aconn_cls, dsn, faker, fmt, set_types, method):
+async def test_copy_to_leaks(aconn_cls, dsn, faker, fmt, set_types, method, gc):
     faker.format = PyFormat.from_pq(fmt)
     faker.choose_schema(ncols=20)
     faker.make_records(20)
@@ -735,12 +735,12 @@ async def test_copy_to_leaks(aconn_cls, dsn, faker, fmt, set_types, method):
                     elif method == "rows":
                         await alist(copy.rows())
 
-    gc_collect()
+    gc.collect()
     n = []
     for i in range(3):
         await work()
-        gc_collect()
-        n.append(gc_count())
+        gc.collect()
+        n.append(gc.count())
 
     assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
 
@@ -750,7 +750,7 @@ async def test_copy_to_leaks(aconn_cls, dsn, faker, fmt, set_types, method):
     "fmt, set_types",
     [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)],
 )
-async def test_copy_from_leaks(aconn_cls, dsn, faker, fmt, set_types):
+async def test_copy_from_leaks(aconn_cls, dsn, faker, fmt, set_types, gc):
     faker.format = PyFormat.from_pq(fmt)
     faker.choose_schema(ncols=20)
     faker.make_records(20)
@@ -778,12 +778,12 @@ async def test_copy_from_leaks(aconn_cls, dsn, faker, fmt, set_types):
                 for got, want in zip(recs, faker.records):
                     faker.assert_record(got, want)
 
-    gc_collect()
+    gc.collect()
     n = []
     for i in range(3):
         await work()
-        gc_collect()
-        n.append(gc_count())
+        gc.collect()
+        n.append(gc.count())
 
     assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
 
index 86d2fd7ef1a9783a71e135e7e19b2000777a5c6a..28b96a7b4d270a7c328cac07acd78744fb170ac1 100644 (file)
@@ -10,8 +10,6 @@ import psycopg
 from psycopg import pq, rows, errors as e
 from psycopg.adapt import PyFormat
 
-from .utils import gc_collect, gc_count
-
 
 def test_default_cursor(conn):
     cur = conn.cursor()
@@ -69,7 +67,7 @@ def test_query_params_executemany(conn):
 @pytest.mark.parametrize("fmt_out", pq.Format)
 @pytest.mark.parametrize("fetch", ["one", "many", "all", "iter"])
 @pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"])
-def test_leak(conn_cls, dsn, faker, fmt, fmt_out, fetch, row_factory):
+def test_leak(conn_cls, dsn, faker, fmt, fmt_out, fetch, row_factory, gc):
     faker.format = fmt
     faker.choose_schema(ncols=5)
     faker.make_records(10)
@@ -101,10 +99,10 @@ def test_leak(conn_cls, dsn, faker, fmt, fmt_out, fetch, row_factory):
                         pass
 
     n = []
-    gc_collect()
+    gc.collect()
     for i in range(3):
         work()
-        gc_collect()
-        n.append(gc_count())
+        gc.collect()
+        n.append(gc.count())
 
     assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
index 1ebc827cda249f956ae047159d0687299d198a36..9c5c3ebdd96e364db3adb29f163a188358a64b0c 100644 (file)
@@ -7,8 +7,6 @@ import psycopg
 from psycopg import pq, rows, errors as e
 from psycopg.adapt import PyFormat
 
-from .utils import gc_collect, gc_count
-
 
 async def test_default_cursor(aconn):
     cur = aconn.cursor()
@@ -68,7 +66,7 @@ async def test_query_params_executemany(aconn):
 @pytest.mark.parametrize("fmt_out", pq.Format)
 @pytest.mark.parametrize("fetch", ["one", "many", "all", "iter"])
 @pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"])
-async def test_leak(aconn_cls, dsn, faker, fmt, fmt_out, fetch, row_factory):
+async def test_leak(aconn_cls, dsn, faker, fmt, fmt_out, fetch, row_factory, gc):
     faker.format = fmt
     faker.choose_schema(ncols=5)
     faker.make_records(10)
@@ -102,10 +100,10 @@ async def test_leak(aconn_cls, dsn, faker, fmt, fmt_out, fetch, row_factory):
                         pass
 
     n = []
-    gc_collect()
+    gc.collect()
     for i in range(3):
         await work()
-        gc_collect()
-        n.append(gc_count())
+        gc.collect()
+        n.append(gc.count())
 
     assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
index bac6567ca3a19d9e4ad80bf0c424d3e63abfa6a1..f9d9d97447860138805322c81d80ff24f7482714 100644 (file)
@@ -7,7 +7,6 @@ import pytest
 import psycopg
 from psycopg import rows
 
-from .utils import gc_collect, gc_count
 from .fix_crdb import crdb_encoding
 
 
@@ -80,7 +79,7 @@ def test_query_params_executemany(conn):
 @pytest.mark.slow
 @pytest.mark.parametrize("fetch", ["one", "many", "all", "iter"])
 @pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"])
-def test_leak(conn_cls, dsn, faker, fetch, row_factory):
+def test_leak(conn_cls, dsn, faker, fetch, row_factory, gc):
     faker.choose_schema(ncols=5)
     faker.make_records(10)
     row_factory = getattr(rows, row_factory)
@@ -111,11 +110,11 @@ def test_leak(conn_cls, dsn, faker, fetch, row_factory):
                         pass
 
     n = []
-    gc_collect()
+    gc.collect()
     for i in range(3):
         work()
-        gc_collect()
-        n.append(gc_count())
+        gc.collect()
+        n.append(gc.count())
 
     assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
 
index d1abbe84915cfd407c5eec9690af866b020e1f03..63fb5d5af934c2d0221d89bd48dcadf6dcb3298e 100644 (file)
@@ -4,7 +4,6 @@ import pytest
 import psycopg
 from psycopg import rows
 
-from .utils import gc_collect, gc_count
 from .fix_crdb import crdb_encoding
 
 
@@ -79,7 +78,7 @@ async def test_query_params_executemany(aconn):
 @pytest.mark.slow
 @pytest.mark.parametrize("fetch", ["one", "many", "all", "iter"])
 @pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"])
-async def test_leak(aconn_cls, dsn, faker, fetch, row_factory):
+async def test_leak(aconn_cls, dsn, faker, fetch, row_factory, gc):
     faker.choose_schema(ncols=5)
     faker.make_records(10)
     row_factory = getattr(rows, row_factory)
@@ -112,11 +111,11 @@ async def test_leak(aconn_cls, dsn, faker, fetch, row_factory):
                         pass
 
     n = []
-    gc_collect()
+    gc.collect()
     for i in range(3):
         await work()
-        gc_collect()
-        n.append(gc_count())
+        gc.collect()
+        n.append(gc.count())
 
     assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
 
index 535aa01535c2ff242ea43e12f666a12b214e54a0..159e67cb1e92de995418ffe24bb6622c4af0c42f 100644 (file)
@@ -17,7 +17,7 @@ from psycopg import sql, rows
 from psycopg.adapt import PyFormat
 from psycopg.types import TypeInfo
 
-from .utils import gc_collect, raiseif
+from .utils import raiseif
 from .acompat import closing
 from .fix_crdb import crdb_encoding
 from ._test_cursor import my_row_factory, ph
@@ -123,7 +123,7 @@ def test_context(conn):
 
 
 @pytest.mark.slow
-def test_weakref(conn):
+def test_weakref(conn, gc_collect):
     cur = conn.cursor()
     w = weakref.ref(cur)
     cur.close()
index 4268fdd70606442a479e815e719180e84adb9a74..840de65ca47c84c7e846fe60c73743993bf0b6ef 100644 (file)
@@ -14,7 +14,7 @@ from psycopg import sql, rows
 from psycopg.adapt import PyFormat
 from psycopg.types import TypeInfo
 
-from .utils import gc_collect, raiseif
+from .utils import raiseif
 from .acompat import aclosing, alist, anext
 from .fix_crdb import crdb_encoding
 from ._test_cursor import my_row_factory, ph
@@ -121,7 +121,7 @@ async def test_context(aconn):
 
 
 @pytest.mark.slow
-async def test_weakref(aconn):
+async def test_weakref(aconn, gc_collect):
     cur = aconn.cursor()
     w = weakref.ref(cur)
     await cur.close()
index 683aa6446b40a8725bcb301542737f04f7fe4165..059d22d1c02d4acdf1ca3fb887eed29c5912ec77 100644 (file)
@@ -7,7 +7,6 @@ from psycopg import pq, rows, errors as e
 from psycopg.adapt import PyFormat
 
 from ._test_cursor import ph
-from .utils import gc_collect, gc_count
 
 
 @pytest.fixture
@@ -75,7 +74,7 @@ def test_query_params_executemany(conn):
 @pytest.mark.parametrize("fmt_out", pq.Format)
 @pytest.mark.parametrize("fetch", ["one", "many", "all", "iter"])
 @pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"])
-def test_leak(conn_cls, dsn, faker, fmt, fmt_out, fetch, row_factory):
+def test_leak(conn_cls, dsn, faker, fmt, fmt_out, fetch, row_factory, gc):
     faker.format = fmt
     faker.choose_schema(ncols=5)
     faker.make_records(10)
@@ -108,9 +107,9 @@ def test_leak(conn_cls, dsn, faker, fmt, fmt_out, fetch, row_factory):
                             pass
 
     n = []
-    gc_collect()
+    gc.collect()
     for i in range(3):
         work()
-        gc_collect()
-        n.append(gc_count())
+        gc.collect()
+        n.append(gc.count())
     assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
index b207b28c83c53924e611a94379b61b430f846b09..e6e85124a87ac897fc560cfd6b7abdb378b2a8a6 100644 (file)
@@ -4,7 +4,6 @@ from psycopg import pq, rows, errors as e
 from psycopg.adapt import PyFormat
 
 from ._test_cursor import ph
-from .utils import gc_collect, gc_count
 
 
 @pytest.fixture
@@ -72,7 +71,7 @@ async def test_query_params_executemany(aconn):
 @pytest.mark.parametrize("fmt_out", pq.Format)
 @pytest.mark.parametrize("fetch", ["one", "many", "all", "iter"])
 @pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"])
-async def test_leak(aconn_cls, dsn, faker, fmt, fmt_out, fetch, row_factory):
+async def test_leak(aconn_cls, dsn, faker, fmt, fmt_out, fetch, row_factory, gc):
     faker.format = fmt
     faker.choose_schema(ncols=5)
     faker.make_records(10)
@@ -105,9 +104,9 @@ async def test_leak(aconn_cls, dsn, faker, fmt, fmt_out, fetch, row_factory):
                             pass
 
     n = []
-    gc_collect()
+    gc.collect()
     for i in range(3):
         await work()
-        gc_collect()
-        n.append(gc_count())
+        gc.collect()
+        n.append(gc.count())
     assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
index 03dc7af537d18e15938f2f21afad5c608bb9d409..b4259abb2f3bc127f9283e088817092514194189 100644 (file)
@@ -255,11 +255,12 @@ def test_close_no_clobber(conn):
             cur.fetchall()
 
 
-def test_warn_close(conn, recwarn):
+def test_warn_close(conn, recwarn, gc_collect):
     recwarn.clear()
     cur = conn.cursor("foo")
     cur.execute("select generate_series(1, 10) as bar")
     del cur
+    gc_collect()
     assert ".close()" in str(recwarn.pop(ResourceWarning).message)
 
 
index 7317c52d8bdc78731fc1ed8478f17848b03f8569..0f0efa2a96c69797e83bd6c97d69f3e5e192c763 100644 (file)
@@ -6,6 +6,7 @@ from psycopg.pq import Format
 
 from .acompat import alist
 
+
 pytestmark = pytest.mark.crdb_skip("server-side cursor")
 
 
@@ -261,11 +262,12 @@ async def test_close_no_clobber(aconn):
             await cur.fetchall()
 
 
-async def test_warn_close(aconn, recwarn):
+async def test_warn_close(aconn, recwarn, gc_collect):
     recwarn.clear()
     cur = aconn.cursor("foo")
     await cur.execute("select generate_series(1, 10) as bar")
     del cur
+    gc_collect()
     assert ".close()" in str(recwarn.pop(ResourceWarning).message)
 
 
index b1e8891155f523f3730d1078d2808c3e6a495544..a83aaeb6694165ebd0c257c650ad7361063185f0 100644 (file)
@@ -33,7 +33,6 @@ async def test_resolve_hostaddr_async_warning(recwarn):
     params = await psycopg._dns.resolve_hostaddr_async(  # type: ignore[attr-defined]
         params
     )
-    assert conninfo_to_dict(conninfo) == params
     assert "resolve_hostaddr_async" in str(recwarn.pop(DeprecationWarning).message)
 
 
index ddf57513ce78218f2731708ffcfe6f82957e1d78..a5016ae32dc95da28d5bd208f671d1e5378b4c71 100644 (file)
@@ -9,7 +9,7 @@ import psycopg
 from psycopg import pq
 from psycopg import errors as e
 
-from .utils import eur, gc_collect
+from .utils import eur
 from .fix_crdb import is_crdb
 
 
@@ -187,7 +187,7 @@ def test_diag_pickle(conn):
     (pq.__impl__ in ("c", "binary") and sys.version_info[:2] == (3, 12)),
     reason="Something with Exceptions, C, Python 3.12",
 )
-def test_diag_survives_cursor(conn):
+def test_diag_survives_cursor(conn, gc_collect):
     cur = conn.cursor()
     with pytest.raises(e.Error) as exc:
         cur.execute("select * from nosuchtable")
diff --git a/tests/test_gevent.py b/tests/test_gevent.py
new file mode 100644 (file)
index 0000000..befbeb2
--- /dev/null
@@ -0,0 +1,84 @@
+import sys
+import json
+import subprocess as sp
+
+import pytest
+import psycopg
+
+pytest.importorskip("gevent")
+
+pytestmark = [pytest.mark.gevent]
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+def test_gevent(dsn):
+    TICK = 0.1
+    script = f"""\
+import gevent.monkey
+gevent.monkey.patch_all()
+
+import json
+import time
+import gevent
+import psycopg
+
+TICK = {TICK!r}
+dts = []
+queried = False
+
+def ticker():
+    t0 = time.time()
+    for i in range(5):
+        time.sleep(TICK)
+        t = time.time()
+        dts.append(t - t0)
+        t0 = t
+
+def querier():
+    time.sleep(TICK * 2)
+    with psycopg.connect({dsn!r}) as conn:
+        conn.execute("select pg_sleep(0.3)")
+
+    global queried
+    queried = True
+
+jobs = [gevent.spawn(ticker), gevent.spawn(querier)]
+gevent.joinall(jobs, timeout=3)
+print(json.dumps(dts))
+"""
+    cmdline = [sys.executable, "-c", script]
+    rv = sp.run(cmdline, check=True, text=True, stdout=sp.PIPE)
+    dts = json.loads(rv.stdout)
+
+    for dt in dts:
+        assert TICK <= dt < TICK * 1.1
+
+
+@pytest.mark.skipif("not psycopg._cmodule._psycopg")
+def test_patched_dont_use_wait_c():
+    if psycopg.waiting.wait is not psycopg.waiting.wait_c:
+        pytest.skip("wait_c not normally in use")
+
+    script = """
+import gevent.monkey
+gevent.monkey.patch_all()
+
+import psycopg
+assert psycopg.waiting.wait is not psycopg.waiting.wait_c
+"""
+    sp.check_call([sys.executable, "-c", script])
+
+
+@pytest.mark.skipif("not psycopg._cmodule._psycopg")
+def test_unpatched_still_use_wait_c():
+    if psycopg.waiting.wait is not psycopg.waiting.wait_c:
+        pytest.skip("wait_c not normally in use")
+
+    script = """
+import gevent.monkey
+
+import psycopg
+assert psycopg.waiting.wait is psycopg.waiting.wait_c
+"""
+    sp.check_call([sys.executable, "-c", script])
index c6b3e08e312cfd16950166af255fb62943f087db..9b144d7d6be31bc0edb32884db1c89e39b69737c 100644 (file)
@@ -1,10 +1,11 @@
 import pytest
 
 from psycopg._cmodule import _psycopg
+from psycopg.conninfo import conninfo_to_dict
 
 
 @pytest.mark.parametrize(
-    "args, kwargs, want_conninfo",
+    "args, kwargs, want",
     [
         ((), {}, ""),
         (("dbname=foo",), {"user": "bar"}, "dbname=foo user=bar"),
@@ -12,24 +13,25 @@ from psycopg._cmodule import _psycopg
         ((), {"user": "foo", "dbname": None}, "user=foo"),
     ],
 )
-def test_connect(monkeypatch, dsn, args, kwargs, want_conninfo):
+def test_connect(monkeypatch, dsn_env, args, kwargs, want, setpgenv):
     # Check the main args passing from psycopg.connect to the conn generator
     # Details of the params manipulation are in test_conninfo.
     import psycopg.connection
 
     orig_connect = psycopg.generators.connect
 
-    got_conninfo = None
+    got_conninfo: str
 
     def mock_connect(conninfo):
         nonlocal got_conninfo
         got_conninfo = conninfo
-        return orig_connect(dsn)
+        return orig_connect(dsn_env)
 
+    setpgenv({})
     monkeypatch.setattr(psycopg.generators, "connect", mock_connect)
 
     conn = psycopg.connect(*args, **kwargs)
-    assert got_conninfo == want_conninfo
+    assert conninfo_to_dict(got_conninfo) == conninfo_to_dict(want)
     conn.close()
 
 
index 69d4e8d8aa1004d0c50358fe905fd16895e6e027..2e429eac95c62a5d5ff93448651ffae5e75c8de2 100644 (file)
@@ -125,25 +125,26 @@ def test_time_from_ticks(ticks, want):
         (("host=foo user=bar",), {}, "host=foo user=bar"),
         (("host=foo",), {"user": "baz"}, "host=foo user=baz"),
         (
-            ("host=foo port=5432",),
+            ("host=foo port=5433",),
             {"host": "qux", "user": "joe"},
-            "host=qux user=joe port=5432",
+            "host=qux user=joe port=5433",
         ),
         (("host=foo",), {"user": None}, "host=foo"),
     ],
 )
-def test_connect_args(monkeypatch, pgconn, args, kwargs, want):
-    the_conninfo: str
+def test_connect_args(monkeypatch, pgconn, args, kwargs, want, setpgenv):
+    got_conninfo: str
 
     def fake_connect(conninfo):
-        nonlocal the_conninfo
-        the_conninfo = conninfo
+        nonlocal got_conninfo
+        got_conninfo = conninfo
         return pgconn
         yield
 
+    setpgenv({})
     monkeypatch.setattr(psycopg.generators, "connect", fake_connect)
     conn = psycopg.connect(*args, **kwargs)
-    assert conninfo_to_dict(the_conninfo) == conninfo_to_dict(want)
+    assert conninfo_to_dict(got_conninfo) == conninfo_to_dict(want)
     conn.close()
 
 
index 5f84c5a192828e8fc5850ad9830981bee860b2fa..4da6a21490f0edf31ccffde0a4ea7f1ac33d7adf 100644 (file)
@@ -13,8 +13,6 @@ from psycopg.types import TypeInfo
 from psycopg.postgres import types as builtins
 from psycopg.types.array import register_array
 
-from ..utils import gc_collect
-
 
 tests_str = [
     ([[[[[["a"]]]]]], "{{{{{{a}}}}}}"),
@@ -342,7 +340,7 @@ def test_all_chars_with_bounds(conn, fmt_out):
 
 
 @pytest.mark.slow
-def test_register_array_leak(conn):
+def test_register_array_leak(conn, gc_collect):
     info = TypeInfo.fetch(conn, "date")
     ntypes = []
     for i in range(2):
index 88a6bb9a03479963f2dfbce792ced4c1663ab9c7..8eef880bd33cd2a2f13280d22308782af22de7c8 100644 (file)
@@ -1,4 +1,3 @@
-import gc
 import re
 import sys
 import operator
@@ -156,46 +155,6 @@ class VersionCheck:
         return (ver_maj, ver_min, ver_fix)
 
 
-def gc_collect():
-    """
-    gc.collect(), but more insisting.
-    """
-    for i in range(3):
-        gc.collect()
-
-
-NO_COUNT_TYPES: Tuple[type, ...] = ()
-
-if sys.version_info[:2] == (3, 10):
-    # On my laptop there are occasional creations of a single one of these objects
-    # with empty content, which might be some Decimal caching.
-    # Keeping the guard as strict as possible, to be extended if other types
-    # or versions are necessary.
-    try:
-        from _contextvars import Context  # type: ignore
-    except ImportError:
-        pass
-    else:
-        NO_COUNT_TYPES += (Context,)
-
-
-def gc_count() -> int:
-    """
-    len(gc.get_objects()), with subtleties.
-    """
-    if not NO_COUNT_TYPES:
-        return len(gc.get_objects())
-
-    # Note: not using a list comprehension because it pollutes the objects list.
-    rv = 0
-    for obj in gc.get_objects():
-        if isinstance(obj, NO_COUNT_TYPES):
-            continue
-        rv += 1
-
-    return rv
-
-
 @contextmanager
 def raiseif(cond, *args, **kwargs):
     """
index 741c21861f216ca0935d9ba9d3af58390c3309d4..ece9d9bad2c901afd2ae0c10d74f706b6d310049 100755 (executable)
@@ -2,6 +2,11 @@
 """Convert async code in the project to sync code.
 
 Note: the version of Python used to run this script affects the output.
+
+Hint: in order to explore the AST of a module you can run:
+
+    python -m ast path/to/module.py
+
 """
 
 from __future__ import annotations
@@ -14,22 +19,10 @@ from copy import deepcopy
 from typing import Any, Literal
 from pathlib import Path
 from argparse import ArgumentParser, Namespace, RawDescriptionHelpFormatter
+from importlib.metadata import version
 
 import ast_comments as ast
 
-# ast_comment versions 1.1.0, 1.1.1 have an import:
-#
-#   from typing import Dict, List, Tuple, Union
-#
-# which shadows some of the types defined in ast.
-#
-# Reported in https://github.com/t3rn0/ast-comments/issues/22
-import ast as ast_orig
-
-ast.Dict = ast_orig.Dict
-ast.List = ast_orig.List
-ast.Tuple = ast_orig.Tuple
-
 # The version of Python officially used for the conversion.
 # Output may differ in other versions.
 # Should be consistent with the Python version used in lint.yml
@@ -79,7 +72,12 @@ def main() -> int:
             current_ver,
         )
         logger.warning(
-            " You might get spurious changes that will be rejected by the CI linter."
+            "You might get spurious changes that will be rejected by the CI linter."
+        )
+        logger.warning(
+            "(use %s {--docker | --podman} to run it with Python %s in a container)",
+            sys.argv[0],
+            PYVER,
         )
 
     outputs = []
@@ -136,7 +134,7 @@ def run_in_container(engine: Literal["docker", "podman"]) -> int:
     """
     Build an image and run the script in a container.
     """
-    tag = f"async-to-sync:{PYVER}"
+    tag = f"async-to-sync:{version('ast_comments')}-{PYVER}"
 
     # Check if the image we want is present.
     cmdline = [engine, "inspect", tag, "-f", "{{ .Id }}"]
@@ -182,10 +180,6 @@ def tree_to_str(tree: ast.AST, filepath: Path) -> str:
     return rv
 
 
-# Hint: in order to explore the AST of a module you can run:
-# python -m ast path/tp/module.py
-
-
 class AsyncToSync(ast.NodeTransformer):
     def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AST:
         new_node = ast.FunctionDef(**node.__dict__)
@@ -308,6 +302,7 @@ class RenameAsyncToSync(ast.NodeTransformer):
         "aspawn": "spawn",
         "asynccontextmanager": "contextmanager",
         "connection_async": "connection",
+        "conninfo_attempts_async": "conninfo_attempts",
         "current_task_name": "current_thread_name",
         "cursor_async": "cursor",
         "ensure_table_async": "ensure_table",
index a545d4a12cc33a3bc83cbabb2c51b5957ae398c1..df6aa589ca8847a15c1b72c962f1b6b69597d68c 100755 (executable)
@@ -9,7 +9,7 @@
 
 set -euo pipefail
 
-python_versions="3.8.18 3.9.18 3.10.13 3.11.6 3.12.0"
+python_versions="3.8.10 3.9.13 3.10.11 3.11.6 3.12.0"
 pg_version=16
 
 function log {