]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
feat: add cursor_factory parameter to `connect()`.
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 12 May 2022 12:48:09 +0000 (14:48 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 12 May 2022 13:36:41 +0000 (15:36 +0200)
This makes easier to use ClientCursor and port more easily code running
on psycopg2 dong more DDL operations and caring less about performance.

docs/api/connections.rst
docs/news.rst
psycopg/psycopg/connection.py
psycopg/psycopg/connection_async.py
tests/test_connection.py
tests/test_connection_async.py

index 8679271c4f5ddf199c3c0181698ddbd5b01048cc..c3a91561eae2119d1c5252b21829541a0b3ed906 100644 (file)
@@ -32,26 +32,25 @@ The `!Connection` class
     .. automethod:: connect
 
         :param conninfo: The `connection string`__ (a ``postgresql://`` url or
-                         a list of ``key=value`` pairs) to specify where and
-                         how to connect.
+            a list of ``key=value`` pairs) to specify where and how to connect.
         :param kwargs: Further parameters specifying the connection string.
-                       They override the ones specified in ``conninfo``.
+            They override the ones specified in ``conninfo``.
         :param autocommit: If `!True` don't start transactions automatically.
-                           See :ref:`transactions` for details.
+            See :ref:`transactions` for details.
         :param row_factory: The row factory specifying what type of records
-                            to create fetching data (default:
-                            `~psycopg.rows.tuple_row()`). See
-                            :ref:`row-factories` for details.
-        :param prepare_threshold: Set the `prepare_threshold` attribute of the
-                                  connection.
+            to create fetching data (default: `~psycopg.rows.tuple_row()`). See
+            :ref:`row-factories` for details.
+        :param cursor_factory: Initial value for the `cursor_factory` attribute
+            of the connection (new in Psycopg 3.1).
+        :param prepare_threshold: Initial value for the `prepare_threshold`
+            attribute of the connection (new in Psycopg 3.1).
 
         More specialized use:
 
         :param context: A context to copy the initial adapters configuration
-                        from. It might be an `~psycopg.adapt.AdaptersMap` with
-                        customized loaders and dumpers, used as a template to
-                        create several connections. See :ref:`adaptation` for
-                        further details.
+            from. It might be an `~psycopg.adapt.AdaptersMap` with customized
+            loaders and dumpers, used as a template to create several connections.
+            See :ref:`adaptation` for further details.
 
         .. __: https://www.postgresql.org/docs/current/libpq-connect.html
             #LIBPQ-CONNSTRING
@@ -67,7 +66,7 @@ The `!Connection` class
             .. __: https://www.postgresql.org/docs/current/libpq-envars.html
 
         .. versionchanged:: 3.1
-            added ``prepare_threshold`` parameter.
+            added ``prepare_threshold`` and ``cursor_factory`` parameters.
 
     .. automethod:: close
 
index 83762a50ffa3014a1f18060e49a82736d8a3ad13..089c536a4cc4e359a563142e74cdff53772bdc76 100644 (file)
@@ -22,6 +22,7 @@ Psycopg 3.1 (unreleased)
   (:ticket:`#145`).
 - Add `pq.PGconn.trace()` and related trace functions (:ticket:`#167`).
 - Add ``prepare_threshold`` parameter to `Connection` init (:ticket:`#200`).
+- Add ``cursor_factory`` parameter to `Connection` init.
 - Add `Error.pgconn` and `Error.pgresult` attributes (:ticket:`#242`).
 - Add explicit type cast to values converted by `sql.Literal` (:ticket:`#205`).
 - Drop support for Python 3.6.
index ff4d1ea7e5c930139e9aa087b97adff1ed4e0213..e8d06de72e732a9d393d5e45c457b1d6e11a9a3a 100644 (file)
@@ -651,7 +651,7 @@ class Connection(BaseConnection[Row]):
         super().__init__(pgconn)
         self.row_factory = row_factory
         self.lock = threading.Lock()
-        self.cursor_factory = cast("Type[Cursor[Row]]", Cursor)
+        self.cursor_factory = Cursor
         self.server_cursor_factory = ServerCursor
 
     @overload
@@ -663,6 +663,7 @@ class Connection(BaseConnection[Row]):
         autocommit: bool = False,
         row_factory: RowFactory[Row],
         prepare_threshold: Optional[int] = 5,
+        cursor_factory: Optional[Type[Cursor[Row]]] = None,
         context: Optional[AdaptContext] = None,
         **kwargs: Union[None, int, str],
     ) -> "Connection[Row]":
@@ -676,6 +677,7 @@ class Connection(BaseConnection[Row]):
         *,
         autocommit: bool = False,
         prepare_threshold: Optional[int] = 5,
