]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Using generics to describe sync/async types
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 12 Nov 2020 19:09:42 +0000 (19:09 +0000)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 12 Nov 2020 19:09:42 +0000 (19:09 +0000)
psycopg3/psycopg3/adapt.py
psycopg3/psycopg3/connection.py
psycopg3/psycopg3/copy.py
psycopg3/psycopg3/cursor.py
psycopg3/psycopg3/pq/_pq_ctypes.pyi
psycopg3/psycopg3/proto.py

index 93f26a774961af9766e868f2c4e17c4870009d1a..7055232d31de32725b5763ebbfc9f4a8fb67a343 100644 (file)
@@ -4,7 +4,7 @@ Entry point into the adaptation system.
 
 # Copyright (C) 2020 The Psycopg Team
 
-from typing import Any, Callable, Optional, Type, Union
+from typing import Any, cast, Callable, Optional, Type, Union
 
 from . import pq
 from . import proto
@@ -139,7 +139,7 @@ def _connection_from_context(
     elif isinstance(context, BaseConnection):
         return context
     elif isinstance(context, BaseCursor):
-        return context.connection
+        return cast(BaseConnection, context.connection)
     elif isinstance(context, Transformer):
         return context.connection
     else:
index 7737098e2ed528a58f3de11cbafdcbd4fe17af2c..0ebbf00c404ef47522dd56fda72684dd74de5c7a 100644 (file)
@@ -9,16 +9,16 @@ import asyncio
 import threading
 from types import TracebackType
 from typing import Any, AsyncIterator, Callable, Iterator, List, NamedTuple
-from typing import Optional, Type, cast
+from typing import Optional, Type, TYPE_CHECKING, Union
 from weakref import ref, ReferenceType
 from functools import partial
 
 from . import pq
-from . import proto
 from . import cursor
 from . import errors as e
 from . import encodings
 from .pq import TransactionStatus, ExecStatus
+from .proto import DumpersMap, LoadersMap, PQGen, RV
 from .waiting import wait, wait_async
 from .conninfo import make_conninfo
 from .generators import notifies
@@ -26,8 +26,12 @@ from .generators import notifies
 logger = logging.getLogger(__name__)
 package_logger = logging.getLogger("psycopg3")
 
-connect: Callable[[str], proto.PQGen[pq.proto.PGconn]]
-execute: Callable[[pq.proto.PGconn], proto.PQGen[List[pq.proto.PGresult]]]
+connect: Callable[[str], PQGen["PGconn"]]
+execute: Callable[["PGconn"], PQGen[List["PGresult"]]]
+
+if TYPE_CHECKING:
+    from .pq.proto import PGconn, PGresult
+    from .cursor import Cursor, AsyncCursor
 
 if pq.__impl__ == "c":
     from psycopg3_c import _psycopg3
@@ -83,12 +87,13 @@ class BaseConnection:
     ConnStatus = pq.ConnStatus
     TransactionStatus = pq.TransactionStatus
 
-    def __init__(self, pgconn: pq.proto.PGconn):
+    cursor_factory: Union[Type["Cursor"], Type["AsyncCursor"]]
+
+    def __init__(self, pgconn: "PGconn"):
         self.pgconn = pgconn  # TODO: document this
-        self.cursor_factory = cursor.BaseCursor
         self._autocommit = False
-        self.dumpers: proto.DumpersMap = {}
-        self.loaders: proto.LoadersMap = {}
+        self.dumpers: DumpersMap = {}
+        self.loaders: LoadersMap = {}
         self._notice_handlers: List[NoticeHandler] = []
         self._notify_handlers: List[NotifyHandler] = []
 
@@ -122,13 +127,6 @@ class BaseConnection:
             )
         self._autocommit = value
 
-    def _cursor(
-        self, name: str = "", format: pq.Format = pq.Format.TEXT
-    ) -> cursor.BaseCursor:
-        if name:
-            raise NotImplementedError
-        return self.cursor_factory(self, format=format)
-
     @property
     def client_encoding(self) -> str:
         """The Python codec name of the connection's client encoding."""
@@ -161,7 +159,7 @@ class BaseConnection:
 
     @staticmethod
     def _notice_handler(
-        wself: "ReferenceType[BaseConnection]", res: pq.proto.PGresult
+        wself: "ReferenceType[BaseConnection]", res: "PGresult"
     ) -> None:
         self = wself()
         if not (self and self._notice_handler):
@@ -209,7 +207,7 @@ class Connection(BaseConnection):
 
     cursor_factory: Type[cursor.Cursor]
 
-    def __init__(self, pgconn: pq.proto.PGconn):
+    def __init__(self, pgconn: "PGconn"):
         super().__init__(pgconn)
         self.lock = threading.Lock()
         self.cursor_factory = cursor.Cursor
