From: Daniele Varrazzo Date: Thu, 23 Sep 2021 15:23:39 +0000 (+0200) Subject: Allow to specify a custom adapt context on connect X-Git-Tag: 3.0~62 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=f71c6f678269bd957f5b1a6bf5fa90ffd12ea7db;p=thirdparty%2Fpsycopg.git Allow to specify a custom adapt context on connect Close #83 --- diff --git a/docs/advanced/adapt.rst b/docs/advanced/adapt.rst index 5d40e0c78..891c923d5 100644 --- a/docs/advanced/adapt.rst +++ b/docs/advanced/adapt.rst @@ -23,9 +23,13 @@ returned. - Every context object derived from another context inherits its adapters mapping: cursors created from a connection inherit the connection's - configuration. Connections obtain an adapters map from the global map + configuration. + + By default, connections obtain an adapters map from the global map exposed as `psycopg.adapters`: changing the content of this object will - affect every connection created afterwards. + affect every connection created afterwards. You may specify a different + template adapters map using the *context* parameter on + `~psycopg.Connection.connect()`. .. image:: ../pictures/adapt.svg :align: center diff --git a/docs/api/connections.rst b/docs/api/connections.rst index 9bb190100..8fff068d5 100644 --- a/docs/api/connections.rst +++ b/docs/api/connections.rst @@ -43,6 +43,14 @@ The `!Connection` class `~psycopg.rows.tuple_row()`). See :ref:`row-factories` for details. + More specialized use: + + :param context: A context to copy the initial adapters configuration + from. It might be an `~psycopg.adapt.AdaptersMap` with + customized loaders and dumpers, used as a template to + create several connections. See :ref:`adaptation` for + further details. + .. __: https://www.postgresql.org/docs/current/libpq-connect.html #LIBPQ-CONNSTRING diff --git a/psycopg/psycopg/connection.py b/psycopg/psycopg/connection.py index f1c879cc1..419f01939 100644 --- a/psycopg/psycopg/connection.py +++ b/psycopg/psycopg/connection.py @@ -16,15 +16,16 @@ from functools import partial from contextlib import contextmanager from . import pq -from . import adapt from . import errors as e from . import waiting from . import postgres from . import encodings from .pq import ConnStatus, ExecStatus, TransactionStatus, Format -from .abc import ConnectionType, Params, PQGen, PQGenConn, Query, RV +from .abc import AdaptContext, ConnectionType, Params, Query, RV +from .abc import PQGen, PQGenConn from .sql import Composable from .rows import Row, RowFactory, tuple_row, TupleRow +from .adapt import AdaptersMap from ._enums import IsolationLevel from .cursor import Cursor from ._cmodule import _psycopg @@ -104,7 +105,7 @@ class BaseConnection(Generic[Row]): def __init__(self, pgconn: "PGconn"): self.pgconn = pgconn self._autocommit = False - self._adapters = adapt.AdaptersMap(postgres.adapters) + self._adapters = AdaptersMap(postgres.adapters) self._notice_handlers: List[NoticeHandler] = [] self._notify_handlers: List[NotifyHandler] = [] @@ -286,7 +287,7 @@ class BaseConnection(Generic[Row]): return ConnectionInfo(self.pgconn) @property - def adapters(self) -> adapt.AdaptersMap: + def adapters(self) -> AdaptersMap: return self._adapters @property @@ -548,6 +549,7 @@ class Connection(BaseConnection[Row]): *, autocommit: bool = False, row_factory: RowFactory[Row], + context: Optional[AdaptContext] = None, **kwargs: Union[None, int, str], ) -> "Connection[Row]": ... @@ -559,6 +561,7 @@ class Connection(BaseConnection[Row]): conninfo: str = "", *, autocommit: bool = False, + context: Optional[AdaptContext] = None, **kwargs: Union[None, int, str], ) -> "Connection[TupleRow]": ... @@ -570,6 +573,7 @@ class Connection(BaseConnection[Row]): *, autocommit: bool = False, row_factory: Optional[RowFactory[Row]] = None, + context: Optional[AdaptContext] = None, **kwargs: Any, ) -> "Connection[Any]": """ @@ -584,6 +588,8 @@ class Connection(BaseConnection[Row]): ) if row_factory: rv.row_factory = row_factory + if context: + rv._adapters = AdaptersMap(context.adapters) return rv def __enter__(self) -> "Connection[Row]": diff --git a/psycopg/psycopg/connection_async.py b/psycopg/psycopg/connection_async.py index 38c032ff1..44bd6d040 100644 --- a/psycopg/psycopg/connection_async.py +++ b/psycopg/psycopg/connection_async.py @@ -14,8 +14,9 @@ from typing import cast, overload, TYPE_CHECKING from . import errors as e from . import waiting from .pq import Format -from .abc import Params, PQGen, PQGenConn, Query, RV +from .abc import AdaptContext, Params, PQGen, PQGenConn, Query, RV from .rows import Row, AsyncRowFactory, tuple_row, TupleRow +from .adapt import AdaptersMap from ._enums import IsolationLevel from ._compat import asynccontextmanager from .conninfo import make_conninfo, conninfo_to_dict @@ -62,6 +63,7 @@ class AsyncConnection(BaseConnection[Row]): *, autocommit: bool = False, row_factory: AsyncRowFactory[Row], + context: Optional[AdaptContext] = None, **kwargs: Union[None, int, str], ) -> "AsyncConnection[Row]": ... @@ -73,6 +75,7 @@ class AsyncConnection(BaseConnection[Row]): conninfo: str = "", *, autocommit: bool = False, + context: Optional[AdaptContext] = None, **kwargs: Union[None, int, str], ) -> "AsyncConnection[TupleRow]": ... @@ -83,6 +86,7 @@ class AsyncConnection(BaseConnection[Row]): conninfo: str = "", *, autocommit: bool = False, + context: Optional[AdaptContext] = None, row_factory: Optional[AsyncRowFactory[Row]] = None, **kwargs: Any, ) -> "AsyncConnection[Any]": @@ -95,6 +99,8 @@ class AsyncConnection(BaseConnection[Row]): ) if row_factory: rv.row_factory = row_factory + if context: + rv._adapters = AdaptersMap(context.adapters) return rv async def __aenter__(self) -> "AsyncConnection[Row]": diff --git a/tests/test_connection.py b/tests/test_connection.py index be6917650..97b094cda 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -15,6 +15,7 @@ from psycopg.conninfo import conninfo_to_dict, make_conninfo from .utils import gc_collect from .test_cursor import my_row_factory +from .test_adapt import make_bin_dumper, make_dumper def test_connect(dsn): @@ -732,3 +733,28 @@ def test_get_connection_params(dsn, kwargs, exp): conninfo = make_conninfo(**params) assert conninfo_to_dict(conninfo) == exp[0] assert params.get("connect_timeout") == exp[1] + + +def test_connect_context(dsn): + ctx = psycopg.adapt.AdaptersMap(psycopg.adapters) + ctx.register_dumper(str, make_bin_dumper("b")) + ctx.register_dumper(str, make_dumper("t")) + + conn = psycopg.connect(dsn, context=ctx) + + cur = conn.execute("select %s", ["hello"]) + assert cur.fetchone()[0] == "hellot" + cur = conn.execute("select %b", ["hello"]) + assert cur.fetchone()[0] == "hellob" + + +def test_connect_context_copy(dsn, conn): + conn.adapters.register_dumper(str, make_bin_dumper("b")) + conn.adapters.register_dumper(str, make_dumper("t")) + + conn2 = psycopg.connect(dsn, context=conn) + + cur = conn2.execute("select %s", ["hello"]) + assert cur.fetchone()[0] == "hellot" + cur = conn2.execute("select %b", ["hello"]) + assert cur.fetchone()[0] == "hellob" diff --git a/tests/test_connection_async.py b/tests/test_connection_async.py index 7c71cb0b2..7423c1466 100644 --- a/tests/test_connection_async.py +++ b/tests/test_connection_async.py @@ -15,6 +15,7 @@ from psycopg.conninfo import conninfo_to_dict, make_conninfo from .utils import gc_collect from .test_cursor import my_row_factory from .test_connection import tx_params, tx_values_map, conninfo_params_timeout +from .test_adapt import make_bin_dumper, make_dumper pytestmark = pytest.mark.asyncio @@ -707,3 +708,28 @@ async def test_get_connection_params(dsn, kwargs, exp): conninfo = make_conninfo(**params) assert conninfo_to_dict(conninfo) == exp[0] assert params["connect_timeout"] == exp[1] + + +async def test_connect_context_adapters(dsn): + ctx = psycopg.adapt.AdaptersMap(psycopg.adapters) + ctx.register_dumper(str, make_bin_dumper("b")) + ctx.register_dumper(str, make_dumper("t")) + + conn = await psycopg.AsyncConnection.connect(dsn, context=ctx) + + cur = await conn.execute("select %s", ["hello"]) + assert (await cur.fetchone())[0] == "hellot" + cur = await conn.execute("select %b", ["hello"]) + assert (await cur.fetchone())[0] == "hellob" + + +async def test_connect_context_copy(dsn, aconn): + aconn.adapters.register_dumper(str, make_bin_dumper("b")) + aconn.adapters.register_dumper(str, make_dumper("t")) + + aconn2 = await psycopg.AsyncConnection.connect(dsn, context=aconn) + + cur = await aconn2.execute("select %s", ["hello"]) + assert (await cur.fetchone())[0] == "hellot" + cur = await aconn2.execute("select %b", ["hello"]) + assert (await cur.fetchone())[0] == "hellob"