]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add cursor_factory and server_cursor_factory attributes
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 17 Jul 2021 00:14:47 +0000 (02:14 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 23 Jul 2021 14:38:39 +0000 (16:38 +0200)
docs/api/connections.rst
psycopg/psycopg/connection.py
tests/test_connection.py
tests/test_connection_async.py

index 6f5abdf5337aab7282df627e6e2775684182d5bb..c30f10350c668e782cc8a3055feab3f57cd3b659 100644 (file)
@@ -64,7 +64,6 @@ The `!Connection` class
     .. autoattribute:: closed
     .. autoattribute:: broken
 
-
     .. method:: cursor(*, binary: bool = False, row_factory: Optional[RowFactory] = None) -> Cursor
     .. method:: cursor(name: str, *, binary: bool = False, row_factory: Optional[RowFactory] = None) -> ServerCursor
         :noindex:
@@ -79,10 +78,25 @@ The `!Connection` class
                        loader. See :ref:`binary-data` for details.
         :param row_factory: If specified override the `row_factory` set on the
                             connection. See :ref:`row-factories` for details.
+        :return: A cursor of the class specified by `cursor_factory` (or
+                 `server_cursor_factory` if *name* is specified).
 
         .. note:: You can use :ref:`with conn.cursor(): ...<usage>`
             to close the cursor automatically when the block is exited.
 
+    .. autoattribute:: cursor_factory
+
+        The type, of factory function, returned by `cursor()` and `execute()`.
+
+        Default is `psycopg.Cursor`.
+
+    .. autoattribute:: server_cursor_factory
+
+        The type, of factory function, returned by `cursor()` when a name is
+        specified.
+
+        Default is `psycopg.ServerCursor`.
+
     .. automethod:: execute(query, params=None, prepare=None) -> Cursor
 
         :param query: The query to execute.
@@ -225,6 +239,14 @@ The `!AsyncConnection` class
         .. note:: You can use ``async with conn.cursor() as cur: ...`` to
             close the cursor automatically when the block is exited.
 
+    .. autoattribute:: cursor_factory
+
+        Default is `psycopg.AsyncCursor`.
+
+    .. autoattribute:: server_cursor_factory
+
+        Default is `psycopg.AsyncServerCursor`.
+
     .. automethod:: execute(query, params=None, prepare=None) -> AsyncCursor
     .. automethod:: commit
     .. automethod:: rollback
index 9f8f08988a10699097be8957eb8ff056df1f6ba4..6eb9d8991f4dba821c61c7d0d0c7fd09ceaf8804 100644 (file)
@@ -444,9 +444,14 @@ class Connection(BaseConnection[Row]):
 
     __module__ = "psycopg"
 
+    cursor_factory: Type[Cursor[Row]]
+    server_cursor_factory: Type[ServerCursor[Row]]
+
     def __init__(self, pgconn: "PGconn", row_factory: RowFactory[Row]):
         super().__init__(pgconn, row_factory)
         self.lock = threading.Lock()
+        self.cursor_factory = Cursor
+        self.server_cursor_factory = ServerCursor
 
     @overload
     @classmethod
@@ -566,9 +571,11 @@ class Connection(BaseConnection[Row]):
 
         cur: Union[Cursor[Any], ServerCursor[Any]]
         if name:
-            cur = ServerCursor(self, name=name, row_factory=row_factory)
+            cur = self.server_cursor_factory(
+                self, name=name, row_factory=row_factory
+            )
         else:
-            cur = Cursor(self, row_factory=row_factory)
+            cur = self.cursor_factory(self, row_factory=row_factory)
 
         if binary:
             cur.format = Format.BINARY
@@ -661,9 +668,14 @@ class AsyncConnection(BaseConnection[Row]):
 
     __module__ = "psycopg"
 
+    cursor_factory: Type[AsyncCursor[Row]]
+    server_cursor_factory: Type[AsyncServerCursor[Row]]
+
     def __init__(self, pgconn: "PGconn", row_factory: RowFactory[Row]):
         super().__init__(pgconn, row_factory)
         self.lock = asyncio.Lock()
+        self.cursor_factory = AsyncCursor
+        self.server_cursor_factory = AsyncServerCursor
 
     @overload
     @classmethod
@@ -781,9 +793,11 @@ class AsyncConnection(BaseConnection[Row]):
 
         cur: Union[AsyncCursor[Any], AsyncServerCursor[Any]]
         if name:
-            cur = AsyncServerCursor(self, name=name, row_factory=row_factory)
+            cur = self.server_cursor_factory(
+                self, name=name, row_factory=row_factory
+            )
         else:
-            cur = AsyncCursor(self, row_factory=row_factory)
+            cur = self.cursor_factory(self, row_factory=row_factory)
 
         if binary:
             cur.format = Format.BINARY
index 826dad4341c751e6d4e7f04fa11fc7ec781db5f9..3bf87b3dc41fe93f413019ab5acd184fee8036b8 100644 (file)
@@ -541,3 +541,28 @@ def test_fileno(conn):
     conn.close()
     with pytest.raises(psycopg.OperationalError):
         conn.fileno()
+
+
+def test_cursor_factory(conn):
+    assert conn.cursor_factory is psycopg.Cursor
+
+    class MyCursor(psycopg.Cursor):
+        pass
+
+    conn.cursor_factory = MyCursor
+    with conn.cursor() as cur:
+        assert isinstance(cur, MyCursor)
+
+    with conn.execute("select 1") as cur:
+        assert isinstance(cur, MyCursor)
+
+
+def test_server_cursor_factory(conn):
+    assert conn.server_cursor_factory is psycopg.ServerCursor
+
+    class MyServerCursor(psycopg.ServerCursor):
+        pass
+
+    conn.server_cursor_factory = MyServerCursor
+    with conn.cursor(name="n") as cur:
+        assert isinstance(cur, MyServerCursor)
index 845844d67678086f366e6acf1895094062943a41..eb95787833df5f4dc7cfd55e9c1cefaefca09c34 100644 (file)
@@ -559,3 +559,28 @@ async def test_fileno(aconn):
     await aconn.close()
     with pytest.raises(psycopg.OperationalError):
         aconn.fileno()
+
+
+async def test_cursor_factory(aconn):
+    assert aconn.cursor_factory is psycopg.AsyncCursor
+
+    class MyCursor(psycopg.AsyncCursor):
+        pass
+
+    aconn.cursor_factory = MyCursor
+    async with aconn.cursor() as cur:
+        assert isinstance(cur, MyCursor)
+
+    async with (await aconn.execute("select 1")) as cur:
+        assert isinstance(cur, MyCursor)
+
+
+async def test_server_cursor_factory(aconn):
+    assert aconn.server_cursor_factory is psycopg.AsyncServerCursor
+
+    class MyServerCursor(psycopg.AsyncServerCursor):
+        pass
+
+    aconn.server_cursor_factory = MyServerCursor
+    async with aconn.cursor(name="n") as cur:
+        assert isinstance(cur, MyServerCursor)