]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added types stub for ctypes functions
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 30 Mar 2020 15:21:36 +0000 (04:21 +1300)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 30 Mar 2020 15:21:36 +0000 (04:21 +1300)
psycopg3/pq/_pq_ctypes.py
psycopg3/pq/_pq_ctypes.pyi [new file with mode: 0644]
psycopg3/pq/pq_ctypes.py
tox.ini

index 86ed2a3232855ffd4c33b930de64c5bb16ce2172..f3868ed54dff27f298c0f80914f4c2588a8ce419 100644 (file)
@@ -146,7 +146,7 @@ if libpq_version >= 120000:
 
 def PQhostaddr(pgconn: type) -> bytes:
     if _PQhostaddr is not None:
-        return _PQhostaddr(pgconn)  # type: ignore
+        return _PQhostaddr(pgconn)
     else:
         raise NotSupportedError(
             f"PQhostaddr requires libpq from PostgreSQL 12,"
@@ -414,3 +414,78 @@ PQfreemem.restype = None
 PQmakeEmptyPGresult = pq.PQmakeEmptyPGresult
 PQmakeEmptyPGresult.argtypes = [PGconn_ptr, c_int]
 PQmakeEmptyPGresult.restype = PGresult_ptr
+
+
+def generate_stub() -> None:
+    import re
+    from ctypes import _CFuncPtr
+
+    def type2str(fname, narg, t):
+        if t is None:
+            return "None"
+        elif t is c_void_p:
+            return "Any"
+        elif t is c_int or t is c_uint:
+            return "int"
+        elif t is c_char_p:
+            return "bytes"
+
+        elif t.__name__ in ("LP_PGconn_struct", "LP_PGresult_struct",):
+            if narg is not None:
+                return f"Optional[{t.__name__[3:]}]"
+            else:
+                return t.__name__[3:]
+
+        elif t.__name__ in ("LP_PQconninfoOption_struct",):
+            return f"Sequence[{t.__name__[3:]}]"
+
+        elif t.__name__ in (
+            "LP_c_char",
+            "LP_c_char_p",
+            "LP_c_int",
+            "LP_c_uint",
+        ):
+            return f"pointer[{t.__name__[3:]}]"
+
+        else:
+            assert False, f"can't deal with {t} in {fname}"
+
+    fn = __file__ + "i"
+    with open(fn, "r") as f:
+        lines = f.read().splitlines()
+
+    istart, iend = [
+        i
+        for i, l in enumerate(lines)
+        if re.match(r"\s*#\s*autogenerated:\s+(start|end)", l)
+    ]
+
+    known = {
+        l[4:].split("(", 1)[0] for l in lines[:istart] if l.startswith("def ")
+    }
+
+    signatures = []
+
+    for name, obj in globals().items():
+        if name in known:
+            continue
+        if not isinstance(obj, _CFuncPtr):
+            continue
+
+        params = []
+        for i, t in enumerate(obj.argtypes):
+            params.append(f"arg{i + 1}: {type2str(name, i, t)}")
+
+        resname = type2str(name, None, obj.restype)
+
+        signatures.append(f"def {name}({', '.join(params)}) -> {resname}: ...")
+
+    lines[istart + 1 : iend] = signatures
+
+    with open(fn, "w") as f:
+        f.write("\n".join(lines))
+        f.write("\n")
+
+
+if __name__ == "__main__":
+    generate_stub()
diff --git a/psycopg3/pq/_pq_ctypes.pyi b/psycopg3/pq/_pq_ctypes.pyi
new file mode 100644 (file)
index 0000000..99b8acc
--- /dev/null
@@ -0,0 +1,115 @@
+"""
+types stub for ctypes functions
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from typing import Any, Optional, Sequence, NewType
+from ctypes import Array, c_char, c_char_p, c_int, c_uint, pointer
+
+Oid = c_uint
+
+class PGconn_struct: ...
+class PGresult_struct: ...
+
+class PQconninfoOption_struct:
+    keyword: bytes
+    envvar: bytes
+    compiled: bytes
+    val: bytes
+    label: bytes
+    dispatcher: bytes
+    dispsize: int
+
+def PQhostaddr(arg1: Optional[PGconn_struct]) -> bytes: ...
+def PQexecPrepared(
+    arg1: Optional[PGconn_struct],
+    arg2: bytes,
+    arg3: int,
+    arg4: Optional[Array[c_char_p]],
+    arg5: Optional[Array[c_int]],
+    arg6: Optional[Array[c_int]],
+    arg7: int,
+) -> PGresult_struct: ...
+def PQprepare(
+    arg1: Optional[PGconn_struct],
+    arg2: bytes,
+    arg3: bytes,
+    arg4: int,
+    arg5: Optional[Array[c_uint]],
+) -> PGresult_struct: ...
+
+# fmt: off
+# autogenerated: start
+def PQlibVersion() -> int: ...
+def PQconnectdb(arg1: bytes) -> PGconn_struct: ...
+def PQconnectStart(arg1: bytes) -> PGconn_struct: ...
+def PQconnectPoll(arg1: Optional[PGconn_struct]) -> int: ...
+def PQconndefaults() -> Sequence[PQconninfoOption_struct]: ...
+def PQconninfoFree(arg1: Sequence[PQconninfoOption_struct]) -> None: ...
+def PQconninfo(arg1: Optional[PGconn_struct]) -> Sequence[PQconninfoOption_struct]: ...
+def PQconninfoParse(arg1: bytes, arg2: pointer[c_char_p]) -> Sequence[PQconninfoOption_struct]: ...
+def PQfinish(arg1: Optional[PGconn_struct]) -> None: ...
+def PQreset(arg1: Optional[PGconn_struct]) -> None: ...
+def PQresetStart(arg1: Optional[PGconn_struct]) -> int: ...
+def PQresetPoll(arg1: Optional[PGconn_struct]) -> int: ...
+def PQping(arg1: bytes) -> int: ...
+def PQdb(arg1: Optional[PGconn_struct]) -> bytes: ...
+def PQuser(arg1: Optional[PGconn_struct]) -> bytes: ...
+def PQpass(arg1: Optional[PGconn_struct]) -> bytes: ...
+def PQhost(arg1: Optional[PGconn_struct]) -> bytes: ...
+def _PQhostaddr(arg1: Optional[PGconn_struct]) -> bytes: ...
+def PQport(arg1: Optional[PGconn_struct]) -> bytes: ...
+def PQtty(arg1: Optional[PGconn_struct]) -> bytes: ...
+def PQoptions(arg1: Optional[PGconn_struct]) -> bytes: ...
+def PQstatus(arg1: Optional[PGconn_struct]) -> int: ...
+def PQtransactionStatus(arg1: Optional[PGconn_struct]) -> int: ...
+def PQparameterStatus(arg1: Optional[PGconn_struct], arg2: bytes) -> bytes: ...
+def PQprotocolVersion(arg1: Optional[PGconn_struct]) -> int: ...
+def PQserverVersion(arg1: Optional[PGconn_struct]) -> int: ...
+def PQerrorMessage(arg1: Optional[PGconn_struct]) -> bytes: ...
+def PQsocket(arg1: Optional[PGconn_struct]) -> int: ...
+def PQbackendPID(arg1: Optional[PGconn_struct]) -> int: ...
+def PQconnectionNeedsPassword(arg1: Optional[PGconn_struct]) -> int: ...
+def PQconnectionUsedPassword(arg1: Optional[PGconn_struct]) -> int: ...
+def PQsslInUse(arg1: Optional[PGconn_struct]) -> int: ...
+def PQexec(arg1: Optional[PGconn_struct], arg2: bytes) -> PGresult_struct: ...
+def PQexecParams(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int, arg4: pointer[c_uint], arg5: pointer[c_char_p], arg6: pointer[c_int], arg7: pointer[c_int], arg8: int) -> PGresult_struct: ...
+def PQdescribePrepared(arg1: Optional[PGconn_struct], arg2: bytes) -> PGresult_struct: ...
+def PQdescribePortal(arg1: Optional[PGconn_struct], arg2: bytes) -> PGresult_struct: ...
+def PQresultStatus(arg1: Optional[PGresult_struct]) -> int: ...
+def PQresultErrorMessage(arg1: Optional[PGresult_struct]) -> bytes: ...
+def PQresultErrorField(arg1: Optional[PGresult_struct], arg2: int) -> bytes: ...
+def PQclear(arg1: Optional[PGresult_struct]) -> None: ...
+def PQntuples(arg1: Optional[PGresult_struct]) -> int: ...
+def PQnfields(arg1: Optional[PGresult_struct]) -> int: ...
+def PQfname(arg1: Optional[PGresult_struct], arg2: int) -> bytes: ...
+def PQftable(arg1: Optional[PGresult_struct], arg2: int) -> int: ...
+def PQftablecol(arg1: Optional[PGresult_struct], arg2: int) -> int: ...
+def PQfformat(arg1: Optional[PGresult_struct], arg2: int) -> int: ...
+def PQftype(arg1: Optional[PGresult_struct], arg2: int) -> int: ...
+def PQfmod(arg1: Optional[PGresult_struct], arg2: int) -> int: ...
+def PQfsize(arg1: Optional[PGresult_struct], arg2: int) -> int: ...
+def PQbinaryTuples(arg1: Optional[PGresult_struct]) -> int: ...
+def PQgetvalue(arg1: Optional[PGresult_struct], arg2: int, arg3: int) -> pointer[c_char]: ...
+def PQgetisnull(arg1: Optional[PGresult_struct], arg2: int, arg3: int) -> int: ...
+def PQgetlength(arg1: Optional[PGresult_struct], arg2: int, arg3: int) -> int: ...
+def PQnparams(arg1: Optional[PGresult_struct]) -> int: ...
+def PQparamtype(arg1: Optional[PGresult_struct], arg2: int) -> int: ...
+def PQcmdStatus(arg1: Optional[PGresult_struct]) -> bytes: ...
+def PQcmdTuples(arg1: Optional[PGresult_struct]) -> bytes: ...
+def PQoidValue(arg1: Optional[PGresult_struct]) -> int: ...
+def PQsendQuery(arg1: Optional[PGconn_struct], arg2: bytes) -> int: ...
+def PQsendQueryParams(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int, arg4: pointer[c_uint], arg5: pointer[c_char_p], arg6: pointer[c_int], arg7: pointer[c_int], arg8: int) -> int: ...
+def PQgetResult(arg1: Optional[PGconn_struct]) -> PGresult_struct: ...
+def PQconsumeInput(arg1: Optional[PGconn_struct]) -> int: ...
+def PQisBusy(arg1: Optional[PGconn_struct]) -> int: ...
+def PQsetnonblocking(arg1: Optional[PGconn_struct], arg2: int) -> int: ...
+def PQisnonblocking(arg1: Optional[PGconn_struct]) -> int: ...
+def PQflush(arg1: Optional[PGconn_struct]) -> int: ...
+def PQfreemem(arg1: Any) -> None: ...
+def PQmakeEmptyPGresult(arg1: Optional[PGconn_struct], arg2: int) -> PGresult_struct: ...
+# autogenerated: end
+# fmt: on
+
+# vim: set syntax=python:
index 15ca32778516399ee5545e3df1bd3054cdc19f74..d11414b837d7ebc5d282112def0882691b32edb8 100644 (file)
@@ -9,7 +9,7 @@ implementation.
 # Copyright (C) 2020 The Psycopg Team
 
 from ctypes import string_at
-from ctypes import c_char_p, c_int, pointer
+from ctypes import Array, c_char_p, c_int, pointer
 from typing import Any, List, Optional, Sequence
 
 from .enums import (
@@ -28,7 +28,7 @@ from ..utils.typing import Oid
 
 
 def version() -> int:
-    return impl.PQlibVersion()  # type: ignore
+    return impl.PQlibVersion()
 
 
 class PQerror(OperationalError):
@@ -104,35 +104,35 @@ class PGconn:
 
     @property
     def db(self) -> bytes:
-        return impl.PQdb(self.pgconn_ptr)  # type: ignore
+        return impl.PQdb(self.pgconn_ptr)
 
     @property
     def user(self) -> bytes:
-        return impl.PQuser(self.pgconn_ptr)  # type: ignore
+        return impl.PQuser(self.pgconn_ptr)
 
     @property
     def password(self) -> bytes:
-        return impl.PQpass(self.pgconn_ptr)  # type: ignore
+        return impl.PQpass(self.pgconn_ptr)
 
     @property
     def host(self) -> bytes:
-        return impl.PQhost(self.pgconn_ptr)  # type: ignore
+        return impl.PQhost(self.pgconn_ptr)
 
     @property
     def hostaddr(self) -> bytes:
-        return impl.PQhostaddr(self.pgconn_ptr)  # type: ignore
+        return impl.PQhostaddr(self.pgconn_ptr)
 
     @property
     def port(self) -> bytes:
-        return impl.PQport(self.pgconn_ptr)  # type: ignore
+        return impl.PQport(self.pgconn_ptr)
 
     @property
     def tty(self) -> bytes:
-        return impl.PQtty(self.pgconn_ptr)  # type: ignore
+        return impl.PQtty(self.pgconn_ptr)
 
     @property
     def options(self) -> bytes:
-        return impl.PQoptions(self.pgconn_ptr)  # type: ignore
+        return impl.PQoptions(self.pgconn_ptr)
 
     @property
     def status(self) -> ConnStatus:
@@ -145,27 +145,27 @@ class PGconn:
         return TransactionStatus(rv)
 
     def parameter_status(self, name: bytes) -> bytes:
-        return impl.PQparameterStatus(self.pgconn_ptr, name)  # type: ignore
+        return impl.PQparameterStatus(self.pgconn_ptr, name)
 
     @property
     def protocol_version(self) -> int:
-        return impl.PQprotocolVersion(self.pgconn_ptr)  # type: ignore
+        return impl.PQprotocolVersion(self.pgconn_ptr)
 
     @property
     def server_version(self) -> int:
-        return impl.PQserverVersion(self.pgconn_ptr)  # type: ignore
+        return impl.PQserverVersion(self.pgconn_ptr)
 
     @property
     def error_message(self) -> bytes:
-        return impl.PQerrorMessage(self.pgconn_ptr)  # type: ignore
+        return impl.PQerrorMessage(self.pgconn_ptr)
 
     @property
     def socket(self) -> int:
-        return impl.PQsocket(self.pgconn_ptr)  # type: ignore
+        return impl.PQsocket(self.pgconn_ptr)
 
     @property
     def backend_pid(self) -> int:
-        return impl.PQbackendPID(self.pgconn_ptr)  # type: ignore
+        return impl.PQbackendPID(self.pgconn_ptr)
 
     @property
     def needs_password(self) -> bool:
@@ -237,14 +237,15 @@ class PGconn:
             raise TypeError(f"bytes expected, got {type(command)} instead")
 
         nparams = len(param_values)
+        aparams: Optional[Array[c_char_p]] = None
+        alenghts: Optional[Array[c_int]] = None
         if nparams:
             aparams = (c_char_p * nparams)(*param_values)
             alenghts = (c_int * nparams)(
                 *(len(p) if p is not None else 0 for p in param_values)
             )
-        else:
-            aparams = alenghts = None  # type: ignore
 
+        atypes: Optional[Array[impl.Oid]]
         if param_types is None:
             atypes = None
         else:
@@ -313,13 +314,13 @@ class PGconn:
             raise TypeError(f"'name' must be bytes, got {type(name)} instead")
 
         nparams = len(param_values)
+        aparams: Optional[Array[c_char_p]] = None
+        alenghts: Optional[Array[c_int]] = None
         if nparams:
             aparams = (c_char_p * nparams)(*param_values)
             alenghts = (c_int * nparams)(
                 *(len(p) if p is not None else 0 for p in param_values)
             )
-        else:
-            aparams = alenghts = None  # type: ignore
 
         if param_formats is None:
             aformats = None
@@ -369,11 +370,11 @@ class PGconn:
             raise PQerror(f"consuming input failed: {error_message(self)}")
 
     def is_busy(self) -> int:
-        return impl.PQisBusy(self.pgconn_ptr)  # type: ignore
+        return impl.PQisBusy(self.pgconn_ptr)
 
     @property
     def nonblocking(self) -> int:
-        return impl.PQisnonblocking(self.pgconn_ptr)  # type: ignore
+        return impl.PQisnonblocking(self.pgconn_ptr)
 
     @nonblocking.setter
     def nonblocking(self, arg: int) -> None:
@@ -396,8 +397,8 @@ class PGconn:
 class PGresult:
     __slots__ = ("pgresult_ptr",)
 
-    def __init__(self, pgresult_ptr: type):
-        self.pgresult_ptr: Optional[type] = pgresult_ptr
+    def __init__(self, pgresult_ptr: impl.PGresult_struct):
+        self.pgresult_ptr: Optional[impl.PGresult_struct] = pgresult_ptr
 
     def __del__(self) -> None:
         self.clear()
@@ -414,43 +415,39 @@ class PGresult:
 
     @property
     def error_message(self) -> bytes:
-        return impl.PQresultErrorMessage(self.pgresult_ptr)  # type: ignore
+        return impl.PQresultErrorMessage(self.pgresult_ptr)
 
     def error_field(self, fieldcode: DiagnosticField) -> bytes:
-        return impl.PQresultErrorField(  # type: ignore
-            self.pgresult_ptr, fieldcode
-        )
+        return impl.PQresultErrorField(self.pgresult_ptr, fieldcode)
 
     @property
     def ntuples(self) -> int:
-        return impl.PQntuples(self.pgresult_ptr)  # type: ignore
+        return impl.PQntuples(self.pgresult_ptr)
 
     @property
     def nfields(self) -> int:
-        return impl.PQnfields(self.pgresult_ptr)  # type: ignore
+        return impl.PQnfields(self.pgresult_ptr)
 
-    def fname(self, column_number: int) -> int:
-        return impl.PQfname(self.pgresult_ptr, column_number)  # type: ignore
+    def fname(self, column_number: int) -> bytes:
+        return impl.PQfname(self.pgresult_ptr, column_number)
 
     def ftable(self, column_number: int) -> Oid:
-        return impl.PQftable(self.pgresult_ptr, column_number)  # type: ignore
+        return Oid(impl.PQftable(self.pgresult_ptr, column_number))
 
     def ftablecol(self, column_number: int) -> int:
-        return impl.PQftablecol(  # type: ignore
-            self.pgresult_ptr, column_number
-        )
+        return impl.PQftablecol(self.pgresult_ptr, column_number)
 
     def fformat(self, column_number: int) -> Format:
-        return impl.PQfformat(self.pgresult_ptr, column_number)  # type: ignore
+        return Format(impl.PQfformat(self.pgresult_ptr, column_number))
 
     def ftype(self, column_number: int) -> Oid:
-        return impl.PQftype(self.pgresult_ptr, column_number)  # type: ignore
+        return Oid(impl.PQftype(self.pgresult_ptr, column_number))
 
     def fmod(self, column_number: int) -> int:
-        return impl.PQfmod(self.pgresult_ptr, column_number)  # type: ignore
+        return impl.PQfmod(self.pgresult_ptr, column_number)
 
     def fsize(self, column_number: int) -> int:
-        return impl.PQfsize(self.pgresult_ptr, column_number)  # type: ignore
+        return impl.PQfsize(self.pgresult_ptr, column_number)
 
     @property
     def binary_tuples(self) -> Format:
@@ -473,16 +470,14 @@ class PGresult:
 
     @property
     def nparams(self) -> int:
-        return impl.PQnparams(self.pgresult_ptr)  # type: ignore
+        return impl.PQnparams(self.pgresult_ptr)
 
     def param_type(self, param_number: int) -> Oid:
-        return impl.PQparamtype(  # type: ignore
-            self.pgresult_ptr, param_number
-        )
+        return Oid(impl.PQparamtype(self.pgresult_ptr, param_number))
 
     @property
     def command_status(self) -> bytes:
-        return impl.PQcmdStatus(self.pgresult_ptr)  # type: ignore
+        return impl.PQcmdStatus(self.pgresult_ptr)
 
     @property
     def command_tuples(self) -> Optional[int]:
@@ -491,7 +486,7 @@ class PGresult:
 
     @property
     def oid_value(self) -> Oid:
-        return impl.PQoidValue(self.pgresult_ptr)  # type: ignore
+        return Oid(impl.PQoidValue(self.pgresult_ptr))
 
 
 class Conninfo:
diff --git a/tox.ini b/tox.ini
index 0a1b28a35b45bac275b9c5496f6ef5cd86380a70..3d39bdc779bfea7e0b540d46f5e56fa2bcc1b491 100644 (file)
--- a/tox.ini
+++ b/tox.ini
@@ -26,6 +26,7 @@ ignore = W503, E203
 [mypy]
 files = psycopg3, setup.py
 warn_unused_ignores = True
+strict = True
 
 [mypy-pytest]
 ignore_missing_imports = True