]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added most type annotations
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 29 Mar 2020 14:39:21 +0000 (03:39 +1300)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 29 Mar 2020 15:37:55 +0000 (04:37 +1300)
mypy almost passes in --strict mode

17 files changed:
psycopg3/adaptation.py
psycopg3/connection.py
psycopg3/conninfo.py
psycopg3/cursor.py
psycopg3/exceptions.py
psycopg3/pq/__init__.py
psycopg3/pq/_pq_ctypes.py
psycopg3/pq/misc.py
psycopg3/pq/pq_ctypes.py
psycopg3/types/numeric.py
psycopg3/types/oids.py
psycopg3/types/text.py
psycopg3/utils/queries.py
psycopg3/utils/typing.py [new file with mode: 0644]
psycopg3/waiting.py
tests/test_query.py
tox.ini

index 710c30f2f72133770133e187fbe6d1eb326a1384..bf7ab98d6dadee78c87de792d0397af90d84ea98 100644 (file)
@@ -5,28 +5,59 @@ Entry point into the adaptation system.
 # Copyright (C) 2020 The Psycopg Team
 
 import codecs
-from typing import Dict, Tuple
+from typing import (
+    Any,
+    Callable,
+    cast,
+    Dict,
+    Generator,
+    List,
+    Optional,
+    Sequence,
+    Tuple,
+    Union,
+)
 from functools import partial
 
 from . import exceptions as exc
-from .pq import Format
+from .pq import Format, PGresult
 from .cursor import BaseCursor
 from .types.oids import type_oid, INVALID_OID
 from .connection import BaseConnection
+from .utils.typing import DecodeFunc
+
+
+# Type system
+
+AdaptContext = Union[BaseConnection, BaseCursor]
+
+MaybeOid = Union[Optional[bytes], Tuple[Optional[bytes], int]]
+AdapterFunc = Callable[[Any], MaybeOid]
+AdapterType = Union["Adapter", AdapterFunc]
+AdaptersMap = Dict[Tuple[type, Format], AdapterType]
+
+TypecasterFunc = Callable[[bytes], Any]
+TypecasterType = Union["Typecaster", TypecasterFunc]
+TypecastersMap = Dict[Tuple[int, Format], TypecasterType]
 
 
 class Adapter:
-    globals: Dict[Tuple[type, Format], "Adapter"] = {}  # TODO: incomplete type
+    globals: AdaptersMap = {}
 
-    def __init__(self, cls, conn):
+    def __init__(self, cls: type, conn: BaseConnection):
         self.cls = cls
         self.conn = conn
 
-    def adapt(self, obj):
+    def adapt(self, obj: Any) -> Union[bytes, Tuple[bytes, int]]:
         raise NotImplementedError()
 
     @staticmethod
-    def register(cls, adapter=None, context=None, format=Format.TEXT):
+    def register(
+        cls: type,
+        adapter: Optional[AdapterType] = None,
+        context: Optional[AdaptContext] = None,
+        format: Format = Format.TEXT,
+    ) -> AdapterType:
         if adapter is None:
             # used as decorator
             return partial(Adapter.register, cls, format=format)
@@ -58,23 +89,31 @@ class Adapter:
         return adapter
 
     @staticmethod
-    def register_binary(cls, adapter=None, context=None):
+    def register_binary(
+        cls: type,
+        adapter: Optional[AdapterType] = None,
+        context: Optional[AdaptContext] = None,
+    ) -> AdapterType:
         return Adapter.register(cls, adapter, context, format=Format.BINARY)
 
 
 class Typecaster:
-    # TODO: incomplete type
-    globals: Dict[Tuple[type, Format], "Typecaster"] = {}
+    globals: TypecastersMap = {}
 
-    def __init__(self, oid, conn):
+    def __init__(self, oid: int, conn: Optional[BaseConnection]):
         self.oid = oid
         self.conn = conn
 
-    def cast(self, data):
+    def cast(self, data: bytes) -> Any:
         raise NotImplementedError()
 
     @staticmethod
-    def register(oid, caster=None, context=None, format=Format.TEXT):
+    def register(
+        oid: int,
+        caster: Optional[TypecasterType] = None,
+        context: Optional[AdaptContext] = None,
+        format: Format = Format.TEXT,
+    ) -> TypecasterType:
         if caster is None:
             # used as decorator
             return partial(Typecaster.register, oid, format=format)
@@ -101,12 +140,16 @@ class Typecaster:
                 f" got {caster} instead"
             )
 
-        where = context.adapters if context is not None else Typecaster.globals
+        where = context.casters if context is not None else Typecaster.globals
         where[oid, format] = caster
         return caster
 
     @staticmethod
-    def register_binary(oid, caster=None, context=None):
+    def register_binary(
+        oid: int,
+        caster: Optional[TypecasterType] = None,
+        context: Optional[AdaptContext] = None,
+    ) -> TypecasterType:
         return Typecaster.register(oid, caster, context, format=Format.BINARY)
 
 
@@ -119,7 +162,10 @@ class Transformer:
     state so adapting several values of the same type can use optimisations.
     """
 
-    def __init__(self, context):
+    connection: Optional[BaseConnection]
+    cursor: Optional[BaseCursor]
+
+    def __init__(self, context: Optional[AdaptContext]):
         if context is None:
             self.connection = None
             self.cursor = None
@@ -136,24 +182,24 @@ class Transformer:
             )
 
         # mapping class, fmt -> adaptation function
-        self._adapt_funcs = {}
+        self._adapt_funcs: Dict[Tuple[type, Format], AdapterFunc] = {}
 
         # mapping oid, fmt -> cast function
-        self._cast_funcs = {}
+        self._cast_funcs: Dict[Tuple[int, Format], TypecasterFunc] = {}
 
         # The result to return values from
-        self._result = None
+        self._result: Optional[PGresult] = None
 
         # sequence of cast function from value to python
         # the length of the result columns
-        self._row_casters = None
+        self._row_casters: List[TypecasterFunc] = []
 
     @property
-    def result(self):
+    def result(self) -> Optional[PGresult]:
         return self._result
 
     @result.setter
-    def result(self, result):
+    def result(self, result: PGresult) -> None:
         if self._result is result:
             return
 
@@ -164,7 +210,9 @@ class Transformer:
             func = self.get_cast_function(oid, fmt)
             rc.append(func)
 
-    def adapt_sequence(self, objs, fmts):
+    def adapt_sequence(
+        self, objs: Sequence[Any], fmts: Sequence[Format]
+    ) -> Tuple[List[Optional[bytes]], List[int]]:
         out = []
         types = []
 
@@ -181,7 +229,7 @@ class Transformer:
 
         return out, types
 
-    def adapt(self, obj, fmt):
+    def adapt(self, obj: None, fmt: Format) -> MaybeOid:
         if obj is None:
             return None, type_oid["text"]
 
@@ -189,7 +237,7 @@ class Transformer:
         func = self.get_adapt_function(cls, fmt)
         return func(obj)
 
-    def get_adapt_function(self, cls, fmt):
+    def get_adapt_function(self, cls: type, fmt: Format) -> AdapterFunc:
         try:
             return self._adapt_funcs[cls, fmt]
         except KeyError:
@@ -197,11 +245,11 @@ class Transformer:
 
         adapter = self.lookup_adapter(cls, fmt)
         if isinstance(adapter, type):
-            adapter = adapter(cls, self.connection).adapt
-
-        return adapter
+            return adapter(cls, self.connection).adapt
+        else:
+            return cast(AdapterFunc, adapter)
 
-    def lookup_adapter(self, cls, fmt):
+    def lookup_adapter(self, cls: type, fmt: Format) -> AdapterType:
         key = (cls, fmt)
 
         cur = self.cursor
@@ -219,7 +267,7 @@ class Transformer:
             f"cannot adapt type {cls.__name__} to format {Format(fmt).name}"
         )
 
-    def cast_row(self, result, n):
+    def cast_row(self, result: PGresult, n: int) -> Generator[Any, None, None]:
         self.result = result
 
         for col, func in enumerate(self._row_casters):
@@ -228,7 +276,7 @@ class Transformer:
                 v = func(v)
             yield v
 
-    def get_cast_function(self, oid, fmt):
+    def get_cast_function(self, oid: int, fmt: Format) -> TypecasterFunc:
         try:
             return self._cast_funcs[oid, fmt]
         except KeyError:
@@ -236,11 +284,11 @@ class Transformer:
 
         caster = self.lookup_caster(oid, fmt)
         if isinstance(caster, type):
-            caster = caster(oid, self.connection).cast
-
-        return caster
+            return caster(oid, self.connection).cast
+        else:
+            return cast(TypecasterFunc, caster)
 
-    def lookup_caster(self, oid, fmt):
+    def lookup_caster(self, oid: int, fmt: Format) -> TypecasterType:
         key = (oid, fmt)
 
         cur = self.cursor
@@ -263,17 +311,18 @@ class UnknownCaster(Typecaster):
     Fallback object to convert unknown types to Python
     """
 
