From: Denis Laxalde Date: Fri, 12 Feb 2021 09:31:30 +0000 (+0100) Subject: Add row_factory as connection attribute and connect argument X-Git-Tag: 3.0.dev0~106^2~14 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=f4e874ca04acec37a37e46b9b6a228a463c0cf95;p=thirdparty%2Fpsycopg.git Add row_factory as connection attribute and connect argument When passing 'row_factory' to connect(), respective attribute will be set on the connection instance. This will be used as default at cursor creation and can be overridden with conn.cursor(row_factory=...) or conn.execute(row_factory=...). We use a '_null_row_factory' marker to handle None-value passed to .cursor() or .execute() for disabling the default row factory. --- diff --git a/psycopg3/psycopg3/connection.py b/psycopg3/psycopg3/connection.py index fae588238..d35eb68fd 100644 --- a/psycopg3/psycopg3/connection.py +++ b/psycopg3/psycopg3/connection.py @@ -78,6 +78,9 @@ NoticeHandler = Callable[[e.Diagnostic], None] NotifyHandler = Callable[[Notify], None] +_null_row_factory: RowFactory = object() # type: ignore[assignment] + + class BaseConnection(AdaptContext): """ Base class for different types of connections. @@ -102,6 +105,8 @@ class BaseConnection(AdaptContext): ConnStatus = pq.ConnStatus TransactionStatus = pq.TransactionStatus + row_factory: Optional[RowFactory] = None + def __init__(self, pgconn: "PGconn"): self.pgconn = pgconn # TODO: document this self._autocommit = False @@ -312,6 +317,7 @@ class BaseConnection(AdaptContext): conninfo: str = "", *, autocommit: bool = False, + row_factory: Optional[RowFactory] = None, **kwargs: Any, ) -> PQGenConn[ConnectionType]: """Generator to connect to the database and create a new instance.""" @@ -319,6 +325,7 @@ class BaseConnection(AdaptContext): pgconn = yield from connect(conninfo) conn = cls(pgconn) conn._autocommit = autocommit + conn.row_factory = row_factory return conn def _exec_command(self, command: Query) -> PQGen["PGresult"]: @@ -405,7 +412,12 @@ class Connection(BaseConnection): @classmethod def connect( - cls, conninfo: str = "", *, autocommit: bool = False, **kwargs: Any + cls, + conninfo: str = "", + *, + autocommit: bool = False, + row_factory: Optional[RowFactory] = None, + **kwargs: Any, ) -> "Connection": """ Connect to a database server and return a new `Connection` instance. @@ -413,7 +425,12 @@ class Connection(BaseConnection): TODO: connection_timeout to be implemented. """ return cls._wait_conn( - cls._connect_gen(conninfo, autocommit=autocommit, **kwargs) + cls._connect_gen( + conninfo, + autocommit=autocommit, + row_factory=row_factory, + **kwargs, + ) ) def __enter__(self) -> "Connection": @@ -465,12 +482,14 @@ class Connection(BaseConnection): name: str = "", *, binary: bool = False, - row_factory: Optional[RowFactory] = None, + row_factory: Optional[RowFactory] = _null_row_factory, ) -> Union[Cursor, ServerCursor]: """ Return a new cursor to send commands and queries to the connection. """ format = Format.BINARY if binary else Format.TEXT + if row_factory is _null_row_factory: + row_factory = self.row_factory if name: return ServerCursor( self, name=name, format=format, row_factory=row_factory @@ -484,9 +503,11 @@ class Connection(BaseConnection): params: Optional[Params] = None, *, prepare: Optional[bool] = None, - row_factory: Optional[RowFactory] = None, + row_factory: Optional[RowFactory] = _null_row_factory, ) -> Cursor: """Execute a query and return a cursor to read its results.""" + if row_factory is _null_row_factory: + row_factory = self.row_factory cur = self.cursor(row_factory=row_factory) return cur.execute(query, params, prepare=prepare) @@ -569,10 +590,20 @@ class AsyncConnection(BaseConnection): @classmethod async def connect( - cls, conninfo: str = "", *, autocommit: bool = False, **kwargs: Any + cls, + conninfo: str = "", + *, + autocommit: bool = False, + row_factory: Optional[RowFactory] = None, + **kwargs: Any, ) -> "AsyncConnection": return await cls._wait_conn( - cls._connect_gen(conninfo, autocommit=autocommit, **kwargs) + cls._connect_gen( + conninfo, + autocommit=autocommit, + row_factory=row_factory, + **kwargs, + ) ) async def __aenter__(self) -> "AsyncConnection": @@ -623,12 +654,14 @@ class AsyncConnection(BaseConnection): name: str = "", *, binary: bool = False, - row_factory: Optional[RowFactory] = None, + row_factory: Optional[RowFactory] = _null_row_factory, ) -> Union[AsyncCursor, AsyncServerCursor]: """ Return a new `AsyncCursor` to send commands and queries to the connection. """ format = Format.BINARY if binary else Format.TEXT + if row_factory is _null_row_factory: + row_factory = self.row_factory if name: return AsyncServerCursor( self, name=name, format=format, row_factory=row_factory @@ -642,8 +675,10 @@ class AsyncConnection(BaseConnection): params: Optional[Params] = None, *, prepare: Optional[bool] = None, - row_factory: Optional[RowFactory] = None, + row_factory: Optional[RowFactory] = _null_row_factory, ) -> AsyncCursor: + if row_factory is _null_row_factory: + row_factory = self.row_factory cur = self.cursor(row_factory=row_factory) return await cur.execute(query, params, prepare=prepare) diff --git a/tests/test_connection.py b/tests/test_connection.py index aef5825ac..c6b33b2fa 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -12,6 +12,7 @@ from psycopg3 import encodings from psycopg3 import Connection, Notify from psycopg3.errors import UndefinedTable from psycopg3.conninfo import conninfo_to_dict +from .test_cursor import my_row_factory def test_connect(dsn): @@ -485,6 +486,29 @@ def test_execute(conn): assert cur.fetchone() == {1, 2} +def test_row_factory(dsn): + conn = Connection.connect(dsn, row_factory=my_row_factory) + assert conn.row_factory + + cur = conn.execute("select 'a' as ve") + assert cur.fetchone() == ["Ave"] + + cur = conn.execute("select 'a' as ve", row_factory=None) + assert cur.fetchone() == ("a",) + + with conn.cursor(row_factory=lambda c: set) as cur: + cur.execute("select 1, 1, 2") + assert cur.fetchall() == [{1, 2}] + + with conn.cursor(row_factory=None) as cur: + cur.execute("select 1, 1, 2") + assert cur.fetchall() == [(1, 1, 2)] + + conn.row_factory = None + cur = conn.execute("select 'vale'") + assert cur.fetchone() == ("vale",) + + def test_str(conn): assert "[IDLE]" in str(conn) conn.close() diff --git a/tests/test_connection_async.py b/tests/test_connection_async.py index e694e6f97..1e090cd51 100644 --- a/tests/test_connection_async.py +++ b/tests/test_connection_async.py @@ -11,6 +11,7 @@ from psycopg3 import encodings from psycopg3 import AsyncConnection, Notify from psycopg3.errors import UndefinedTable from psycopg3.conninfo import conninfo_to_dict +from .test_cursor import my_row_factory pytestmark = pytest.mark.asyncio @@ -503,6 +504,29 @@ async def test_execute(aconn): assert await cur.fetchone() == {1, 2} +async def test_row_factory(dsn): + conn = await AsyncConnection.connect(dsn, row_factory=my_row_factory) + assert conn.row_factory + + cur = await conn.execute("select 'a' as ve") + assert await cur.fetchone() == ["Ave"] + + cur = await conn.execute("select 'a' as ve", row_factory=None) + assert await cur.fetchone() == ("a",) + + async with conn.cursor(row_factory=lambda c: set) as cur: + await cur.execute("select 1, 1, 2") + assert await cur.fetchall() == [{1, 2}] + + async with conn.cursor(row_factory=None) as cur: + await cur.execute("select 1, 1, 2") + assert await cur.fetchall() == [(1, 1, 2)] + + conn.row_factory = None + cur = await conn.execute("select 'vale'") + assert await cur.fetchone() == ("vale",) + + async def test_str(aconn): assert "[IDLE]" in str(aconn) await aconn.close()