]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Drop RowConn from proto
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 30 Apr 2021 00:36:33 +0000 (02:36 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 30 Apr 2021 12:48:55 +0000 (14:48 +0200)
Use a more local CursorRow definition.

psycopg3/psycopg3/connection.py
psycopg3/psycopg3/proto.py

index 5d77e69d090581c025b2ccfe98bdca41cd1a40fd..fff67965cc38b09507500f686e2566a940205e5e 100644 (file)
@@ -10,7 +10,8 @@ import warnings
 import threading
 from types import TracebackType
 from typing import Any, AsyncIterator, Callable, Generic, Iterator, List
-from typing import NamedTuple, Optional, Type, Union, TYPE_CHECKING, overload
+from typing import NamedTuple, Optional, Type, TypeVar, Union
+from typing import overload, TYPE_CHECKING
 from weakref import ref, ReferenceType
 from functools import partial
 from contextlib import contextmanager
@@ -24,7 +25,7 @@ from .pq import ConnStatus, ExecStatus, TransactionStatus, Format
 from .sql import Composable
 from .rows import tuple_row, TupleRow
 from .proto import AdaptContext, ConnectionType, Params, PQGen, PQGenConn
-from .proto import Query, Row, RowConn, RowFactory, RV
+from .proto import Query, Row, RowFactory, RV
 from .cursor import Cursor, AsyncCursor
 from .conninfo import make_conninfo, ConnectionInfo
 from .generators import notifies
@@ -38,6 +39,10 @@ logger = logging.getLogger("psycopg3")
 connect: Callable[[str], PQGenConn["PGconn"]]
 execute: Callable[["PGconn"], PQGen[List["PGresult"]]]
 
+# Row Type variable for Cursor (when it needs to be distinguished from the
+# connection's one)
+CursorRow = TypeVar("CursorRow")
+
 if TYPE_CHECKING:
     from .pq.proto import PGconn, PGresult
     from .pool.base import BasePool
@@ -74,7 +79,7 @@ NoticeHandler = Callable[[e.Diagnostic], None]
 NotifyHandler = Callable[[Notify], None]
 
 
-class BaseConnection(AdaptContext, Generic[RowConn]):
+class BaseConnection(AdaptContext, Generic[Row]):
     """
     Base class for different types of connections.
 
@@ -98,7 +103,7 @@ class BaseConnection(AdaptContext, Generic[RowConn]):
     ConnStatus = pq.ConnStatus
     TransactionStatus = pq.TransactionStatus
 
-    def __init__(self, pgconn: "PGconn", row_factory: RowFactory[RowConn]):
+    def __init__(self, pgconn: "PGconn", row_factory: RowFactory[Row]):
         self.pgconn = pgconn  # TODO: document this
         self._row_factory = row_factory
         self._autocommit = False
@@ -224,17 +229,17 @@ class BaseConnection(AdaptContext, Generic[RowConn]):
         return self._adapters
 
     @property
-    def connection(self) -> "BaseConnection[RowConn]":
+    def connection(self) -> "BaseConnection[Row]":
         # implement the AdaptContext protocol
         return self
 
     @property
-    def row_factory(self) -> RowFactory[RowConn]:
+    def row_factory(self) -> RowFactory[Row]:
         """Writable attribute to control how result rows are formed."""
         return self._row_factory
 
     @row_factory.setter
-    def row_factory(self, row_factory: RowFactory[RowConn]) -> None:
+    def row_factory(self, row_factory: RowFactory[Row]) -> None:
         self._row_factory = row_factory
 
     def fileno(self) -> int:
@@ -265,7 +270,7 @@ class BaseConnection(AdaptContext, Generic[RowConn]):
 
     @staticmethod
     def _notice_handler(
-        wself: "ReferenceType[BaseConnection[RowConn]]", res: "PGresult"
+        wself: "ReferenceType[BaseConnection[Row]]", res: "PGresult"
     ) -> None:
         self = wself()
         if not (self and self._notice_handler):
@@ -294,7 +299,7 @@ class BaseConnection(AdaptContext, Generic[RowConn]):
 
     @staticmethod
     def _notify_handler(
-        wself: "ReferenceType[BaseConnection[RowConn]]", pgn: pq.PGnotify
+        wself: "ReferenceType[BaseConnection[Row]]", pgn: pq.PGnotify
     ) -> None:
         self = wself()
         if not (self and self._notify_handlers):
@@ -435,14 +440,14 @@ class BaseConnection(AdaptContext, Generic[RowConn]):
         yield from self._exec_command(b"rollback")
 
 
-class Connection(BaseConnection[RowConn]):
+class Connection(BaseConnection[Row]):
     """
     Wrapper for a connection to the database.
     """
 
     __module__ = "psycopg3"
 
-    def __init__(self, pgconn: "PGconn", row_factory: RowFactory[RowConn]):
+    def __init__(self, pgconn: "PGconn", row_factory: RowFactory[Row]):
         super().__init__(pgconn, row_factory)
         self.lock = threading.Lock()
 
@@ -453,9 +458,9 @@ class Connection(BaseConnection[RowConn]):
         conninfo: str = "",
         *,
         autocommit: bool = False,
-        row_factory: RowFactory[RowConn],
+        row_factory: RowFactory[Row],
         **kwargs: Union[None, int, str],
-    ) -> "Connection[RowConn]":
+    ) -> "Connection[Row]":
         ...
 
     @overload
@@ -475,7 +480,7 @@ class Connection(BaseConnection[RowConn]):
         conninfo: str = "",
         *,
         autocommit: bool = False,
-        row_factory: Optional[RowFactory[RowConn]] = None,
+        row_factory: Optional[RowFactory[Row]] = None,
         **kwargs: Any,
     ) -> "Connection[Any]":
         """
@@ -492,7 +497,7 @@ class Connection(BaseConnection[RowConn]):
             )
         )
 