-    def __init__(self, oid, conn):
+    def __init__(self, oid: int, conn: Optional[BaseConnection]):
         super().__init__(oid, conn)
+        self.decode: DecodeFunc
         if conn is not None:
             self.decode = conn.codec.decode
         else:
             self.decode = codecs.lookup("utf8").decode
 
-    def cast(self, data):
+    def cast(self, data: bytes) -> str:
         return self.decode(data)[0]
 
 
 @Typecaster.register_binary(INVALID_OID)
-def cast_unknown(data):
+def cast_unknown(data: bytes) -> bytes:
     return data
index 21e714ddf4f9c1e30acc3cff6dccf4a250eb9ba8..7d89c1b7ba6de9018bb7e9e7fba62e7a029053d9 100644 (file)
@@ -8,6 +8,17 @@ import codecs
 import logging
 import asyncio
 import threading
+from typing import (
+    cast,
+    Any,
+    Generator,
+    List,
+    Optional,
+    Tuple,
+    Type,
+    TypeVar,
+    TYPE_CHECKING,
+)
 
 from . import pq
 from . import exceptions as exc
@@ -17,6 +28,13 @@ from .waiting import wait_select, wait_async, Wait, Ready
 
 logger = logging.getLogger(__name__)
 
+ConnectGen = Generator[Tuple[int, Wait], Ready, pq.PGconn]
+QueryGen = Generator[Tuple[int, Wait], Ready, List[pq.PGresult]]
+RV = TypeVar("RV")
+
+if TYPE_CHECKING:
+    from .adaptation import AdaptersMap, TypecastersMap
+
 
 class BaseConnection:
     """
@@ -26,38 +44,40 @@ class BaseConnection:
     allow different interfaces (sync/async).
     """
 
-    def __init__(self, pgconn):
+    def __init__(self, pgconn: pq.PGconn):
         self.pgconn = pgconn
-        self.cursor_factory = None
-        self.adapters = {}
-        self.casters = {}
+        self.cursor_factory = cursor.BaseCursor
+        self.adapters: AdaptersMap = {}
+        self.casters: TypecastersMap = {}
         # name of the postgres encoding (in bytes)
-        self.pgenc = None
+        self.pgenc = b""
 
-    def cursor(self, name=None, binary=False):
+    def cursor(
+        self, name: Optional[str] = None, binary: bool = False
+    ) -> cursor.BaseCursor:
         if name is not None:
             raise NotImplementedError
         return self.cursor_factory(self, binary=binary)
 
     @property
-    def codec(self):
+    def codec(self) -> codecs.CodecInfo:
         # TODO: utf8 fastpath?
         pgenc = self.pgconn.parameter_status(b"client_encoding")
         if self.pgenc != pgenc:
             # for unknown encodings and SQL_ASCII be strict and use ascii
-            pyenc = pq.py_codecs.get(pgenc.decode("ascii"), "ascii")
+            pyenc = pq.py_codecs.get(pgenc.decode("ascii")) or "ascii"
             self._codec = codecs.lookup(pyenc)
             self.pgenc = pgenc
         return self._codec
 
-    def encode(self, s):
+    def encode(self, s: str) -> bytes:
         return self.codec.encode(s)[0]
 
-    def decode(self, b):
+    def decode(self, b: bytes) -> str:
         return self.codec.decode(b)[0]
 
     @classmethod
-    def _connect_gen(cls, conninfo):
+    def _connect_gen(cls, conninfo: str) -> ConnectGen:
         """
         Generator to create a database connection without blocking.
 
@@ -65,9 +85,7 @@ class BaseConnection:
         generator can be restarted sending the appropriate `Ready` state when
         the file descriptor is ready.
         """
-        conninfo = conninfo.encode("utf8")
-
-        conn = pq.PGconn.connect_start(conninfo)
+        conn = pq.PGconn.connect_start(conninfo.encode("utf8"))
         logger.debug("connection started, status %s", conn.status.name)
         while 1:
             if conn.status == pq.ConnStatus.BAD:
@@ -94,7 +112,7 @@ class BaseConnection:
         return conn
 
     @classmethod
-    def _exec_gen(cls, pgconn):
+    def _exec_gen(cls, pgconn: pq.PGconn) -> QueryGen:
         """
         Generator returning query results without blocking.
 
@@ -109,7 +127,7 @@ class BaseConnection:
         Return the list of results returned by the database (whether success
         or error).
         """
-        results = []
+        results: List[pq.PGresult] = []
 
         while 1:
             f = pgconn.flush()
@@ -148,13 +166,17 @@ class Connection(BaseConnection):
     This class implements a DBAPI-compliant interface.
     """
 
-    def __init__(self, pgconn):
+    cursor_factory: Type[cursor.Cursor]
+
+    def __init__(self, pgconn: pq.PGconn):
         super().__init__(pgconn)
         self.lock = threading.Lock()
         self.cursor_factory = cursor.Cursor
 
     @classmethod
-    def connect(cls, conninfo, connection_factory=None, **kwargs):
+    def connect(
+        cls, conninfo: str, connection_factory: Any = None, **kwargs: Any
+    ) -> "Connection":
         if connection_factory is not None:
             raise NotImplementedError()
         conninfo = make_conninfo(conninfo, **kwargs)
@@ -162,13 +184,19 @@ class Connection(BaseConnection):
         pgconn = cls.wait(gen)
         return cls(pgconn)
 
-    def commit(self):
+    def cursor(
+        self, name: Optional[str] = None, binary: bool = False
+    ) -> cursor.Cursor:
+        cur = super().cursor(name, binary)
+        return cast(cursor.Cursor, cur)
+
+    def commit(self) -> None:
         self._exec_commit_rollback(b"commit")
 
-    def rollback(self):
+    def rollback(self) -> None:
         self._exec_commit_rollback(b"rollback")
 
-    def _exec_commit_rollback(self, command):
+    def _exec_commit_rollback(self, command: bytes) -> None:
         with self.lock:
             status = self.pgconn.transaction_status
             if status == pq.TransactionStatus.IDLE:
@@ -183,7 +211,7 @@ class Connection(BaseConnection):
                 )
 
     @classmethod
-    def wait(cls, gen):
+    def wait(cls, gen: Generator[Tuple[int, Wait], Ready, RV]) -> RV:
         return wait_select(gen)
 
 
@@ -195,25 +223,33 @@ class AsyncConnection(BaseConnection):
     methods implemented as coroutines.
     """
 
-    def __init__(self, pgconn):
+    cursor_factory: Type[cursor.AsyncCursor]
+
+    def __init__(self, pgconn: pq.PGconn):
         super().__init__(pgconn)
         self.lock = asyncio.Lock()
         self.cursor_factory = cursor.AsyncCursor
 
     @classmethod
-    async def connect(cls, conninfo, **kwargs):
+    async def connect(cls, conninfo: str, **kwargs: Any) -> "AsyncConnection":
         conninfo = make_conninfo(conninfo, **kwargs)
         gen = cls._connect_gen(conninfo)
         pgconn = await cls.wait(gen)
         return cls(pgconn)
 
-    async def commit(self):
+    def cursor(
+        self, name: Optional[str] = None, binary: bool = False
+    ) -> cursor.AsyncCursor:
+        cur = super().cursor(name, binary)
+        return cast(cursor.AsyncCursor, cur)
+
+    async def commit(self) -> None:
         await self._exec_commit_rollback(b"commit")
 
-    async def rollback(self):
+    async def rollback(self) -> None:
         await self._exec_commit_rollback(b"rollback")
 
