]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Drop row_factory from base connection class
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 1 Aug 2021 15:40:01 +0000 (17:40 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 1 Aug 2021 20:37:31 +0000 (22:37 +0200)
Move it to the sync/async concrete classes, specifying the right type.

psycopg/psycopg/connection.py

index 1a32c0561a866793e9ed9ea28df9578e99bd64d7..e57a862a22caa42260a65fc186d9a56f26ea4853 100644 (file)
@@ -103,13 +103,8 @@ class BaseConnection(Generic[Row]):
     ConnStatus = pq.ConnStatus
     TransactionStatus = pq.TransactionStatus
 
-    def __init__(
-        self,
-        pgconn: "PGconn",
-        row_factory: Union[RowFactory[Row], AsyncRowFactory[Row]],
-    ):
+    def __init__(self, pgconn: "PGconn"):
         self.pgconn = pgconn  # TODO: document this
-        self._row_factory = row_factory
         self._autocommit = False
         self._adapters = adapt.AdaptersMap(postgres.adapters)
         self._notice_handlers: List[NoticeHandler] = []
@@ -416,13 +411,10 @@ class BaseConnection(Generic[Row]):
         conninfo: str = "",
         *,
         autocommit: bool = False,
-        row_factory: Optional[RowFactory[Any]] = None,
     ) -> PQGenConn[ConnectionType]:
         """Generator to connect to the database and create a new instance."""
         pgconn = yield from connect(conninfo)
-        if not row_factory:
-            row_factory = tuple_row
-        conn = cls(pgconn, row_factory)
+        conn = cls(pgconn)
         conn._autocommit = bool(autocommit)
         return conn
 
@@ -528,9 +520,13 @@ class Connection(BaseConnection[Row]):
 
     cursor_factory: Type[Cursor[Row]]
     server_cursor_factory: Type[ServerCursor[Row]]
+    row_factory: RowFactory[Row]
 
-    def __init__(self, pgconn: "PGconn", row_factory: RowFactory[Row]):
-        super().__init__(pgconn, row_factory)
+    def __init__(
+        self, pgconn: "PGconn", row_factory: Optional[RowFactory[Row]] = None
+    ):
+        super().__init__(pgconn)
+        self.row_factory = row_factory or cast(RowFactory[Row], tuple_row)
         self.lock = threading.Lock()
         self.cursor_factory = Cursor
         self.server_cursor_factory = ServerCursor
@@ -571,12 +567,13 @@ class Connection(BaseConnection[Row]):
         Connect to a database server and return a new `Connection` instance.
         """
         conninfo, timeout = _conninfo_connect_timeout(conninfo, **kwargs)
-        return cls._wait_conn(
-            cls._connect_gen(
-                conninfo, autocommit=autocommit, row_factory=row_factory
-            ),
+        rv = cls._wait_conn(
+            cls._connect_gen(conninfo, autocommit=autocommit),
             timeout,
         )
+        if row_factory:
+            rv.row_factory = row_factory
+        return rv
 
     def __enter__(self) -> "Connection[Row]":
         return self
@@ -614,15 +611,6 @@ class Connection(BaseConnection[Row]):
         self._closed = True
         self.pgconn.finish()
 
-    @property
-    def row_factory(self) -> RowFactory[Row]:
-        """Writable attribute to control how result rows are formed."""
-        return cast(RowFactory[Row], self._row_factory)
-
-    @row_factory.setter
-    def row_factory(self, row_factory: RowFactory[Row]) -> None:
-        self._row_factory = row_factory
-
     @overload
     def cursor(self, *, binary: bool = False) -> Cursor[Row]:
         ...
@@ -788,9 +776,15 @@ class AsyncConnection(BaseConnection[Row]):
 
     cursor_factory: Type[AsyncCursor[Row]]
     server_cursor_factory: Type[AsyncServerCursor[Row]]
+    row_factory: AsyncRowFactory[Row]
 
-    def __init__(self, pgconn: "PGconn", row_factory: AsyncRowFactory[Row]):
-        super().__init__(pgconn, row_factory)
+    def __init__(
+        self,
+        pgconn: "PGconn",
+        row_factory: Optional[AsyncRowFactory[Row]] = None,
+    ):
+        super().__init__(pgconn)
+        self.row_factory = row_factory or cast(AsyncRowFactory[Row], tuple_row)
         self.lock = asyncio.Lock()
         self.cursor_factory = AsyncCursor
         self.server_cursor_factory = AsyncServerCursor
@@ -824,16 +818,17 @@ class AsyncConnection(BaseConnection[Row]):
         conninfo: str = "",
         *,
         autocommit: bool = False,
-        row_factory: Optional[RowFactory[Row]] = None,
+        row_factory: Optional[AsyncRowFactory[Row]] = None,
         **kwargs: Any,
     ) -> "AsyncConnection[Any]":
         conninfo, timeout = _conninfo_connect_timeout(conninfo, **kwargs)
-        return await cls._wait_conn(
-            cls._connect_gen(
-                conninfo, autocommit=autocommit, row_factory=row_factory
-            ),
+        rv = await cls._wait_conn(
+            cls._connect_gen(conninfo, autocommit=autocommit),
             timeout,
         )
+        if row_factory:
+            rv.row_factory = row_factory
+        return rv
 
     async def __aenter__(self) -> "AsyncConnection[Row]":
         return self
@@ -870,15 +865,6 @@ class AsyncConnection(BaseConnection[Row]):
         self._closed = True
         self.pgconn.finish()
 
-    @property
-    def row_factory(self) -> AsyncRowFactory[Row]:
-        """Writable attribute to control how result rows are formed."""
-        return cast(AsyncRowFactory[Row], self._row_factory)
-
-    @row_factory.setter
-    def row_factory(self, row_factory: AsyncRowFactory[Row]) -> None:
-        self._row_factory = row_factory
-
     @overload
     def cursor(self, *, binary: bool = False) -> AsyncCursor[Row]:
         ...