@@ -257,8 +255,10 @@ class Connection(BaseConnection):
         """
         Return a new `Cursor` to send commands and queries to the connection.
         """
-        cur = self._cursor(name, format=format)
-        return cast(cursor.Cursor, cur)
+        if name:
+            raise NotImplementedError
+
+        return self.cursor_factory(self, format=format)
 
     def _start_query(self) -> None:
         # the function is meant to be called by a cursor once the lock is taken
@@ -301,9 +301,7 @@ class Connection(BaseConnection):
             )
 
     @classmethod
-    def wait(
-        cls, gen: proto.PQGen[proto.RV], timeout: Optional[float] = 0.1
-    ) -> proto.RV:
+    def wait(cls, gen: PQGen[RV], timeout: Optional[float] = 0.1) -> RV:
         return wait(gen, timeout=timeout)
 
     def _set_client_encoding(self, name: str) -> None:
@@ -345,7 +343,7 @@ class AsyncConnection(BaseConnection):
 
     cursor_factory: Type[cursor.AsyncCursor]
 
-    def __init__(self, pgconn: pq.proto.PGconn):
+    def __init__(self, pgconn: "PGconn"):
         super().__init__(pgconn)
         self.lock = asyncio.Lock()
         self.cursor_factory = cursor.AsyncCursor
@@ -386,8 +384,10 @@ class AsyncConnection(BaseConnection):
         """
         Return a new `AsyncCursor` to send commands and queries to the connection.
         """
-        cur = self._cursor(name, format=format)
-        return cast(cursor.AsyncCursor, cur)
+        if name:
+            raise NotImplementedError
+
+        return self.cursor_factory(self, format=format)
 
     async def _start_query(self) -> None:
         # the function is meant to be called by a cursor once the lock is taken
@@ -428,7 +428,7 @@ class AsyncConnection(BaseConnection):
             )
 
     @classmethod
-    async def wait(cls, gen: proto.PQGen[proto.RV]) -> proto.RV:
+    async def wait(cls, gen: PQGen[RV]) -> RV:
         return await wait_async(gen)
 
     def _set_client_encoding(self, name: str) -> None:
index f80900787e63246e44dd95936f5c73e25d749a30..c6201287091809953784c4b77e76245c662207c3 100644 (file)
@@ -6,36 +6,35 @@ psycopg3 copy support
 
 import re
 import struct
-from typing import TYPE_CHECKING, AsyncIterator, Iterator
+from typing import TYPE_CHECKING, AsyncIterator, Iterator, Generic
 from typing import Any, Dict, List, Match, Optional, Sequence, Type, Union
 from types import TracebackType
 
-from . import pq
-from .proto import AdaptContext
+from .pq import Format
+from .proto import ConnectionType, Transformer
 from .generators import copy_from, copy_to, copy_end
 
 if TYPE_CHECKING:
-    from .connection import BaseConnection, Connection, AsyncConnection
+    from .pq.proto import PGresult
+    from .connection import Connection, AsyncConnection  # noqa: F401
 
 
-class BaseCopy:
+class BaseCopy(Generic[ConnectionType]):
     def __init__(
         self,
-        context: AdaptContext,
-        result: Optional[pq.proto.PGresult],
-        format: pq.Format = pq.Format.TEXT,
+        connection: ConnectionType,
+        transformer: Transformer,
+        result: "PGresult",
     ):
-        from .adapt import Transformer
-
-        self._connection: Optional["BaseConnection"] = None
-        self._transformer = Transformer(context)
-        self.format = format
+        self.connection = connection
+        self._transformer = transformer
         self.pgresult = result
+        self.format = result.binary_tuples
         self._first_row = True
         self._finished = False
         self._encoding: str = ""
 
-        if format == pq.Format.TEXT:
+        if self.format == Format.TEXT:
             self._format_row = self._format_row_text
         else:
             self._format_row = self._format_row_binary
@@ -45,22 +44,11 @@ class BaseCopy:
         return self._finished
 
     @property
-    def connection(self) -> "BaseConnection":
-        if self._connection:
-            return self._connection
-
-        self._connection = conn = self._transformer.connection
-        if conn:
-            return conn
-
-        raise ValueError("no connection available")
-
-    @property
-    def pgresult(self) -> Optional[pq.proto.PGresult]:
+    def pgresult(self) -> Optional["PGresult"]:
         return self._pgresult
 
     @pgresult.setter
-    def pgresult(self, result: Optional[pq.proto.PGresult]) -> None:
+    def pgresult(self, result: Optional["PGresult"]) -> None:
         self._pgresult = result
         self._transformer.pgresult = result
 
