import threading
from types import TracebackType
from typing import Any, AsyncIterator, Callable, Generic, Iterator, List
-from typing import NamedTuple, Optional, Type, Union, TYPE_CHECKING, overload
+from typing import NamedTuple, Optional, Type, TypeVar, Union
+from typing import overload, TYPE_CHECKING
from weakref import ref, ReferenceType
from functools import partial
from contextlib import contextmanager
from .sql import Composable
from .rows import tuple_row, TupleRow
from .proto import AdaptContext, ConnectionType, Params, PQGen, PQGenConn
-from .proto import Query, Row, RowConn, RowFactory, RV
+from .proto import Query, Row, RowFactory, RV
from .cursor import Cursor, AsyncCursor
from .conninfo import make_conninfo, ConnectionInfo
from .generators import notifies
connect: Callable[[str], PQGenConn["PGconn"]]
execute: Callable[["PGconn"], PQGen[List["PGresult"]]]
+# Row Type variable for Cursor (when it needs to be distinguished from the
+# connection's one)
+CursorRow = TypeVar("CursorRow")
+
if TYPE_CHECKING:
from .pq.proto import PGconn, PGresult
from .pool.base import BasePool
NotifyHandler = Callable[[Notify], None]
-class BaseConnection(AdaptContext, Generic[RowConn]):
+class BaseConnection(AdaptContext, Generic[Row]):
"""
Base class for different types of connections.
ConnStatus = pq.ConnStatus
TransactionStatus = pq.TransactionStatus
- def __init__(self, pgconn: "PGconn", row_factory: RowFactory[RowConn]):
+ def __init__(self, pgconn: "PGconn", row_factory: RowFactory[Row]):
self.pgconn = pgconn # TODO: document this
self._row_factory = row_factory
self._autocommit = False
return self._adapters
@property
- def connection(self) -> "BaseConnection[RowConn]":
+ def connection(self) -> "BaseConnection[Row]":
# implement the AdaptContext protocol
return self
@property
- def row_factory(self) -> RowFactory[RowConn]:
+ def row_factory(self) -> RowFactory[Row]:
"""Writable attribute to control how result rows are formed."""
return self._row_factory
@row_factory.setter
- def row_factory(self, row_factory: RowFactory[RowConn]) -> None:
+ def row_factory(self, row_factory: RowFactory[Row]) -> None:
self._row_factory = row_factory
def fileno(self) -> int:
@staticmethod
def _notice_handler(
- wself: "ReferenceType[BaseConnection[RowConn]]", res: "PGresult"
+ wself: "ReferenceType[BaseConnection[Row]]", res: "PGresult"
) -> None:
self = wself()
if not (self and self._notice_handler):
@staticmethod
def _notify_handler(
- wself: "ReferenceType[BaseConnection[RowConn]]", pgn: pq.PGnotify
+ wself: "ReferenceType[BaseConnection[Row]]", pgn: pq.PGnotify
) -> None:
self = wself()
if not (self and self._notify_handlers):
yield from self._exec_command(b"rollback")
-class Connection(BaseConnection[RowConn]):
+class Connection(BaseConnection[Row]):
"""
Wrapper for a connection to the database.
"""
__module__ = "psycopg3"
- def __init__(self, pgconn: "PGconn", row_factory: RowFactory[RowConn]):
+ def __init__(self, pgconn: "PGconn", row_factory: RowFactory[Row]):
super().__init__(pgconn, row_factory)
self.lock = threading.Lock()
conninfo: str = "",
*,
autocommit: bool = False,
- row_factory: RowFactory[RowConn],
+ row_factory: RowFactory[Row],
**kwargs: Union[None, int, str],
- ) -> "Connection[RowConn]":
+ ) -> "Connection[Row]":
...
@overload
conninfo: str = "",
*,
autocommit: bool = False,
- row_factory: Optional[RowFactory[RowConn]] = None,
+ row_factory: Optional[RowFactory[Row]] = None,
**kwargs: Any,
) -> "Connection[Any]":
"""
)
)
- def __enter__(self) -> "Connection[RowConn]":
+ def __enter__(self) -> "Connection[Row]":
return self
def __exit__(
self.pgconn.finish()
@overload
- def cursor(self, *, binary: bool = False) -> Cursor[RowConn]:
+ def cursor(self, *, binary: bool = False) -> Cursor[Row]:
...
@overload
def cursor(
- self, *, binary: bool = False, row_factory: RowFactory[Row]
- ) -> Cursor[Row]:
+ self, *, binary: bool = False, row_factory: RowFactory[CursorRow]
+ ) -> Cursor[CursorRow]:
...
@overload
- def cursor(
- self, name: str, *, binary: bool = False
- ) -> ServerCursor[RowConn]:
+ def cursor(self, name: str, *, binary: bool = False) -> ServerCursor[Row]:
...
@overload
def cursor(
- self, name: str, *, binary: bool = False, row_factory: RowFactory[Row]
- ) -> ServerCursor[Row]:
+ self,
+ name: str,
+ *,
+ binary: bool = False,
+ row_factory: RowFactory[CursorRow],
+ ) -> ServerCursor[CursorRow]:
...
def cursor(
params: Optional[Params] = None,
*,
prepare: Optional[bool] = None,
- ) -> Cursor[RowConn]:
+ ) -> Cursor[Row]:
"""Execute a query and return a cursor to read its results."""
cur = self.cursor()
try:
self.wait(self._set_client_encoding_gen(name))
-class AsyncConnection(BaseConnection[RowConn]):
+class AsyncConnection(BaseConnection[Row]):
"""
Asynchronous wrapper for a connection to the database.
"""
__module__ = "psycopg3"
- def __init__(self, pgconn: "PGconn", row_factory: RowFactory[RowConn]):
+ def __init__(self, pgconn: "PGconn", row_factory: RowFactory[Row]):
super().__init__(pgconn, row_factory)
self.lock = asyncio.Lock()
conninfo: str = "",
*,
autocommit: bool = False,
- row_factory: RowFactory[RowConn],
+ row_factory: RowFactory[Row],
**kwargs: Union[None, int, str],
- ) -> "AsyncConnection[RowConn]":
+ ) -> "AsyncConnection[Row]":
...
@overload
conninfo: str = "",
*,
autocommit: bool = False,
- row_factory: Optional[RowFactory[RowConn]] = None,
+ row_factory: Optional[RowFactory[Row]] = None,
**kwargs: Any,
) -> "AsyncConnection[Any]":
return await cls._wait_conn(
)
)
- async def __aenter__(self) -> "AsyncConnection[RowConn]":
+ async def __aenter__(self) -> "AsyncConnection[Row]":
return self
async def __aexit__(
self.pgconn.finish()
@overload
- def cursor(self, *, binary: bool = False) -> AsyncCursor[RowConn]:
+ def cursor(self, *, binary: bool = False) -> AsyncCursor[Row]:
...
@overload
def cursor(
- self, *, binary: bool = False, row_factory: RowFactory[Row]
- ) -> AsyncCursor[Row]:
+ self, *, binary: bool = False, row_factory: RowFactory[CursorRow]
+ ) -> AsyncCursor[CursorRow]:
...
@overload
def cursor(
self, name: str, *, binary: bool = False
- ) -> AsyncServerCursor[RowConn]:
+ ) -> AsyncServerCursor[Row]:
...
@overload
def cursor(
- self, name: str, *, binary: bool = False, row_factory: RowFactory[Row]
- ) -> AsyncServerCursor[Row]:
+ self,
+ name: str,
+ *,
+ binary: bool = False,
+ row_factory: RowFactory[CursorRow],
+ ) -> AsyncServerCursor[CursorRow]:
...
def cursor(
params: Optional[Params] = None,
*,
prepare: Optional[bool] = None,
- ) -> AsyncCursor[RowConn]:
+ ) -> AsyncCursor[Row]:
cur = self.cursor()
try:
return await cur.execute(query, params, prepare=prepare)