-    async def _exec_commit_rollback(self, command):
+    async def _exec_commit_rollback(self, command: bytes) -> None:
         async with self.lock:
             status = self.pgconn.transaction_status
             if status == pq.TransactionStatus.IDLE:
@@ -228,5 +264,5 @@ class AsyncConnection(BaseConnection):
                 )
 
     @classmethod
-    async def wait(cls, gen):
+    async def wait(cls, gen: Generator[Tuple[int, Wait], Ready, RV]) -> RV:
         return await wait_async(gen)
index 365c3e3d9c36a3f124b15b1856e9658a4778e1db..81308a2a2ac42c0fc55efe513e61b8ad3612a573 100644 (file)
@@ -1,16 +1,17 @@
 import re
+from typing import Any, Dict, List
 
 from . import pq
 from . import exceptions as exc
 
 
-def make_conninfo(conninfo=None, **kwargs):
+def make_conninfo(conninfo: str = "", **kwargs: Any) -> str:
     """
     Merge a string and keyword params into a single conninfo string.
 
     Raise ProgrammingError if the input don't make a valid conninfo.
     """
-    if conninfo is None and not kwargs:
+    if not conninfo and not kwargs:
         return ""
 
     # If no kwarg is specified don't mung the conninfo but check if it's correct
@@ -37,7 +38,7 @@ def make_conninfo(conninfo=None, **kwargs):
     return conninfo
 
 
-def conninfo_to_dict(conninfo):
+def conninfo_to_dict(conninfo: str) -> Dict[str, str]:
     """
     Convert the *conninfo* string into a dictionary of parameters.
 
@@ -51,7 +52,7 @@ def conninfo_to_dict(conninfo):
     }
 
 
-def _parse_conninfo(conninfo):
+def _parse_conninfo(conninfo: str) -> List[pq.ConninfoOption]:
     """
     Verify that *conninfo* is a valid connection string.
 
@@ -65,9 +66,11 @@ def _parse_conninfo(conninfo):
         raise exc.ProgrammingError(str(e))
 
 
-def _param_escape(
-    s, re_escape=re.compile(r"([\\'])"), re_space=re.compile(r"\s")
-):
+re_escape = re.compile(r"([\\'])")
+re_space = re.compile(r"\s")
+
+
+def _param_escape(s: str) -> str:
     """
     Apply the escaping rule required by PQconnectdb
     """
index e5d63f1f9539b13b3cd23f95492be0546293d12d..ec6416820ca3d9056b1461253b59df01fe6d44e0 100644 (file)
@@ -4,29 +4,43 @@ psycopg3 cursor objects
 
 # Copyright (C) 2020 The Psycopg Team
 
+from typing import Any, List, Mapping, Optional, Sequence, Tuple, TYPE_CHECKING
+
 from . import exceptions as exc
-from .pq import error_message, DiagnosticField, ExecStatus
+from .pq import error_message, DiagnosticField, ExecStatus, PGresult, Format
 from .utils.queries import query2pg, reorder_params
+from .utils.typing import Query, Params
+
+if TYPE_CHECKING:
+    from .connection import (
+        BaseConnection,
+        Connection,
+        AsyncConnection,
+        QueryGen,
+    )
+    from .adaptation import AdaptersMap, TypecastersMap
 
 
 class BaseCursor:
-    def __init__(self, conn, binary=False):
+    def __init__(self, conn: "BaseConnection", binary: bool = False):
         self.conn = conn
         self.binary = binary
-        self.adapters = {}
-        self.casters = {}
+        self.adapters: AdaptersMap = {}
+        self.casters: TypecastersMap = {}
         self._reset()
 
-    def _reset(self):
+    def _reset(self) -> None:
         from .adaptation import Transformer
 
-        self._results = []
-        self._result = None
+        self._results: List[PGresult] = []
+        self._result: Optional[PGresult] = None
         self._pos = 0
         self._iresult = 0
         self._transformer = Transformer(self)
 
-    def _execute_send(self, query, vars):
+    def _execute_send(
+        self, query: Query, vars: Optional[Params]
+    ) -> "QueryGen":
         # Implement part of execute() before waiting common to sync and async
         self._reset()
 
@@ -40,21 +54,23 @@ class BaseCursor:
             query, formats, order = query2pg(query, vars, codec)
         if vars:
             if order is not None:
+                assert isinstance(vars, Mapping)
                 vars = reorder_params(vars, order)
+            assert isinstance(vars, Sequence)
             params, types = self._transformer.adapt_sequence(vars, formats)
             self.conn.pgconn.send_query_params(
                 query,
                 params,
                 param_formats=formats,
                 param_types=types,
-                result_format=int(self.binary),
+                result_format=Format(self.binary),
             )
         else:
             self.conn.pgconn.send_query(query)
 
         return self.conn._exec_gen(self.conn.pgconn)
 
-    def _execute_results(self, results):
+    def _execute_results(self, results: List[PGresult]) -> None:
         # Implement part of execute() after waiting common to sync and async
         if not results:
             raise exc.InternalError("got no result from the query")
@@ -89,20 +105,22 @@ class BaseCursor:
                 f" {', '.join(sorted(s.name for s in sorted(badstats)))}"
             )
 
-    def nextset(self):
+    def nextset(self) -> Optional[bool]:
         self._iresult += 1
         if self._iresult < len(self._results):
             self._result = self._results[self._iresult]
             self._pos = 0
             return True
+        else:
+            return None
 
-    def fetchone(self):
+    def fetchone(self) -> Optional[Sequence[Any]]:
         rv = self._cast_row(self._pos)
         if rv is not None:
             self._pos += 1
         return rv
 
-    def _cast_row(self, n):
+    def _cast_row(self, n: int) -> Optional[Tuple[Any, ...]]:
         if self._result is None:
             return None
         if n >= self._result.ntuples:
@@ -112,7 +130,12 @@ class BaseCursor:
 
 
 class Cursor(BaseCursor):
-    def execute(self, query, vars=None):
+    conn: "Connection"
+
+    def __init__(self, conn: "Connection", binary: bool = False):
+        super().__init__(conn, binary)
+
+    def execute(self, query: Query, vars: Optional[Params] = None) -> "Cursor":
         with self.conn.lock:
             gen = self._execute_send(query, vars)
             results = self.conn.wait(gen)
@@ -121,7 +144,14 @@ class Cursor(BaseCursor):
 
 
 class AsyncCursor(BaseCursor):
-    async def execute(self, query, vars=None):
+    conn: "AsyncConnection"
+
+    def __init__(self, conn: "AsyncConnection", binary: bool = False):
+        super().__init__(conn, binary)
+
+    async def execute(
+        self, query: Query, vars: Optional[Params] = None
+    ) -> "AsyncCursor":
         async with self.conn.lock:
             gen = self._execute_send(query, vars)
             results = await self.conn.wait(gen)
index 7baba55b150e1f1e7cd651c20deda4e6cd4c8ff2..c8ce84438264042a5e93d4b585084c01d4f32779 100644 (file)
@@ -18,6 +18,11 @@ DBAPI-defined Exceptions are defined in the following hierarchy::
 
 # Copyright (C) 2020 The Psycopg Team
 
+from typing import Any, Optional, Sequence, TYPE_CHECKING
+
+if TYPE_CHECKING:
+    from psycopg3.pq import PGresult  # noqa
+
 
 class Warning(Exception):
     """
@@ -32,7 +37,9 @@ class Error(Exception):
     Base exception for all the errors psycopg3 will raise.
     """
 
-    def __init__(self, *args, pgresult=None):
+    def __init__(
+        self, *args: Sequence[Any], pgresult: Optional["PGresult"] = None
+    ):
         super().__init__(*args)
         self.pgresult = pgresult
 
