]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Allow to specify a custom adapt context on connect
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 23 Sep 2021 15:23:39 +0000 (17:23 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 26 Sep 2021 17:45:29 +0000 (19:45 +0200)
Close #83

docs/advanced/adapt.rst
docs/api/connections.rst
psycopg/psycopg/connection.py
psycopg/psycopg/connection_async.py
tests/test_connection.py
tests/test_connection_async.py

index 5d40e0c78b65352db2f4298726e12f83fad3a54f..891c923d514cad1653e1b6594c9034e55ee31aff 100644 (file)
@@ -23,9 +23,13 @@ returned.
 
 - Every context object derived from another context inherits its adapters
   mapping: cursors created from a connection inherit the connection's
-  configuration. Connections obtain an adapters map from the global map
+  configuration.
+
+  By default, connections obtain an adapters map from the global map
   exposed as `psycopg.adapters`: changing the content of this object will
-  affect every connection created afterwards.
+  affect every connection created afterwards. You may specify a different
+  template adapters map using the *context* parameter on
+  `~psycopg.Connection.connect()`.
 
   .. image:: ../pictures/adapt.svg
      :align: center
index 9bb1901007065928fc720fd372dbb03889049fa0..8fff068d549e36c7e162227541140026fffceb21 100644 (file)
@@ -43,6 +43,14 @@ The `!Connection` class
                             `~psycopg.rows.tuple_row()`). See
                             :ref:`row-factories` for details.
 
+        More specialized use:
+
+        :param context: A context to copy the initial adapters configuration
+                        from. It might be an `~psycopg.adapt.AdaptersMap` with
+                        customized loaders and dumpers, used as a template to
+                        create several connections. See :ref:`adaptation` for
+                        further details.
+
         .. __: https://www.postgresql.org/docs/current/libpq-connect.html
             #LIBPQ-CONNSTRING
 
index f1c879cc198915dfac38eed5609f77d7f9682849..419f0193999ef06ef646892d38bc318b66aca03f 100644 (file)
@@ -16,15 +16,16 @@ from functools import partial
 from contextlib import contextmanager
 
 from . import pq
-from . import adapt
 from . import errors as e
 from . import waiting
 from . import postgres
 from . import encodings
 from .pq import ConnStatus, ExecStatus, TransactionStatus, Format
-from .abc import ConnectionType, Params, PQGen, PQGenConn, Query, RV
+from .abc import AdaptContext, ConnectionType, Params, Query, RV
+from .abc import PQGen, PQGenConn
 from .sql import Composable
 from .rows import Row, RowFactory, tuple_row, TupleRow
+from .adapt import AdaptersMap
 from ._enums import IsolationLevel
 from .cursor import Cursor
 from ._cmodule import _psycopg
@@ -104,7 +105,7 @@ class BaseConnection(Generic[Row]):
     def __init__(self, pgconn: "PGconn"):
         self.pgconn = pgconn
         self._autocommit = False
-        self._adapters = adapt.AdaptersMap(postgres.adapters)
+        self._adapters = AdaptersMap(postgres.adapters)
         self._notice_handlers: List[NoticeHandler] = []
         self._notify_handlers: List[NotifyHandler] = []
 
@@ -286,7 +287,7 @@ class BaseConnection(Generic[Row]):
         return ConnectionInfo(self.pgconn)
 
     @property
-    def adapters(self) -> adapt.AdaptersMap:
+    def adapters(self) -> AdaptersMap:
         return self._adapters
 
     @property
@@ -548,6 +549,7 @@ class Connection(BaseConnection[Row]):
         *,
         autocommit: bool = False,
         row_factory: RowFactory[Row],
+        context: Optional[AdaptContext] = None,
         **kwargs: Union[None, int, str],
     ) -> "Connection[Row]":
         ...
@@ -559,6 +561,7 @@ class Connection(BaseConnection[Row]):
         conninfo: str = "",
         *,
         autocommit: bool = False,
+        context: Optional[AdaptContext] = None,
         **kwargs: Union[None, int, str],
     ) -> "Connection[TupleRow]":
         ...
@@ -570,6 +573,7 @@ class Connection(BaseConnection[Row]):
         *,
         autocommit: bool = False,
         row_factory: Optional[RowFactory[Row]] = None,
