From: Daniele Varrazzo Date: Sat, 17 Jul 2021 00:14:47 +0000 (+0200) Subject: Add cursor_factory and server_cursor_factory attributes X-Git-Tag: 3.0.dev2~51 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=118bdef3be60a39c5dbca01a8bf37171740c36d2;p=thirdparty%2Fpsycopg.git Add cursor_factory and server_cursor_factory attributes --- diff --git a/docs/api/connections.rst b/docs/api/connections.rst index 6f5abdf53..c30f10350 100644 --- a/docs/api/connections.rst +++ b/docs/api/connections.rst @@ -64,7 +64,6 @@ The `!Connection` class .. autoattribute:: closed .. autoattribute:: broken - .. method:: cursor(*, binary: bool = False, row_factory: Optional[RowFactory] = None) -> Cursor .. method:: cursor(name: str, *, binary: bool = False, row_factory: Optional[RowFactory] = None) -> ServerCursor :noindex: @@ -79,10 +78,25 @@ The `!Connection` class loader. See :ref:`binary-data` for details. :param row_factory: If specified override the `row_factory` set on the connection. See :ref:`row-factories` for details. + :return: A cursor of the class specified by `cursor_factory` (or + `server_cursor_factory` if *name* is specified). .. note:: You can use :ref:`with conn.cursor(): ...` to close the cursor automatically when the block is exited. + .. autoattribute:: cursor_factory + + The type, of factory function, returned by `cursor()` and `execute()`. + + Default is `psycopg.Cursor`. + + .. autoattribute:: server_cursor_factory + + The type, of factory function, returned by `cursor()` when a name is + specified. + + Default is `psycopg.ServerCursor`. + .. automethod:: execute(query, params=None, prepare=None) -> Cursor :param query: The query to execute. @@ -225,6 +239,14 @@ The `!AsyncConnection` class .. note:: You can use ``async with conn.cursor() as cur: ...`` to close the cursor automatically when the block is exited. + .. autoattribute:: cursor_factory + + Default is `psycopg.AsyncCursor`. + + .. autoattribute:: server_cursor_factory + + Default is `psycopg.AsyncServerCursor`. + .. automethod:: execute(query, params=None, prepare=None) -> AsyncCursor .. automethod:: commit .. automethod:: rollback diff --git a/psycopg/psycopg/connection.py b/psycopg/psycopg/connection.py index 9f8f08988..6eb9d8991 100644 --- a/psycopg/psycopg/connection.py +++ b/psycopg/psycopg/connection.py @@ -444,9 +444,14 @@ class Connection(BaseConnection[Row]): __module__ = "psycopg" + cursor_factory: Type[Cursor[Row]] + server_cursor_factory: Type[ServerCursor[Row]] + def __init__(self, pgconn: "PGconn", row_factory: RowFactory[Row]): super().__init__(pgconn, row_factory) self.lock = threading.Lock() + self.cursor_factory = Cursor + self.server_cursor_factory = ServerCursor @overload @classmethod @@ -566,9 +571,11 @@ class Connection(BaseConnection[Row]): cur: Union[Cursor[Any], ServerCursor[Any]] if name: - cur = ServerCursor(self, name=name, row_factory=row_factory) + cur = self.server_cursor_factory( + self, name=name, row_factory=row_factory + ) else: - cur = Cursor(self, row_factory=row_factory) + cur = self.cursor_factory(self, row_factory=row_factory) if binary: cur.format = Format.BINARY @@ -661,9 +668,14 @@ class AsyncConnection(BaseConnection[Row]): __module__ = "psycopg" + cursor_factory: Type[AsyncCursor[Row]] + server_cursor_factory: Type[AsyncServerCursor[Row]] + def __init__(self, pgconn: "PGconn", row_factory: RowFactory[Row]): super().__init__(pgconn, row_factory) self.lock = asyncio.Lock() + self.cursor_factory = AsyncCursor + self.server_cursor_factory = AsyncServerCursor @overload @classmethod @@ -781,9 +793,11 @@ class AsyncConnection(BaseConnection[Row]): cur: Union[AsyncCursor[Any], AsyncServerCursor[Any]] if name: - cur = AsyncServerCursor(self, name=name, row_factory=row_factory) + cur = self.server_cursor_factory( + self, name=name, row_factory=row_factory + ) else: - cur = AsyncCursor(self, row_factory=row_factory) + cur = self.cursor_factory(self, row_factory=row_factory) if binary: cur.format = Format.BINARY diff --git a/tests/test_connection.py b/tests/test_connection.py index 826dad434..3bf87b3dc 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -541,3 +541,28 @@ def test_fileno(conn): conn.close() with pytest.raises(psycopg.OperationalError): conn.fileno() + + +def test_cursor_factory(conn): + assert conn.cursor_factory is psycopg.Cursor + + class MyCursor(psycopg.Cursor): + pass + + conn.cursor_factory = MyCursor + with conn.cursor() as cur: + assert isinstance(cur, MyCursor) + + with conn.execute("select 1") as cur: + assert isinstance(cur, MyCursor) + + +def test_server_cursor_factory(conn): + assert conn.server_cursor_factory is psycopg.ServerCursor + + class MyServerCursor(psycopg.ServerCursor): + pass + + conn.server_cursor_factory = MyServerCursor + with conn.cursor(name="n") as cur: + assert isinstance(cur, MyServerCursor) diff --git a/tests/test_connection_async.py b/tests/test_connection_async.py index 845844d67..eb9578783 100644 --- a/tests/test_connection_async.py +++ b/tests/test_connection_async.py @@ -559,3 +559,28 @@ async def test_fileno(aconn): await aconn.close() with pytest.raises(psycopg.OperationalError): aconn.fileno() + + +async def test_cursor_factory(aconn): + assert aconn.cursor_factory is psycopg.AsyncCursor + + class MyCursor(psycopg.AsyncCursor): + pass + + aconn.cursor_factory = MyCursor + async with aconn.cursor() as cur: + assert isinstance(cur, MyCursor) + + async with (await aconn.execute("select 1")) as cur: + assert isinstance(cur, MyCursor) + + +async def test_server_cursor_factory(aconn): + assert aconn.server_cursor_factory is psycopg.AsyncServerCursor + + class MyServerCursor(psycopg.AsyncServerCursor): + pass + + aconn.server_cursor_factory = MyServerCursor + async with aconn.cursor(name="n") as cur: + assert isinstance(cur, MyServerCursor)