@@ -100,6 +107,6 @@ class NotSupportedError(DatabaseError):
     """
 
 
-def class_for_state(sqlstate):
+def class_for_state(sqlstate: bytes) -> type:
     # TODO: stub
     return DatabaseError
index 98575ae6024534efb2307afd3ba49f1255391c48..70a0fd0d3d39ae77b769a10fd0394a74a0e9f51a 100644 (file)
@@ -19,7 +19,7 @@ from .enums import (
     Format,
 )
 from .encodings import py_codecs
-from .misc import error_message
+from .misc import error_message, ConninfoOption
 
 from . import pq_ctypes as pq_module
 
@@ -41,6 +41,7 @@ __all__ = (
     "Conninfo",
     "PQerror",
     "error_message",
+    "ConninfoOption",
     "py_codecs",
     "version",
 )
index 31887b2418dc9233ac202aa8578f0dd58fbb7f30..86ed2a3232855ffd4c33b930de64c5bb16ce2172 100644 (file)
@@ -144,9 +144,9 @@ if libpq_version >= 120000:
     _PQhostaddr.restype = c_char_p
 
 
-def PQhostaddr(pgconn):
+def PQhostaddr(pgconn: type) -> bytes:
     if _PQhostaddr is not None:
-        return _PQhostaddr(pgconn)
+        return _PQhostaddr(pgconn)  # type: ignore
     else:
         raise NotSupportedError(
             f"PQhostaddr requires libpq from PostgreSQL 12,"
index ed1e6cca4c87b7ab760207af898252b49439e902..c6b020500c5fa7a9a9a2de5b49b55163116fe442 100644 (file)
@@ -4,8 +4,19 @@ Various functionalities to make easier to work with the libpq.
 
 # Copyright (C) 2020 The Psycopg Team
 
+from collections import namedtuple
+from typing import TYPE_CHECKING, Union
 
-def error_message(obj):
+if TYPE_CHECKING:
+    from psycopg3.pq import PGconn, PGresult  # noqa
+
+
+ConninfoOption = namedtuple(
+    "ConninfoOption", "keyword envvar compiled val label dispatcher dispsize"
+)
+
+
+def error_message(obj: Union["PGconn", "PGresult"]) -> str:
     """
     Return an error message from a PGconn or PGresult.
 