+        context: Optional[AdaptContext] = None,
         **kwargs: Any,
     ) -> "Connection[Any]":
         """
@@ -584,6 +588,8 @@ class Connection(BaseConnection[Row]):
         )
         if row_factory:
             rv.row_factory = row_factory
+        if context:
+            rv._adapters = AdaptersMap(context.adapters)
         return rv
 
     def __enter__(self) -> "Connection[Row]":
index 38c032ff13374660257736641feb2cb890aab476..44bd6d040d69a5eea0c4a25f98b190b2da19fc72 100644 (file)
@@ -14,8 +14,9 @@ from typing import cast, overload, TYPE_CHECKING
 from . import errors as e
 from . import waiting
 from .pq import Format
-from .abc import Params, PQGen, PQGenConn, Query, RV
+from .abc import AdaptContext, Params, PQGen, PQGenConn, Query, RV
 from .rows import Row, AsyncRowFactory, tuple_row, TupleRow
+from .adapt import AdaptersMap
 from ._enums import IsolationLevel
 from ._compat import asynccontextmanager
 from .conninfo import make_conninfo, conninfo_to_dict
@@ -62,6 +63,7 @@ class AsyncConnection(BaseConnection[Row]):
         *,
         autocommit: bool = False,
         row_factory: AsyncRowFactory[Row],
+        context: Optional[AdaptContext] = None,
         **kwargs: Union[None, int, str],
     ) -> "AsyncConnection[Row]":
         ...
@@ -73,6 +75,7 @@ class AsyncConnection(BaseConnection[Row]):
         conninfo: str = "",
         *,
         autocommit: bool = False,
+        context: Optional[AdaptContext] = None,
         **kwargs: Union[None, int, str],
     ) -> "AsyncConnection[TupleRow]":
         ...
@@ -83,6 +86,7 @@ class AsyncConnection(BaseConnection[Row]):
         conninfo: str = "",
         *,
         autocommit: bool = False,
+        context: Optional[AdaptContext] = None,
         row_factory: Optional[AsyncRowFactory[Row]] = None,
         **kwargs: Any,
     ) -> "AsyncConnection[Any]":
@@ -95,6 +99,8 @@ class AsyncConnection(BaseConnection[Row]):
         )
         if row_factory:
             rv.row_factory = row_factory
+        if context:
+            rv._adapters = AdaptersMap(context.adapters)
         return rv
 
     async def __aenter__(self) -> "AsyncConnection[Row]":
index be6917650965cc6f0f01a74e2aa4ed62c11e5472..97b094cda40d51150427210442ab58319a88e095 100644 (file)
@@ -15,6 +15,7 @@ from psycopg.conninfo import conninfo_to_dict, make_conninfo
 
 from .utils import gc_collect
 from .test_cursor import my_row_factory
+from .test_adapt import make_bin_dumper, make_dumper
 
 
 def test_connect(dsn):
@@ -732,3 +733,28 @@ def test_get_connection_params(dsn, kwargs, exp):
     conninfo = make_conninfo(**params)
     assert conninfo_to_dict(conninfo) == exp[0]
     assert params.get("connect_timeout") == exp[1]
+
+
+def test_connect_context(dsn):
+    ctx = psycopg.adapt.AdaptersMap(psycopg.adapters)
+    ctx.register_dumper(str, make_bin_dumper("b"))
+    ctx.register_dumper(str, make_dumper("t"))
+
+    conn = psycopg.connect(dsn, context=ctx)
+
+    cur = conn.execute("select %s", ["hello"])
+    assert cur.fetchone()[0] == "hellot"
+    cur = conn.execute("select %b", ["hello"])
+    assert cur.fetchone()[0] == "hellob"
+
+
+def test_connect_context_copy(dsn, conn):
+    conn.adapters.register_dumper(str, make_bin_dumper("b"))
+    conn.adapters.register_dumper(str, make_dumper("t"))
+
+    conn2 = psycopg.connect(dsn, context=conn)
+
+    cur = conn2.execute("select %s", ["hello"])
+    assert cur.fetchone()[0] == "hellot"
+    cur = conn2.execute("select %b", ["hello"])
+    assert cur.fetchone()[0] == "hellob"
index 7c71cb0b26688418383a95bb8312b5b549281170..7423c1466b4807edc66174102fa85e050bd9df89 100644 (file)
@@ -15,6 +15,7 @@ from psycopg.conninfo import conninfo_to_dict, make_conninfo
 from .utils import gc_collect
 from .test_cursor import my_row_factory
 from .test_connection import tx_params, tx_values_map, conninfo_params_timeout
+from .test_adapt import make_bin_dumper, make_dumper
 
 pytestmark = pytest.mark.asyncio
 
@@ -707,3 +708,28 @@ async def test_get_connection_params(dsn, kwargs, exp):
     conninfo = make_conninfo(**params)
     assert conninfo_to_dict(conninfo) == exp[0]
     assert params["connect_timeout"] == exp[1]
+
+
+async def test_connect_context_adapters(dsn):
+    ctx = psycopg.adapt.AdaptersMap(psycopg.adapters)
+    ctx.register_dumper(str, make_bin_dumper("b"))
+    ctx.register_dumper(str, make_dumper("t"))
+
+    conn = await psycopg.AsyncConnection.connect(dsn, context=ctx)
+
+    cur = await conn.execute("select %s", ["hello"])
+    assert (await cur.fetchone())[0] == "hellot"
+    cur = await conn.execute("select %b", ["hello"])
+    assert (await cur.fetchone())[0] == "hellob"
+
+
+async def test_connect_context_copy(dsn, aconn):
+    aconn.adapters.register_dumper(str, make_bin_dumper("b"))
+    aconn.adapters.register_dumper(str, make_dumper("t"))
+
+    aconn2 = await psycopg.AsyncConnection.connect(dsn, context=aconn)
+
+    cur = await aconn2.execute("select %s", ["hello"])
+    assert (await cur.fetchone())[0] == "hellot"
+    cur = await aconn2.execute("select %b", ["hello"])
+    assert (await cur.fetchone())[0] == "hellob"