@@ -74,7 +62,7 @@ class BaseCopy:
 
             if (
                 self.pgresult is None
-                or self.pgresult.binary_tuples == pq.Format.BINARY
+                or self.pgresult.binary_tuples == Format.BINARY
             ):
                 raise TypeError(
                     "cannot copy str data in binary mode: use bytes instead"
@@ -151,15 +139,7 @@ def _bsrepl_sub(
 _bsrepl_re = re.compile(b"[\b\t\n\v\f\r\\\\]")
 
 
-class Copy(BaseCopy):
-    _connection: Optional["Connection"]
-
-    @property
-    def connection(self) -> "Connection":
-        # TODO: mypy error: "Callable[[BaseCopy], BaseConnection]" has no
-        # attribute "fget"
-        return BaseCopy.connection.fget(self)  # type: ignore
-
+class Copy(BaseCopy["Connection"]):
     def read(self) -> Optional[bytes]:
         if self._finished:
             return None
@@ -195,7 +175,7 @@ class Copy(BaseCopy):
         exc_tb: Optional[TracebackType],
     ) -> None:
         if exc_val is None:
-            if self.format == pq.Format.BINARY and not self._first_row:
+            if self.format == Format.BINARY and not self._first_row:
                 # send EOF only if we copied binary rows (_first_row is False)
                 self.write(b"\xff\xff")
             self.finish()
@@ -210,13 +190,7 @@ class Copy(BaseCopy):
             yield data
 
 
-class AsyncCopy(BaseCopy):
-    _connection: Optional["AsyncConnection"]
-
-    @property
-    def connection(self) -> "AsyncConnection":
-        return BaseCopy.connection.fget(self)  # type: ignore
-
+class AsyncCopy(BaseCopy["AsyncConnection"]):
     async def read(self) -> Optional[bytes]:
         if self._finished:
             return None
@@ -252,7 +226,7 @@ class AsyncCopy(BaseCopy):
         exc_tb: Optional[TracebackType],
     ) -> None:
         if exc_val is None:
-            if self.format == pq.Format.BINARY and not self._first_row:
+            if self.format == Format.BINARY and not self._first_row:
                 # send EOF only if we copied binary rows (_first_row is False)
                 await self.write(b"\xff\xff")
             await self.finish()
index 18c76afd895588b580ad57069885285b619011c4..66a490f786e91393bb437cd552d18b937fe582de 100644 (file)
@@ -5,23 +5,32 @@ psycopg3 cursor objects
 # Copyright (C) 2020 The Psycopg Team
 
 from types import TracebackType
-from typing import Any, AsyncIterator, Callable, Iterator, List, Mapping
+from typing import (
+    Any,
+    AsyncIterator,
+    Callable,
+    Generic,
+    Iterator,
+    List,
+    Mapping,
+)
 from typing import Optional, Sequence, Type, TYPE_CHECKING, Union
 from operator import attrgetter
 
 from . import errors as e
 from . import pq
 from . import sql
-from . import proto
 from .oids import builtins
 from .copy import Copy, AsyncCopy
-from .proto import Query, Params, DumpersMap, LoadersMap, PQGen
+from .proto import ConnectionType, Query, Params, DumpersMap, LoadersMap, PQGen
 from .utils.queries import PostgresQuery
 
 if TYPE_CHECKING:
-    from .connection import BaseConnection, Connection, AsyncConnection
+    from .proto import Transformer
+    from .pq.proto import PGconn, PGresult
+    from .connection import Connection, AsyncConnection  # noqa: F401
 
-execute: Callable[[pq.proto.PGconn], PQGen[List[pq.proto.PGresult]]]
+execute: Callable[["PGconn"], PQGen[List["PGresult"]]]
 
 if pq.__impl__ == "c":
     from psycopg3_c import _psycopg3
@@ -35,7 +44,7 @@ else:
 
 
 class Column(Sequence[Any]):
-    def __init__(self, pgresult: pq.proto.PGresult, index: int, encoding: str):
+    def __init__(self, pgresult: "PGresult", index: int, encoding: str):
         self._pgresult = pgresult
         self._index = index
         self._encoding = encoding
@@ -150,13 +159,15 @@ class Column(Sequence[Any]):
         return None
 
 
-class BaseCursor:
+class BaseCursor(Generic[ConnectionType]):
     ExecStatus = pq.ExecStatus
 
-    _transformer: proto.Transformer
+    _transformer: "Transformer"
 
     def __init__(
-        self, connection: "BaseConnection", format: pq.Format = pq.Format.TEXT
+        self,
+        connection: ConnectionType,
+        format: pq.Format = pq.Format.TEXT,
     ):
         self.connection = connection
         self.format = format
@@ -167,7 +178,7 @@ class BaseCursor:
         self._closed = False
 
     def _reset(self) -> None:
-        self._results: List[pq.proto.PGresult] = []
+        self._results: List["PGresult"] = []
         self.pgresult = None
         self._pos = 0
         self._iresult = 0