@@ -13,29 +24,33 @@ def error_message(obj):
     """
     from psycopg3 import pq
 
+    bmsg: bytes
+
     if isinstance(obj, pq.PGconn):
-        msg = obj.error_message
+        bmsg = obj.error_message
 
         # strip severity and whitespaces
-        if msg:
-            msg = msg.splitlines()[0].split(b":", 1)[-1].strip()
+        if bmsg:
+            bmsg = bmsg.splitlines()[0].split(b":", 1)[-1].strip()
 
     elif isinstance(obj, pq.PGresult):
-        msg = obj.error_field(pq.DiagnosticField.MESSAGE_PRIMARY)
-        if not msg:
-            msg = obj.error_message
+        bmsg = obj.error_field(pq.DiagnosticField.MESSAGE_PRIMARY)
+        if not bmsg:
+            bmsg = obj.error_message
 
             # strip severity and whitespaces
-            if msg:
-                msg = msg.splitlines()[0].split(b":", 1)[-1].strip()
+            if bmsg:
+                bmsg = bmsg.splitlines()[0].split(b":", 1)[-1].strip()
 
     else:
         raise TypeError(
             f"PGconn or PGresult expected, got {type(obj).__name__}"
         )
 
-    if msg:
-        msg = msg.decode("utf8", "replace")  # TODO: or in connection encoding?
+    if bmsg:
+        msg = bmsg.decode(
+            "utf8", "replace"
+        )  # TODO: or in connection encoding?
     else:
         msg = "no details available"
 
index a79bb4298f0ba0f13917ad82ee8242c50f43206e..0f1b47816e13975fb16de9977f8532f02b509824 100644 (file)
@@ -8,9 +8,9 @@ implementation.
 
 # Copyright (C) 2020 The Psycopg Team
 
-from collections import namedtuple
 from ctypes import string_at
 from ctypes import c_char_p, c_int, pointer
+from typing import Any, List, Optional, Sequence
 
 from .enums import (
     ConnStatus,
@@ -18,14 +18,16 @@ from .enums import (
     ExecStatus,
     TransactionStatus,
     Ping,
+    DiagnosticField,
+    Format,
 )
-from .misc import error_message
+from .misc import error_message, ConninfoOption
 from . import _pq_ctypes as impl
 from ..exceptions import OperationalError
 
 
-def version():
-    return impl.PQlibVersion()
+def version() -> int:
+    return impl.PQlibVersion()  # type: ignore
 
 
 class PQerror(OperationalError):
@@ -35,18 +37,16 @@ class PQerror(OperationalError):
 class PGconn:
     __slots__ = ("pgconn_ptr",)
 
-    def __init__(self, pgconn_ptr):
-        self.pgconn_ptr = pgconn_ptr
+    def __init__(self, pgconn_ptr: impl.PGconn_struct):
+        self.pgconn_ptr: Optional[impl.PGconn_struct] = pgconn_ptr
 
-    def __del__(self):
+    def __del__(self) -> None:
         self.finish()
 
     @classmethod
-    def connect(cls, conninfo):
+    def connect(cls, conninfo: bytes) -> "PGconn":
         if not isinstance(conninfo, bytes):
-            raise TypeError(
-                "bytes expected, got %s instead" % type(conninfo).__name__
-            )
+            raise TypeError(f"bytes expected, got {type(conninfo)} instead")
 
         pgconn_ptr = impl.PQconnectdb(conninfo)
         if not pgconn_ptr:
@@ -54,28 +54,26 @@ class PGconn:
         return cls(pgconn_ptr)
 
     @classmethod
-    def connect_start(cls, conninfo):
+    def connect_start(cls, conninfo: bytes) -> "PGconn":
         if not isinstance(conninfo, bytes):
-            raise TypeError(
-                "bytes expected, got %s instead" % type(conninfo).__name__
-            )
+            raise TypeError(f"bytes expected, got {type(conninfo)} instead")
 
         pgconn_ptr = impl.PQconnectStart(conninfo)
         if not pgconn_ptr:
             raise MemoryError("couldn't allocate PGconn")
         return cls(pgconn_ptr)
 
-    def connect_poll(self):
+    def connect_poll(self) -> PollingStatus:
         rv = impl.PQconnectPoll(self.pgconn_ptr)
         return PollingStatus(rv)
 
-    def finish(self):
+    def finish(self) -> None:
         self.pgconn_ptr, p = None, self.pgconn_ptr
         if p is not None:
             impl.PQfinish(p)
 
     @property
-    def info(self):
+    def info(self) -> List["ConninfoOption"]:
         opts = impl.PQconninfo(self.pgconn_ptr)
         if not opts:
             raise MemoryError("couldn't allocate connection info")
@@ -84,130 +82,124 @@ class PGconn:
         finally:
             impl.PQconninfoFree(opts)
 
-    def reset(self):
+    def reset(self) -> None:
         impl.PQreset(self.pgconn_ptr)
 
-    def reset_start(self):
+    def reset_start(self) -> None:
         if not impl.PQresetStart(self.pgconn_ptr):
             raise PQerror("couldn't reset connection")
 
-    def reset_poll(self):
+    def reset_poll(self) -> PollingStatus:
         rv = impl.PQresetPoll(self.pgconn_ptr)
         return PollingStatus(rv)
 
     @classmethod
-    def ping(self, conninfo):
+    def ping(self, conninfo: bytes) -> Ping:
         if not isinstance(conninfo, bytes):
-            raise TypeError(
-                "bytes expected, got %s instead" % type(conninfo).__name__
-            )
+            raise TypeError(f"bytes expected, got {type(conninfo)} instead")
 
         rv = impl.PQping(conninfo)
         return Ping(rv)
 
     @property
-    def db(self):
-        return impl.PQdb(self.pgconn_ptr)
+    def db(self) -> bytes:
+        return impl.PQdb(self.pgconn_ptr)  # type: ignore
 
     @property
-    def user(self):
-        return impl.PQuser(self.pgconn_ptr)
+    def user(self) -> bytes:
+        return impl.PQuser(self.pgconn_ptr)  # type: ignore
 
     @property
-    def password(self):
-        return impl.PQpass(self.pgconn_ptr)
+    def password(self) -> bytes:
+        return impl.PQpass(self.pgconn_ptr)  # type: ignore
 
     @property
-    def host(self):
-        return impl.PQhost(self.pgconn_ptr)
+    def host(self) -> bytes:
+        return impl.PQhost(self.pgconn_ptr)  # type: ignore
 
     @property
-    def hostaddr(self):
-        return impl.PQhostaddr(self.pgconn_ptr)
+    def hostaddr(self) -> bytes:
+        return impl.PQhostaddr(self.pgconn_ptr)  # type: ignore
 
     @property
-    def port(self):
-        return impl.PQport(self.pgconn_ptr)
+    def port(self) -> bytes:
+        return impl.PQport(self.pgconn_ptr)  # type: ignore
 
     @property
-    def tty(self):
-        return impl.PQtty(self.pgconn_ptr)
+    def tty(self) -> bytes:
+        return impl.PQtty(self.pgconn_ptr)  # type: ignore
 
     @property
-    def options(self):
-        return impl.PQoptions(self.pgconn_ptr)
+    def options(self) -> bytes:
+        return impl.PQoptions(self.pgconn_ptr)  # type: ignore
 
     @property
-    def status(self):
+    def status(self) -> ConnStatus:
         rv = impl.PQstatus(self.pgconn_ptr)
         return ConnStatus(rv)
 
     @property
-    def transaction_status(self):
+    def transaction_status(self) -> TransactionStatus:
         rv = impl.PQtransactionStatus(self.pgconn_ptr)
         return TransactionStatus(rv)
 
-    def parameter_status(self, name):
-        return impl.PQparameterStatus(self.pgconn_ptr, name)
+    def parameter_status(self, name: bytes) -> bytes:
+        return impl.PQparameterStatus(self.pgconn_ptr, name)  # type: ignore
 
     @property
-    def protocol_version(self):
-        return impl.PQprotocolVersion(self.pgconn_ptr)
+    def protocol_version(self) -> int:
+        return impl.PQprotocolVersion(self.pgconn_ptr)  # type: ignore
 
     @property
-    def server_version(self):
-        return impl.PQserverVersion(self.pgconn_ptr)
+    def server_version(self) -> int:
+        return impl.PQserverVersion(self.pgconn_ptr)  # type: ignore
 
     @property
-    def error_message(self):
-        return impl.PQerrorMessage(self.pgconn_ptr)
+    def error_message(self) -> bytes:
+        return impl.PQerrorMessage(self.pgconn_ptr)  # type: ignore
 
     @property
-    def socket(self):
-        return impl.PQsocket(self.pgconn_ptr)
+    def socket(self) -> int:
+        return impl.PQsocket(self.pgconn_ptr)  # type: ignore
 
     @property
-    def backend_pid(self):
-        return impl.PQbackendPID(self.pgconn_ptr)
+    def backend_pid(self) -> int:
+        return impl.PQbackendPID(self.pgconn_ptr)  # type: ignore
 
     @property
-    def needs_password(self):
+    def needs_password(self) -> bool:
         return bool(impl.PQconnectionNeedsPassword(self.pgconn_ptr))
 
     @property
-    def used_password(self):
+    def used_password(self) -> bool:
         return bool(impl.PQconnectionUsedPassword(self.pgconn_ptr))
 
     @property
-    def ssl_in_use(self):
+    def ssl_in_use(self) -> bool:
         return bool(impl.PQsslInUse(self.pgconn_ptr))
 
-    def exec_(self, command):
+    def exec_(self, command: bytes) -> "PGresult":
         if not isinstance(command, bytes):
-            raise TypeError(
-                "bytes expected, got %s instead" % type(command).__name__
-            )
+            raise TypeError(f"bytes expected, got {type(command)} instead")
         rv = impl.PQexec(self.pgconn_ptr, command)
         if not rv:
             raise MemoryError("couldn't allocate PGresult")
         return PGresult(rv)
 
-    def send_query(self, command):
+    def send_query(self, command: bytes) -> None:
         if not isinstance(command, bytes):
-            raise TypeError(
-                "bytes expected, got %s instead" % type(command).__name__
-            )
+            raise TypeError(f"bytes expected, got {type(command)} instead")
         if not impl.PQsendQuery(self.pgconn_ptr, command):
             raise PQerror(f"sending query failed: {error_message(self)}")
 
     def exec_params(
         self,
-        command,
-        param_values,
-        param_types=None,
-        param_formats=None,
-        result_format=0,
-    ):
+        command: bytes,
+        param_values: List[Optional[bytes]],
+        param_types: Optional[List[int]] = None,
+        param_formats: Optional[List[Format]] = None,
+        result_format: Format = Format.TEXT,
+    ) -> "PGresult":
         args = self._query_params_args(
             command, param_values, param_types, param_formats, result_format
         )
@@ -218,12 +210,12 @@ class PGconn:
 
     def send_query_params(
         self,
-        command,
-        param_values,
-        param_types=None,
-        param_formats=None,
-        result_format=0,
-    ):
+        command: bytes,
+        param_values: List[Optional[bytes]],
+        param_types: Optional[List[int]] = None,
+        param_formats: Optional[List[Format]] = None,
+        result_format: Format = Format.TEXT,
+    ) -> None:
         args = self._query_params_args(
             command, param_values, param_types, param_formats, result_format
         )
@@ -233,12 +225,15 @@ class PGconn:
             )
 
     def _query_params_args(
-        self, command, param_values, param_types, param_formats, result_format,
-    ):
+        self,
+        command: bytes,
+        param_values: List[Optional[bytes]],
+        param_types: Optional[List[int]] = None,
+        param_formats: Optional[List[Format]] = None,
+        result_format: Format = Format.TEXT,
+    ) -> Any:
         if not isinstance(command, bytes):
-            raise TypeError(
-                "bytes expected, got %s instead" % type(command).__name__
-            )
+            raise TypeError(f"bytes expected, got {type(command)} instead")
 
         nparams = len(param_values)
         if nparams:
@@ -247,7 +242,7 @@ class PGconn:
                 *(len(p) if p is not None else 0 for p in param_values)
             )
         else:
-            aparams = alenghts = None
+            aparams = alenghts = None  # type: ignore
 
         if param_types is None:
             atypes = None
@@ -280,16 +275,18 @@ class PGconn:
             result_format,
         )
 
-    def prepare(self, name, command, param_types=None):
+    def prepare(
+        self,
+        name: bytes,
+        command: bytes,
+        param_types: Optional[List[int]] = None,
+    ) -> "PGresult":
         if not isinstance(name, bytes):
-            raise TypeError(
-                "'name' must be bytes, got %s instead" % type(name).__name__
-            )
+            raise TypeError(f"'name' must be bytes, got {type(name)} instead")
 
         if not isinstance(command, bytes):
             raise TypeError(
-                "'command' must be bytes, got %s instead"
-                % type(command).__name__
+                f"'command' must be bytes, got {type(command)} instead"
             )
 
         if param_types is None:
@@ -305,12 +302,14 @@ class PGconn:
         return PGresult(rv)
 
     def exec_prepared(
-        self, name, param_values, param_formats=None, result_format=0
-    ):
+        self,
+        name: bytes,
+        param_values: List[bytes],
+        param_formats: Optional[List[int]] = None,
+        result_format: int = 0,
+    ) -> "PGresult":
         if not isinstance(name, bytes):
-            raise TypeError(
-                "'name' must be bytes, got %s instead" % type(name).__name__
-            )
+            raise TypeError(f"'name' must be bytes, got {type(name)} instead")
 
         nparams = len(param_values)
         if nparams:
@@ -319,7 +318,7 @@ class PGconn:
                 *(len(p) if p is not None else 0 for p in param_values)
             )
         else:
-            aparams = alenghts = None
+            aparams = alenghts = None  # type: ignore
 
         if param_formats is None:
             aformats = None
@@ -344,53 +343,49 @@ class PGconn:
             raise MemoryError("couldn't allocate PGresult")
         return PGresult(rv)
 
-    def describe_prepared(self, name):
+    def describe_prepared(self, name: bytes) -> "PGresult":
         if not isinstance(name, bytes):
-            raise TypeError(
-                "'name' must be bytes, got %s instead" % type(name).__name__
-            )
+            raise TypeError(f"'name' must be bytes, got {type(name)} instead")
         rv = impl.PQdescribePrepared(self.pgconn_ptr, name)
         if not rv:
             raise MemoryError("couldn't allocate PGresult")
         return PGresult(rv)
 
-    def describe_portal(self, name):
+    def describe_portal(self, name: bytes) -> "PGresult":
         if not isinstance(name, bytes):
-            raise TypeError(
-                "'name' must be bytes, got %s instead" % type(name).__name__
-            )
+            raise TypeError(f"'name' must be bytes, got {type(name)} instead")
         rv = impl.PQdescribePortal(self.pgconn_ptr, name)
         if not rv:
             raise MemoryError("couldn't allocate PGresult")
         return PGresult(rv)
 
-    def get_result(self):
+    def get_result(self) -> Optional["PGresult"]:
         rv = impl.PQgetResult(self.pgconn_ptr)
         return PGresult(rv) if rv else None
 
-    def consume_input(self):
+    def consume_input(self) -> None:
         if 1 != impl.PQconsumeInput(self.pgconn_ptr):
             raise PQerror(f"consuming input failed: {error_message(self)}")
 
-    def is_busy(self):
-        return impl.PQisBusy(self.pgconn_ptr)
+    def is_busy(self) -> int:
+        return impl.PQisBusy(self.pgconn_ptr)  # type: ignore
 
     @property
-    def nonblocking(self):
-        return impl.PQisnonblocking(self.pgconn_ptr)
+    def nonblocking(self) -> int:
+        return impl.PQisnonblocking(self.pgconn_ptr)  # type: ignore
 
     @nonblocking.setter
-    def nonblocking(self, arg):
+    def nonblocking(self, arg: int) -> None:
         if 0 > impl.PQsetnonblocking(self.pgconn_ptr, arg):
             raise PQerror(f"setting nonblocking failed: {error_message(self)}")
 
-    def flush(self):
-        rv = impl.PQflush(self.pgconn_ptr)
+    def flush(self) -> int:
+        rv: int = impl.PQflush(self.pgconn_ptr)
         if rv < 0:
             raise PQerror(f"flushing failed: {error_message(self)}")
         return rv
 
-    def make_empty_result(self, exec_status):
+    def make_empty_result(self, exec_status: ExecStatus) -> "PGresult":
         rv = impl.PQmakeEmptyPGresult(self.pgconn_ptr, exec_status)
         if not rv:
             raise MemoryError("couldn't allocate empty PGresult")
@@ -400,64 +395,72 @@ class PGconn:
 class PGresult:
     __slots__ = ("pgresult_ptr",)
 
-    def __init__(self, pgresult_ptr):
-        self.pgresult_ptr = pgresult_ptr
+    def __init__(self, pgresult_ptr: type):
+        self.pgresult_ptr: Optional[type] = pgresult_ptr
 
-    def __del__(self):
+    def __del__(self) -> None:
         self.clear()
 
-    def clear(self):
+    def clear(self) -> None:
         self.pgresult_ptr, p = None, self.pgresult_ptr
         if p is not None:
             impl.PQclear(p)
 
     @property
-    def status(self):
+    def status(self) -> ExecStatus:
         rv = impl.PQresultStatus(self.pgresult_ptr)
         return ExecStatus(rv)
 
     @property
-    def error_message(self):
-        return impl.PQresultErrorMessage(self.pgresult_ptr)
+    def error_message(self) -> bytes:
+        return impl.PQresultErrorMessage(self.pgresult_ptr)  # type: ignore
 
-    def error_field(self, fieldcode):
-        return impl.PQresultErrorField(self.pgresult_ptr, fieldcode)
+    def error_field(self, fieldcode: DiagnosticField) -> bytes:
+        return impl.PQresultErrorField(  # type: ignore
+            self.pgresult_ptr, fieldcode
+        )
 
     @property
-    def ntuples(self):
-        return impl.PQntuples(self.pgresult_ptr)
+    def ntuples(self) -> int:
+        return impl.PQntuples(self.pgresult_ptr)  # type: ignore
 
     @property
-    def nfields(self):
-        return impl.PQnfields(self.pgresult_ptr)
+    def nfields(self) -> int:
+        return impl.PQnfields(self.pgresult_ptr)  # type: ignore
 
-    def fname(self, column_number):
-        return impl.PQfname(self.pgresult_ptr, column_number)
+    def fname(self, column_number: int) -> int:
+        return impl.PQfname(self.pgresult_ptr, column_number)  # type: ignore
 
-    def ftable(self, column_number):
-        return impl.PQftable(self.pgresult_ptr, column_number)
+    def ftable(self, column_number: int) -> int:
+        return impl.PQftable(self.pgresult_ptr, column_number)  # type: ignore
 
-    def ftablecol(self, column_number):
-        return impl.PQftablecol(self.pgresult_ptr, column_number)
+    def ftablecol(self, column_number: int) -> int:
+        return impl.PQftablecol(  # type: ignore
+            self.pgresult_ptr, column_number
+        )
 
-    def fformat(self, column_number):
-        return impl.PQfformat(self.pgresult_ptr, column_number)
+    def fformat(self, column_number: int) -> Format:
+        return impl.PQfformat(self.pgresult_ptr, column_number)  # type: ignore
 
-    def ftype(self, column_number):
-        return impl.PQftype(self.pgresult_ptr, column_number)
+    def ftype(self, column_number: int) -> int:
+        return impl.PQftype(self.pgresult_ptr, column_number)  # type: ignore
 
-    def fmod(self, column_number):
-        return impl.PQfmod(self.pgresult_ptr, column_number)
+    def fmod(self, column_number: int) -> int:
+        return impl.PQfmod(self.pgresult_ptr, column_number)  # type: ignore
 
-    def fsize(self, column_number):
-        return impl.PQfsize(self.pgresult_ptr, column_number)
+    def fsize(self, column_number: int) -> int:
+        return impl.PQfsize(self.pgresult_ptr, column_number)  # type: ignore
 
     @property
-    def binary_tuples(self):
-        return impl.PQbinaryTuples(self.pgresult_ptr)
-
-    def get_value(self, row_number, column_number):
-        length = impl.PQgetlength(self.pgresult_ptr, row_number, column_number)
+    def binary_tuples(self) -> int:
+        return impl.PQbinaryTuples(self.pgresult_ptr)  # type: ignore
+
+    def get_value(
+        self, row_number: int, column_number: int
+    ) -> Optional[bytes]:
+        length: int = impl.PQgetlength(
+            self.pgresult_ptr, row_number, column_number
+        )
         if length:
             v = impl.PQgetvalue(self.pgresult_ptr, row_number, column_number)
             return string_at(v, length)
@@ -468,35 +471,31 @@ class PGresult:
                 return b""
 
     @property
-    def nparams(self):
-        return impl.PQnparams(self.pgresult_ptr)
+    def nparams(self) -> int:
+        return impl.PQnparams(self.pgresult_ptr)  # type: ignore
 
-    def param_type(self, param_number):
-        return impl.PQparamtype(self.pgresult_ptr, param_number)
+    def param_type(self, param_number: int) -> int:
+        return impl.PQparamtype(  # type: ignore
+            self.pgresult_ptr, param_number
+        )
 
     @property
-    def command_status(self):
-        return impl.PQcmdStatus(self.pgresult_ptr)
+    def command_status(self) -> bytes:
+        return impl.PQcmdStatus(self.pgresult_ptr)  # type: ignore
 
     @property
-    def command_tuples(self):
+    def command_tuples(self) -> Optional[int]:
         rv = impl.PQcmdTuples(self.pgresult_ptr)
-        if rv:
-            return int(rv)
+        return int(rv) if rv else None
 
     @property
-    def oid_value(self):
-        return impl.PQoidValue(self.pgresult_ptr)
-
-
-ConninfoOption = namedtuple(
-    "ConninfoOption", "keyword envvar compiled val label dispatcher dispsize"
-)
+    def oid_value(self) -> int:
+        return impl.PQoidValue(self.pgresult_ptr)  # type: ignore
 
 
 class Conninfo:
     @classmethod
-    def get_defaults(cls):
+    def get_defaults(cls) -> List[ConninfoOption]:
         opts = impl.PQconndefaults()
         if not opts:
             raise MemoryError("couldn't allocate connection defaults")
@@ -506,11 +505,9 @@ class Conninfo:
             impl.PQconninfoFree(opts)
 
     @classmethod
-    def parse(cls, conninfo):
+    def parse(cls, conninfo: bytes) -> List[ConninfoOption]:
         if not isinstance(conninfo, bytes):
-            raise TypeError(
-                "bytes expected, got %s instead" % type(conninfo).__name__
-            )
+            raise TypeError(f"bytes expected, got {type(conninfo)} instead")
 
         errmsg = c_char_p()
         rv = impl.PQconninfoParse(conninfo, pointer(errmsg))
@@ -518,7 +515,7 @@ class Conninfo:
             if not errmsg:
                 raise MemoryError("couldn't allocate on conninfo parse")
             else:
-                exc = PQerror(errmsg.value.decode("utf8", "replace"))
+                exc = PQerror((errmsg.value or b"").decode("utf8", "replace"))
                 impl.PQfreemem(errmsg)
                 raise exc
 
@@ -528,7 +525,9 @@ class Conninfo:
             impl.PQconninfoFree(rv)
 
     @classmethod
-    def _options_from_array(cls, opts):
+    def _options_from_array(
+        cls, opts: Sequence[impl.PQconninfoOption_struct]
+    ) -> List[ConninfoOption]:
         rv = []
         skws = "keyword envvar compiled val label dispatcher".split()
         for opt in opts:
index 0dd0213b8d3a19fa6c1421ef5481c1ff68910aa2..b633de9241cee1591dd73f9f7751eda17ba35a17 100644 (file)
@@ -5,6 +5,7 @@ Adapters of numeric types.
 # Copyright (C) 2020 The Psycopg Team
 
 import codecs
+from typing import Tuple
 
 from ..adaptation import Adapter, Typecaster
 from .oids import type_oid
@@ -14,10 +15,10 @@ _decode = codecs.lookup("ascii").decode
 
 
 @Adapter.register(int)
-def adapt_int(obj):
+def adapt_int(obj: int) -> Tuple[bytes, int]:
     return _encode(str(obj))[0], type_oid["numeric"]
 
 
 @Typecaster.register(type_oid["numeric"])
-def cast_int(data):
+def cast_int(data: bytes) -> int:
     return int(_decode(data)[0])
index 025f6a337e954ec8cf38d3e050cc9f6aaa571774..d81b2c8651525df465696db54e36cd92b18de2f3 100644 (file)
@@ -92,7 +92,7 @@ _oids_table = [
 type_oid = {name: oid for name, oid, _, _ in _oids_table}
 
 
-def self_update():
+def self_update() -> None:
     import subprocess as sp
 
     # queries output should make black happy
index 3d1405deba22ee948ff461176c8e56c46172103c..2db4d1d300ecb771c3d36060ebcde573924bb75a 100644 (file)
@@ -5,30 +5,39 @@ Adapters of textual types.
 # Copyright (C) 2020 The Psycopg Team
 
 import codecs
+from typing import Optional, Union
 
-from ..adaptation import Adapter
-from ..adaptation import Typecaster
+from ..adaptation import Adapter, Typecaster
+from ..connection import BaseConnection
+from ..utils.typing import EncodeFunc, DecodeFunc
 from .oids import type_oid
 
 
 @Adapter.register(str)
 @Adapter.register_binary(str)
 class StringAdapter(Adapter):
-    def __init__(self, cls, conn):
+    def __init__(self, cls: type, conn: BaseConnection):
         super().__init__(cls, conn)
-        self.encode = (
-            conn.codec if conn is not None else codecs.lookup("utf8")
-        ).encode
 
-    def adapt(self, obj):
-        return self.encode(obj)[0]
+        self._encode: EncodeFunc
+        if conn is not None:
+            self._encode = conn.codec.encode
+        else:
+            self._encode = codecs.lookup("utf8").encode
+
+    def adapt(self, obj: str) -> bytes:
+        return self._encode(obj)[0]
 
 
 @Typecaster.register(type_oid["text"])
 @Typecaster.register_binary(type_oid["text"])
 class StringCaster(Typecaster):
-    def __init__(self, oid, conn):
+
+    decode: Optional[DecodeFunc]
+
+    def __init__(self, oid: int, conn: BaseConnection):
         super().__init__(oid, conn)
+
         if conn is not None:
             if conn.pgenc != b"SQL_ASCII":
                 self.decode = conn.codec.decode
@@ -37,7 +46,7 @@ class StringCaster(Typecaster):
         else:
             self.decode = codecs.lookup("utf8").decode
 
-    def cast(self, data):
+    def cast(self, data: bytes) -> Union[bytes, str]:
         if self.decode is not None:
             return self.decode(data)[0]
         else:
index 8cdd8cb3ee3952ae70dc02c10680c506497e2041..5ad976e5febd84e70407c32a4a1223370f2a7c48 100644 (file)
@@ -5,13 +5,28 @@ Utility module to manipulate queries
 # Copyright (C) 2020 The Psycopg Team
 
 import re
-from collections.abc import Sequence, Mapping
+from codecs import CodecInfo
+from typing import (
+    Any,
+    Dict,
+    List,
+    Mapping,
+    Match,
+    NamedTuple,
+    Optional,
+    Sequence,
+    Tuple,
+    Union,
+)
 
 from .. import exceptions as exc
 from ..pq import Format
+from .typing import Params
 
 
-def query2pg(query, vars, codec):
+def query2pg(
+    query: bytes, vars: Params, codec: CodecInfo
+) -> Tuple[bytes, List[Format], Optional[List[str]]]:
     """
     Convert Python query and params into something Postgres understands.
 
