]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add row_factory as connection attribute and connect argument
authorDenis Laxalde <denis.laxalde@dalibo.com>
Fri, 12 Feb 2021 09:31:30 +0000 (10:31 +0100)
committerDenis Laxalde <denis.laxalde@dalibo.com>
Fri, 12 Feb 2021 10:34:35 +0000 (11:34 +0100)
When passing 'row_factory' to connect(), respective attribute will be
set on the connection instance. This will be used as default at cursor
creation and can be overridden with conn.cursor(row_factory=...) or
conn.execute(row_factory=...).

We use a '_null_row_factory' marker to handle None-value passed to
.cursor() or .execute() for disabling the default row factory.

psycopg3/psycopg3/connection.py
tests/test_connection.py
tests/test_connection_async.py

index fae5882381a66f5d12f465402f529f33bef5bd6a..d35eb68fd29558ada60da65e24b0271dad7e259d 100644 (file)
@@ -78,6 +78,9 @@ NoticeHandler = Callable[[e.Diagnostic], None]
 NotifyHandler = Callable[[Notify], None]
 
 
+_null_row_factory: RowFactory = object()  # type: ignore[assignment]
+
+
 class BaseConnection(AdaptContext):
     """
     Base class for different types of connections.
@@ -102,6 +105,8 @@ class BaseConnection(AdaptContext):
     ConnStatus = pq.ConnStatus
     TransactionStatus = pq.TransactionStatus
 
+    row_factory: Optional[RowFactory] = None
+
     def __init__(self, pgconn: "PGconn"):
         self.pgconn = pgconn  # TODO: document this
         self._autocommit = False
@@ -312,6 +317,7 @@ class BaseConnection(AdaptContext):
         conninfo: str = "",
         *,
         autocommit: bool = False,
+        row_factory: Optional[RowFactory] = None,
         **kwargs: Any,
     ) -> PQGenConn[ConnectionType]:
         """Generator to connect to the database and create a new instance."""
@@ -319,6 +325,7 @@ class BaseConnection(AdaptContext):
         pgconn = yield from connect(conninfo)
         conn = cls(pgconn)
         conn._autocommit = autocommit
+        conn.row_factory = row_factory
         return conn
 
     def _exec_command(self, command: Query) -> PQGen["PGresult"]:
@@ -405,7 +412,12 @@ class Connection(BaseConnection):
 
     @classmethod
     def connect(
-        cls, conninfo: str = "", *, autocommit: bool = False, **kwargs: Any
+        cls,
+        conninfo: str = "",
+        *,
+        autocommit: bool = False,
+        row_factory: Optional[RowFactory] = None,
+        **kwargs: Any,
     ) -> "Connection":
         """
         Connect to a database server and return a new `Connection` instance.