-    def __enter__(self) -> "Connection[RowConn]":
+    def __enter__(self) -> "Connection[Row]":
         return self
 
     def __exit__(
@@ -529,25 +534,27 @@ class Connection(BaseConnection[RowConn]):
         self.pgconn.finish()
 
     @overload
-    def cursor(self, *, binary: bool = False) -> Cursor[RowConn]:
+    def cursor(self, *, binary: bool = False) -> Cursor[Row]:
         ...
 
     @overload
     def cursor(
-        self, *, binary: bool = False, row_factory: RowFactory[Row]
-    ) -> Cursor[Row]:
+        self, *, binary: bool = False, row_factory: RowFactory[CursorRow]
+    ) -> Cursor[CursorRow]:
         ...
 
     @overload
-    def cursor(
-        self, name: str, *, binary: bool = False
-    ) -> ServerCursor[RowConn]:
+    def cursor(self, name: str, *, binary: bool = False) -> ServerCursor[Row]:
         ...
 
     @overload
     def cursor(
-        self, name: str, *, binary: bool = False, row_factory: RowFactory[Row]
-    ) -> ServerCursor[Row]:
+        self,
+        name: str,
+        *,
+        binary: bool = False,
+        row_factory: RowFactory[CursorRow],
+    ) -> ServerCursor[CursorRow]:
         ...
 
     def cursor(
@@ -576,7 +583,7 @@ class Connection(BaseConnection[RowConn]):
         params: Optional[Params] = None,
         *,
         prepare: Optional[bool] = None,
-    ) -> Cursor[RowConn]:
+    ) -> Cursor[Row]:
         """Execute a query and return a cursor to read its results."""
         cur = self.cursor()
         try:
@@ -651,14 +658,14 @@ class Connection(BaseConnection[RowConn]):
             self.wait(self._set_client_encoding_gen(name))
 
 
-class AsyncConnection(BaseConnection[RowConn]):
+class AsyncConnection(BaseConnection[Row]):
     """
     Asynchronous wrapper for a connection to the database.
     """
 
     __module__ = "psycopg3"
 
-    def __init__(self, pgconn: "PGconn", row_factory: RowFactory[RowConn]):
+    def __init__(self, pgconn: "PGconn", row_factory: RowFactory[Row]):
         super().__init__(pgconn, row_factory)
         self.lock = asyncio.Lock()
 
@@ -669,9 +676,9 @@ class AsyncConnection(BaseConnection[RowConn]):
         conninfo: str = "",
         *,
         autocommit: bool = False,
-        row_factory: RowFactory[RowConn],
+        row_factory: RowFactory[Row],
         **kwargs: Union[None, int, str],
-    ) -> "AsyncConnection[RowConn]":
+    ) -> "AsyncConnection[Row]":
         ...
 
     @overload
@@ -691,7 +698,7 @@ class AsyncConnection(BaseConnection[RowConn]):
         conninfo: str = "",
         *,
         autocommit: bool = False,
-        row_factory: Optional[RowFactory[RowConn]] = None,
+        row_factory: Optional[RowFactory[Row]] = None,
         **kwargs: Any,
     ) -> "AsyncConnection[Any]":
         return await cls._wait_conn(
@@ -703,7 +710,7 @@ class AsyncConnection(BaseConnection[RowConn]):
             )
         )
 
-    async def __aenter__(self) -> "AsyncConnection[RowConn]":
+    async def __aenter__(self) -> "AsyncConnection[Row]":
         return self
 
     async def __aexit__(
@@ -739,25 +746,29 @@ class AsyncConnection(BaseConnection[RowConn]):
         self.pgconn.finish()
 
     @overload
-    def cursor(self, *, binary: bool = False) -> AsyncCursor[RowConn]:
+    def cursor(self, *, binary: bool = False) -> AsyncCursor[Row]:
         ...
 
     @overload
     def cursor(
-        self, *, binary: bool = False, row_factory: RowFactory[Row]
-    ) -> AsyncCursor[Row]:
+        self, *, binary: bool = False, row_factory: RowFactory[CursorRow]
+    ) -> AsyncCursor[CursorRow]:
         ...
 
     @overload
     def cursor(
         self, name: str, *, binary: bool = False
-    ) -> AsyncServerCursor[RowConn]:
+    ) -> AsyncServerCursor[Row]:
         ...
 
     @overload
     def cursor(
-        self, name: str, *, binary: bool = False, row_factory: RowFactory[Row]
-    ) -> AsyncServerCursor[Row]:
+        self,
+        name: str,
+        *,
+        binary: bool = False,
+        row_factory: RowFactory[CursorRow],
+    ) -> AsyncServerCursor[CursorRow]:
         ...
 
     def cursor(
@@ -786,7 +797,7 @@ class AsyncConnection(BaseConnection[RowConn]):
         params: Optional[Params] = None,
         *,
         prepare: Optional[bool] = None,
-    ) -> AsyncCursor[RowConn]:
+    ) -> AsyncCursor[Row]:
         cur = self.cursor()
         try:
             return await cur.execute(query, params, prepare=prepare)
index 657780967f53741fef6694956639dfde2d18f25e..6f6652a982b7650ead5afeb0b98e0939deeca84a 100644 (file)
@@ -49,8 +49,6 @@ Wait states.
 
 Row = TypeVar("Row")
 Row_co = TypeVar("Row_co", covariant=True)
-# Type variable for Connection (other are for Cursor).
-RowConn = TypeVar("RowConn")
 
 
 class RowMaker(Protocol[Row_co]):