@@ -30,6 +45,9 @@ def query2pg(query, vars, codec):
         )
 
     parts = split_query(query, codec.name)
+    order: Optional[List[str]] = None
+    chunks: List[bytes] = []
+    formats = []
 
     if isinstance(vars, Sequence) and not isinstance(vars, (bytes, str)):
         if len(vars) != len(parts) - 1:
@@ -37,32 +55,40 @@ def query2pg(query, vars, codec):
                 f"the query has {len(parts) - 1} placeholders but"
                 f" {len(vars)} parameters were passed"
             )
-        if vars and not isinstance(parts[0][1], int):
+        if vars and not isinstance(parts[0].index, int):
             raise TypeError(
                 "named placeholders require a mapping of parameters"
             )
-        order = None
+
+        for part in parts[:-1]:
+            assert isinstance(part.index, int)
+            chunks.append(part.pre)
+            chunks.append(b"$%d" % (part.index + 1))
+            formats.append(part.format)
 
     elif isinstance(vars, Mapping):
-        if vars and len(parts) > 1 and not isinstance(parts[0][1], bytes):
+        if vars and len(parts) > 1 and not isinstance(parts[0][1], str):
             raise TypeError(
                 "positional placeholders (%s) require a sequence of parameters"
             )
-        seen = {}
+        seen: Dict[str, Tuple[bytes, Format]] = {}
         order = []
         for part in parts[:-1]:
-            name = codec.decode(part[1])[0]
-            if name not in seen:
-                n = len(seen)
-                part[1] = n
-                seen[name] = (n, part[2])
-                order.append(name)
+            assert isinstance(part.index, str)
+            formats.append(part.format)
+            chunks.append(part.pre)
+            if part.index not in seen:
+                ph = b"$%d" % (len(seen) + 1)
+                seen[part.index] = (ph, part.format)
+                order.append(part.index)
+                chunks.append(ph)
             else:
-                if seen[name][1] != part[2]:
+                if seen[part.index][1] != part.format:
                     raise exc.ProgrammingError(
-                        f"placeholder '{name}' cannot have different formats"
+                        f"placeholder '{part.index}' cannot have"
+                        f" different formats"
                     )
-                part[1] = seen[name][0]
+                chunks.append(seen[part.index][0])
 
     else:
         raise TypeError(
@@ -70,16 +96,10 @@ def query2pg(query, vars, codec):
             f" got {type(vars).__name__}"
         )
 
-    # Assemble query and parameters
-    rv = []
-    formats = []
-    for part in parts[:-1]:
-        rv.append(part[0])
-        rv.append(b"$%d" % (part[1] + 1))
-        formats.append(part[2])
-    rv.append(parts[-1][0])
+    # last part
+    chunks.append(parts[-1].pre)
 