@@ -185,12 +196,12 @@ class BaseCursor:
         return res.status if res else None
 
     @property
-    def pgresult(self) -> Optional[pq.proto.PGresult]:
+    def pgresult(self) -> Optional["PGresult"]:
         """The `~psycopg3.pq.PGresult` exposed by the cursor."""
         return self._pgresult
 
     @pgresult.setter
-    def pgresult(self, result: Optional[pq.proto.PGresult]) -> None:
+    def pgresult(self, result: Optional["PGresult"]) -> None:
         self._pgresult = result
         if result and self._transformer:
             self._transformer.pgresult = result
@@ -236,7 +247,7 @@ class BaseCursor:
             return None
 
     def _start_query(self) -> None:
-        from .adapt import Transformer
+        from . import adapt
 
         if self.closed:
             raise e.InterfaceError("the cursor is closed")
@@ -251,7 +262,7 @@ class BaseCursor:
             )
 
         self._reset()
-        self._transformer = Transformer(self)
+        self._transformer = adapt.Transformer(self)
 
     def _execute_send(
         self, query: Query, vars: Optional[Params], no_pqexec: bool = False
@@ -275,7 +286,7 @@ class BaseCursor:
             # one query in one go
             self.connection.pgconn.send_query(pgq.query)
 
-    def _execute_results(self, results: Sequence[pq.proto.PGresult]) -> None:
+    def _execute_results(self, results: Sequence["PGresult"]) -> None:
         """
         Implement part of execute() after waiting common to sync and async
         """
@@ -393,9 +404,7 @@ class BaseCursor:
         qparts.append(sql.SQL(")"))
         return sql.Composed(qparts)
 
-    def _check_copy_results(
-        self, results: Sequence[pq.proto.PGresult]
-    ) -> None:
+    def _check_copy_results(self, results: Sequence["PGresult"]) -> None:
         """
         Check that the value returned in a copy() operation is a legit COPY.
         """
@@ -419,14 +428,7 @@ class BaseCursor:
             )
 
 
-class Cursor(BaseCursor):
-    connection: "Connection"
-
-    def __init__(
-        self, connection: "Connection", format: pq.Format = pq.Format.TEXT
-    ):
-        super().__init__(connection, format=format)
-
+class Cursor(BaseCursor["Connection"]):
     def __enter__(self) -> "Cursor":
         return self
 
@@ -563,22 +565,16 @@ class Cursor(BaseCursor):
             self._execute_send(statement, vars, no_pqexec=True)
             gen = execute(self.connection.pgconn)
             results = self.connection.wait(gen)
-            tx = self._transformer
 
         self._check_copy_results(results)
         return Copy(
-            context=tx, result=results[0], format=results[0].binary_tuples
+            connection=self.connection,
+            transformer=self._transformer,
+            result=results[0],
         )
 
 
-class AsyncCursor(BaseCursor):
-    connection: "AsyncConnection"
-
-    def __init__(
-        self, connection: "AsyncConnection", format: pq.Format = pq.Format.TEXT
-    ):
-        super().__init__(connection, format=format)
-
+class AsyncCursor(BaseCursor["AsyncConnection"]):
     async def __aenter__(self) -> "AsyncCursor":
         return self
 
@@ -700,11 +696,12 @@ class AsyncCursor(BaseCursor):
             self._execute_send(statement, vars, no_pqexec=True)
             gen = execute(self.connection.pgconn)
             results = await self.connection.wait(gen)
-            tx = self._transformer
 
         self._check_copy_results(results)
         return AsyncCopy(
-            context=tx, result=results[0], format=results[0].binary_tuples
+            connection=self.connection,
+            transformer=self._transformer,
+            result=results[0],
         )
 
 
index ea443727bb4dc91033de9a6600df6a6bb5b67bda..8fb7a13500f00e37e13b18800be7a9fac834e7ef 100644 (file)
@@ -4,7 +4,7 @@ types stub for ctypes functions
 
 # Copyright (C) 2020 The Psycopg Team
 
-from typing import Any, Callable, Optional, Sequence, NewType
+from typing import Any, Callable, Optional, Sequence
 from ctypes import Array, pointer
 from ctypes import c_char, c_char_p, c_int, c_ubyte, c_uint, c_ulong
 
index c0bdfd66453777d92792df5df4b9e71383967a3c..1e26e528af49afe207a6353a96f847cf1e5c0d64 100644 (file)
@@ -21,6 +21,7 @@ if TYPE_CHECKING:
 
 Query = Union[str, bytes, "Composable"]
 Params = Union[Sequence[Any], Mapping[str, Any]]
+ConnectionType = TypeVar("ConnectionType", bound="BaseConnection")
 
 
 # Waiting protocol types