+        cursor_factory: Optional[Type[Cursor[Any]]] = None,
         context: Optional[AdaptContext] = None,
         **kwargs: Union[None, int, str],
     ) -> "Connection[TupleRow]":
@@ -689,6 +691,7 @@ class Connection(BaseConnection[Row]):
         autocommit: bool = False,
         prepare_threshold: Optional[int] = 5,
         row_factory: Optional[RowFactory[Row]] = None,
+        cursor_factory: Optional[Type[Cursor[Row]]] = None,
         context: Optional[AdaptContext] = None,
         **kwargs: Any,
     ) -> "Connection[Any]":
@@ -708,6 +711,8 @@ class Connection(BaseConnection[Row]):
 
         if row_factory:
             rv.row_factory = row_factory
+        if cursor_factory:
+            rv.cursor_factory = cursor_factory
         if context:
             rv._adapters = AdaptersMap(context.adapters)
         rv.prepare_threshold = prepare_threshold
index 2cbaa745d80a3fac4d44ef3f9dfb19f4ecf06b31..3411dc6de7cbaead37b0d480b173dfbbaa1e83e4 100644 (file)
@@ -56,7 +56,7 @@ class AsyncConnection(BaseConnection[Row]):
         super().__init__(pgconn)
         self.row_factory = row_factory
         self.lock = asyncio.Lock()
-        self.cursor_factory = cast("Type[AsyncCursor[Row]]", AsyncCursor)
+        self.cursor_factory = AsyncCursor
         self.server_cursor_factory = AsyncServerCursor
 
     @overload
@@ -68,6 +68,7 @@ class AsyncConnection(BaseConnection[Row]):
         autocommit: bool = False,
         prepare_threshold: Optional[int] = 5,
         row_factory: AsyncRowFactory[Row],
+        cursor_factory: Optional[Type[AsyncCursor[Row]]] = None,
         context: Optional[AdaptContext] = None,
         **kwargs: Union[None, int, str],
     ) -> "AsyncConnection[Row]":
@@ -81,6 +82,7 @@ class AsyncConnection(BaseConnection[Row]):
         *,
         autocommit: bool = False,
         prepare_threshold: Optional[int] = 5,
+        cursor_factory: Optional[Type[AsyncCursor[Any]]] = None,
         context: Optional[AdaptContext] = None,
         **kwargs: Union[None, int, str],
     ) -> "AsyncConnection[TupleRow]":
@@ -95,6 +97,7 @@ class AsyncConnection(BaseConnection[Row]):
         prepare_threshold: Optional[int] = 5,
         context: Optional[AdaptContext] = None,
         row_factory: Optional[AsyncRowFactory[Row]] = None,
+        cursor_factory: Optional[Type[AsyncCursor[Row]]] = None,
         **kwargs: Any,
     ) -> "AsyncConnection[Any]":
 
@@ -121,6 +124,8 @@ class AsyncConnection(BaseConnection[Row]):
 
         if row_factory:
             rv.row_factory = row_factory
+        if cursor_factory:
+            rv.cursor_factory = cursor_factory
         if context:
             rv._adapters = AdaptersMap(context.adapters)
         rv.prepare_threshold = prepare_threshold
index 663bd1c153b0818aaf1d480a9acb13d50a919352..bdd3117b7eaaa47eda8294f7b0007d54c5609640 100644 (file)
@@ -551,6 +551,16 @@ def test_cursor_factory(conn):
         assert isinstance(cur, MyCursor)
 
 
+def test_cursor_factory_connect(dsn):
+    class MyCursor(psycopg.Cursor[psycopg.rows.Row]):
+        pass
+
+    with psycopg.connect(dsn, cursor_factory=MyCursor) as conn:
+        assert conn.cursor_factory is MyCursor
+        cur = conn.cursor()
+        assert type(cur) is MyCursor
+
+
 def test_server_cursor_factory(conn):
     assert conn.server_cursor_factory is psycopg.ServerCursor
 
index 912b4c90935b44135de96b0e068eca838646b028..8d48ed69eaa0a2fd0d03ccc68321bd4b3b5ad2cd 100644 (file)
@@ -556,6 +556,18 @@ async def test_cursor_factory(aconn):
         assert isinstance(cur, MyCursor)
 
 
+async def test_cursor_factory_connect(dsn):
+    class MyCursor(psycopg.AsyncCursor[psycopg.rows.Row]):
+        pass
+
+    async with await psycopg.AsyncConnection.connect(
+        dsn, cursor_factory=MyCursor
+    ) as conn:
+        assert conn.cursor_factory is MyCursor
+        cur = conn.cursor()
+        assert type(cur) is MyCursor
+
+
 async def test_server_cursor_factory(aconn):
     assert aconn.server_cursor_factory is psycopg.AsyncServerCursor