-    return b"".join(rv), formats, order
+    return b"".join(chunks), formats, order
 
 
 _re_placeholder = re.compile(
@@ -97,35 +117,48 @@ _re_placeholder = re.compile(
 )
 
 
-def split_query(query, encoding="ascii"):
-    parts = []
+class QueryPart(NamedTuple):
+    pre: bytes
+    # TODO: mypy bug? https://github.com/python/mypy/issues/8599
+    index: Union[int, str]  # type: ignore
+    format: Format
+
+
+def split_query(query: bytes, encoding: str = "ascii") -> List[QueryPart]:
+    parts: List[Tuple[bytes, Optional[Match[bytes]]]] = []
     cur = 0
 
-    # pairs [(fragment, match)], with the last match None
+    # pairs [(fragment, match], with the last match None
     m = None
     for m in _re_placeholder.finditer(query):
         pre = query[cur : m.span(0)[0]]
-        parts.append([pre, m, None])
+        parts.append((pre, m))
         cur = m.span(0)[1]
     if m is None:
-        parts.append([query, None, None])
+        parts.append((query, None))
     else:
-        parts.append([query[cur:], None, None])
+        parts.append((query[cur:], None))
+
+    rv = []
 
     # drop the "%%", validate
     i = 0
     phtype = None
     while i < len(parts):
-        part = parts[i]
-        m = part[1]
+        pre, m = parts[i]
         if m is None:
-            break  # last part
+            # last part
+            rv.append(QueryPart(pre, 0, Format.TEXT))
+            break
+
         ph = m.group(0)
         if ph == b"%%":
             # unescape '%%' to '%' and merge the parts
-            parts[i + 1][0] = part[0] + b"%" + parts[i + 1][0]
+            pre1, m1 = parts[i + 1]
+            parts[i + 1] = (pre + b"%" + pre1, m1)
             del parts[i]
             continue
+
         if ph == b"%(":
             raise exc.ProgrammingError(
                 f"incomplete placeholder:"
@@ -144,28 +177,29 @@ def split_query(query, encoding="ascii"):
             )
 
         # Index or name
-        if m.group(1) is None:
-            part[1] = i
-        else:
-            part[1] = m.group(1)
-
-        # Binary format
-        part[2] = Format(ph[-1:] == b"b")
+        index: Union[int, str]
+        index = i if m.group(1) is None else m.group(1).decode(encoding)
 
         if phtype is None:
-            phtype = type(part[1])
+            phtype = type(index)
         else:
-            if phtype is not type(part[1]):  # noqa
+            if phtype is not type(index):  # noqa
                 raise exc.ProgrammingError(
                     "positional and named placeholders cannot be mixed"
                 )
 
+        # Binary format
+        format = Format(ph[-1:] == b"b")
+
+        rv.append(QueryPart(pre, index, format))
         i += 1
 
-    return parts
+    return rv
 
 
-def reorder_params(params, order):
+def reorder_params(
+    params: Mapping[str, Any], order: Sequence[str]
+) -> List[str]:
     """
     Convert a mapping of parameters into an array in a specified order
     """
diff --git a/psycopg3/utils/typing.py b/psycopg3/utils/typing.py
new file mode 100644 (file)
index 0000000..f96576c
--- /dev/null
@@ -0,0 +1,13 @@
+"""
+Additional types for checking
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from typing import Any, Callable, Mapping, Sequence, Tuple, Union
+
+EncodeFunc = Callable[[str], Tuple[bytes, int]]
+DecodeFunc = Callable[[bytes], Tuple[str, int]]
+
+Query = Union[str, bytes]
+Params = Union[Sequence[Any], Mapping[str, Any]]
index 7a6580f688f7731c650c6c880304deb9bd19be15..59e0e1f618cfa87874127d415bff594bad5b92ce 100644 (file)
@@ -7,6 +7,7 @@ Code concerned with waiting in different contexts (blocking, async, etc).
 
 from enum import Enum
 from select import select
+from typing import Generator, Tuple, TypeVar
 from asyncio import get_event_loop, Event
 
 from . import exceptions as exc
@@ -15,8 +16,10 @@ from . import exceptions as exc
 Wait = Enum("Wait", "R W RW")
 Ready = Enum("Ready", "R W")
 
+RV = TypeVar("RV")
 
-def wait_select(gen):
+
+def wait_select(gen: Generator[Tuple[int, Wait], Ready, RV]) -> RV:
     """
     Wait on the behalf of a generator using select().
 
@@ -48,10 +51,11 @@ def wait_select(gen):
             else:
                 raise exc.InternalError("bad poll status: %s")
     except StopIteration as e:
-        return e.args[0]
+        rv: RV = e.args[0]
+        return rv
 
 
-async def wait_async(gen):
+async def wait_async(gen: Generator[Tuple[int, Wait], Ready, RV]) -> RV:
     """
     Coroutine waiting for a generator to complete.
 
@@ -65,9 +69,9 @@ async def wait_async(gen):
     # Not sure this is the best implementation but it's a start.
     ev = Event()
     loop = get_event_loop()
-    ready = None
+    ready = Ready.R
 
-    def wakeup(state):
+    def wakeup(state: Ready) -> None:
         nonlocal ready
         ready = state
         ev.set()
@@ -96,4 +100,5 @@ async def wait_async(gen):
             else:
                 raise exc.InternalError("bad poll status: %s")
     except StopIteration as e:
-        return e.args[0]
+        rv: RV = e.args[0]
+        return rv
index 8aa6ec9529ea0c91132c140a1f18362a8a6f431d..0323017813ba26cacfa9efece97f7ae94856e189 100644 (file)
@@ -8,34 +8,22 @@ from psycopg3.utils.queries import split_query, query2pg, reorder_params
 @pytest.mark.parametrize(
     "input, want",
     [
-        (b"", [[b"", None, None]]),
-        (b"foo bar", [[b"foo bar", None, None]]),
-        (b"foo %% bar", [[b"foo % bar", None, None]]),
-        (b"%s", [[b"", 0, False], [b"", None, None]]),
-        (b"%s foo", [[b"", 0, False], [b" foo", None, None]]),
-        (b"%b foo", [[b"", 0, True], [b" foo", None, None]]),
-        (b"foo %s", [[b"foo ", 0, False], [b"", None, None]]),
-        (b"foo %%%s bar", [[b"foo %", 0, False], [b" bar", None, None]]),
-        (
-            b"foo %(name)s bar",
-            [[b"foo ", b"name", False], [b" bar", None, None]],
-        ),
+        (b"", [(b"", 0, 0)]),
+        (b"foo bar", [(b"foo bar", 0, 0)]),
+        (b"foo %% bar", [(b"foo % bar", 0, 0)]),
+        (b"%s", [(b"", 0, 0), (b"", 0, 0)]),
+        (b"%s foo", [(b"", 0, 0), (b" foo", 0, 0)]),
+        (b"%b foo", [(b"", 0, 1), (b" foo", 0, 0)]),
+        (b"foo %s", [(b"foo ", 0, 0), (b"", 0, 0)]),
+        (b"foo %%%s bar", [(b"foo %", 0, 0), (b" bar", 0, 0)]),
+        (b"foo %(name)s bar", [(b"foo ", "name", 0), (b" bar", 0, 0)]),
         (
             b"foo %(name)s %(name)b bar",
-            [
-                [b"foo ", b"name", False],
-                [b" ", b"name", True],
-                [b" bar", None, None],
-            ],
+            [(b"foo ", "name", 0), (b" ", "name", 1), (b" bar", 0, 0)],
         ),
         (
             b"foo %s%b bar %s baz",
-            [
-                [b"foo ", 0, False],
-                [b"", 1, True],
-                [b" bar ", 2, False],
-                [b" baz", None, None],
-            ],
+            [(b"foo ", 0, 0), (b"", 1, 1), (b" bar ", 2, 0), (b" baz", 0, 0)],
         ),
     ],
 )
diff --git a/tox.ini b/tox.ini
index 2f9f668aba71709d45dc778a780967c6c9bcfbdc..0a1b28a35b45bac275b9c5496f6ef5cd86380a70 100644 (file)
--- a/tox.ini
+++ b/tox.ini
@@ -24,7 +24,7 @@ exclude = env, .tox
 ignore = W503, E203
 
 [mypy]
-files = psycopg3, tests, setup.py
+files = psycopg3, setup.py
 warn_unused_ignores = True
 
 [mypy-pytest]