@@ -413,7 +425,12 @@ class Connection(BaseConnection):
         TODO: connection_timeout to be implemented.
         """
         return cls._wait_conn(
-            cls._connect_gen(conninfo, autocommit=autocommit, **kwargs)
+            cls._connect_gen(
+                conninfo,
+                autocommit=autocommit,
+                row_factory=row_factory,
+                **kwargs,
+            )
         )
 
     def __enter__(self) -> "Connection":
@@ -465,12 +482,14 @@ class Connection(BaseConnection):
         name: str = "",
         *,
         binary: bool = False,
-        row_factory: Optional[RowFactory] = None,
+        row_factory: Optional[RowFactory] = _null_row_factory,
     ) -> Union[Cursor, ServerCursor]:
         """
         Return a new cursor to send commands and queries to the connection.
         """
         format = Format.BINARY if binary else Format.TEXT
+        if row_factory is _null_row_factory:
+            row_factory = self.row_factory
         if name:
             return ServerCursor(
                 self, name=name, format=format, row_factory=row_factory
@@ -484,9 +503,11 @@ class Connection(BaseConnection):
         params: Optional[Params] = None,
         *,
         prepare: Optional[bool] = None,
-        row_factory: Optional[RowFactory] = None,
+        row_factory: Optional[RowFactory] = _null_row_factory,
     ) -> Cursor:
         """Execute a query and return a cursor to read its results."""
+        if row_factory is _null_row_factory:
+            row_factory = self.row_factory
         cur = self.cursor(row_factory=row_factory)
         return cur.execute(query, params, prepare=prepare)
 
@@ -569,10 +590,20 @@ class AsyncConnection(BaseConnection):
 
     @classmethod
     async def connect(
-        cls, conninfo: str = "", *, autocommit: bool = False, **kwargs: Any
+        cls,
+        conninfo: str = "",
+        *,
+        autocommit: bool = False,
+        row_factory: Optional[RowFactory] = None,
+        **kwargs: Any,
     ) -> "AsyncConnection":
         return await cls._wait_conn(
-            cls._connect_gen(conninfo, autocommit=autocommit, **kwargs)
+            cls._connect_gen(
+                conninfo,
+                autocommit=autocommit,
+                row_factory=row_factory,
+                **kwargs,
+            )
         )
 
     async def __aenter__(self) -> "AsyncConnection":
@@ -623,12 +654,14 @@ class AsyncConnection(BaseConnection):
         name: str = "",
         *,
         binary: bool = False,
-        row_factory: Optional[RowFactory] = None,
+        row_factory: Optional[RowFactory] = _null_row_factory,
     ) -> Union[AsyncCursor, AsyncServerCursor]:
         """
         Return a new `AsyncCursor` to send commands and queries to the connection.
         """
         format = Format.BINARY if binary else Format.TEXT
+        if row_factory is _null_row_factory:
+            row_factory = self.row_factory
         if name:
             return AsyncServerCursor(
                 self, name=name, format=format, row_factory=row_factory
@@ -642,8 +675,10 @@ class AsyncConnection(BaseConnection):
         params: Optional[Params] = None,
         *,
         prepare: Optional[bool] = None,
-        row_factory: Optional[RowFactory] = None,
+        row_factory: Optional[RowFactory] = _null_row_factory,
     ) -> AsyncCursor:
+        if row_factory is _null_row_factory:
+            row_factory = self.row_factory
         cur = self.cursor(row_factory=row_factory)
         return await cur.execute(query, params, prepare=prepare)
 
index aef5825ac34ae60fd34ffb81387ee27a72ac04be..c6b33b2fa4145b04da53927730ef117356905ee7 100644 (file)
@@ -12,6 +12,7 @@ from psycopg3 import encodings
 from psycopg3 import Connection, Notify
 from psycopg3.errors import UndefinedTable
 from psycopg3.conninfo import conninfo_to_dict
+from .test_cursor import my_row_factory
 
 
 def test_connect(dsn):
@@ -485,6 +486,29 @@ def test_execute(conn):
     assert cur.fetchone() == {1, 2}
 
 
+def test_row_factory(dsn):
+    conn = Connection.connect(dsn, row_factory=my_row_factory)
+    assert conn.row_factory
+
+    cur = conn.execute("select 'a' as ve")
+    assert cur.fetchone() == ["Ave"]
+
+    cur = conn.execute("select 'a' as ve", row_factory=None)
+    assert cur.fetchone() == ("a",)
+
+    with conn.cursor(row_factory=lambda c: set) as cur:
+        cur.execute("select 1, 1, 2")
+        assert cur.fetchall() == [{1, 2}]
+
+    with conn.cursor(row_factory=None) as cur:
+        cur.execute("select 1, 1, 2")
+        assert cur.fetchall() == [(1, 1, 2)]
+
+    conn.row_factory = None
+    cur = conn.execute("select 'vale'")
+    assert cur.fetchone() == ("vale",)
+
+
 def test_str(conn):
     assert "[IDLE]" in str(conn)
     conn.close()
index e694e6f971004155ee6d6faa36bae6d70485f616..1e090cd51b87d26563bc0ff01f27b67e0d1567a7 100644 (file)
@@ -11,6 +11,7 @@ from psycopg3 import encodings
 from psycopg3 import AsyncConnection, Notify
 from psycopg3.errors import UndefinedTable
 from psycopg3.conninfo import conninfo_to_dict
+from .test_cursor import my_row_factory
 
 pytestmark = pytest.mark.asyncio
 
@@ -503,6 +504,29 @@ async def test_execute(aconn):
     assert await cur.fetchone() == {1, 2}
 
 
+async def test_row_factory(dsn):
+    conn = await AsyncConnection.connect(dsn, row_factory=my_row_factory)
+    assert conn.row_factory
+
+    cur = await conn.execute("select 'a' as ve")
+    assert await cur.fetchone() == ["Ave"]
+
+    cur = await conn.execute("select 'a' as ve", row_factory=None)
+    assert await cur.fetchone() == ("a",)
+
+    async with conn.cursor(row_factory=lambda c: set) as cur:
+        await cur.execute("select 1, 1, 2")
+        assert await cur.fetchall() == [{1, 2}]
+
+    async with conn.cursor(row_factory=None) as cur:
+        await cur.execute("select 1, 1, 2")
+        assert await cur.fetchall() == [(1, 1, 2)]
+
+    conn.row_factory = None
+    cur = await conn.execute("select 'vale'")
+    assert await cur.fetchone() == ("vale",)
+
+
 async def test_str(aconn):
     assert "[IDLE]" in str(aconn)
     await aconn.close()