From: Daniele Varrazzo Date: Sun, 29 Mar 2020 14:39:21 +0000 (+1300) Subject: Added most type annotations X-Git-Tag: 3.0.dev0~648 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=2149d627b23a3f464a2d1a44119c0f8655cdaf5f;p=thirdparty%2Fpsycopg.git Added most type annotations mypy almost passes in --strict mode --- diff --git a/psycopg3/adaptation.py b/psycopg3/adaptation.py index 710c30f2f..bf7ab98d6 100644 --- a/psycopg3/adaptation.py +++ b/psycopg3/adaptation.py @@ -5,28 +5,59 @@ Entry point into the adaptation system. # Copyright (C) 2020 The Psycopg Team import codecs -from typing import Dict, Tuple +from typing import ( + Any, + Callable, + cast, + Dict, + Generator, + List, + Optional, + Sequence, + Tuple, + Union, +) from functools import partial from . import exceptions as exc -from .pq import Format +from .pq import Format, PGresult from .cursor import BaseCursor from .types.oids import type_oid, INVALID_OID from .connection import BaseConnection +from .utils.typing import DecodeFunc + + +# Type system + +AdaptContext = Union[BaseConnection, BaseCursor] + +MaybeOid = Union[Optional[bytes], Tuple[Optional[bytes], int]] +AdapterFunc = Callable[[Any], MaybeOid] +AdapterType = Union["Adapter", AdapterFunc] +AdaptersMap = Dict[Tuple[type, Format], AdapterType] + +TypecasterFunc = Callable[[bytes], Any] +TypecasterType = Union["Typecaster", TypecasterFunc] +TypecastersMap = Dict[Tuple[int, Format], TypecasterType] class Adapter: - globals: Dict[Tuple[type, Format], "Adapter"] = {} # TODO: incomplete type + globals: AdaptersMap = {} - def __init__(self, cls, conn): + def __init__(self, cls: type, conn: BaseConnection): self.cls = cls self.conn = conn - def adapt(self, obj): + def adapt(self, obj: Any) -> Union[bytes, Tuple[bytes, int]]: raise NotImplementedError() @staticmethod - def register(cls, adapter=None, context=None, format=Format.TEXT): + def register( + cls: type, + adapter: Optional[AdapterType] = None, + context: Optional[AdaptContext] = None, + format: Format = Format.TEXT, + ) -> AdapterType: if adapter is None: # used as decorator return partial(Adapter.register, cls, format=format) @@ -58,23 +89,31 @@ class Adapter: return adapter @staticmethod - def register_binary(cls, adapter=None, context=None): + def register_binary( + cls: type, + adapter: Optional[AdapterType] = None, + context: Optional[AdaptContext] = None, + ) -> AdapterType: return Adapter.register(cls, adapter, context, format=Format.BINARY) class Typecaster: - # TODO: incomplete type - globals: Dict[Tuple[type, Format], "Typecaster"] = {} + globals: TypecastersMap = {} - def __init__(self, oid, conn): + def __init__(self, oid: int, conn: Optional[BaseConnection]): self.oid = oid self.conn = conn - def cast(self, data): + def cast(self, data: bytes) -> Any: raise NotImplementedError() @staticmethod - def register(oid, caster=None, context=None, format=Format.TEXT): + def register( + oid: int, + caster: Optional[TypecasterType] = None, + context: Optional[AdaptContext] = None, + format: Format = Format.TEXT, + ) -> TypecasterType: if caster is None: # used as decorator return partial(Typecaster.register, oid, format=format) @@ -101,12 +140,16 @@ class Typecaster: f" got {caster} instead" ) - where = context.adapters if context is not None else Typecaster.globals + where = context.casters if context is not None else Typecaster.globals where[oid, format] = caster return caster @staticmethod - def register_binary(oid, caster=None, context=None): + def register_binary( + oid: int, + caster: Optional[TypecasterType] = None, + context: Optional[AdaptContext] = None, + ) -> TypecasterType: return Typecaster.register(oid, caster, context, format=Format.BINARY) @@ -119,7 +162,10 @@ class Transformer: state so adapting several values of the same type can use optimisations. """ - def __init__(self, context): + connection: Optional[BaseConnection] + cursor: Optional[BaseCursor] + + def __init__(self, context: Optional[AdaptContext]): if context is None: self.connection = None self.cursor = None @@ -136,24 +182,24 @@ class Transformer: ) # mapping class, fmt -> adaptation function - self._adapt_funcs = {} + self._adapt_funcs: Dict[Tuple[type, Format], AdapterFunc] = {} # mapping oid, fmt -> cast function - self._cast_funcs = {} + self._cast_funcs: Dict[Tuple[int, Format], TypecasterFunc] = {} # The result to return values from - self._result = None + self._result: Optional[PGresult] = None # sequence of cast function from value to python # the length of the result columns - self._row_casters = None + self._row_casters: List[TypecasterFunc] = [] @property - def result(self): + def result(self) -> Optional[PGresult]: return self._result @result.setter - def result(self, result): + def result(self, result: PGresult) -> None: if self._result is result: return @@ -164,7 +210,9 @@ class Transformer: func = self.get_cast_function(oid, fmt) rc.append(func) - def adapt_sequence(self, objs, fmts): + def adapt_sequence( + self, objs: Sequence[Any], fmts: Sequence[Format] + ) -> Tuple[List[Optional[bytes]], List[int]]: out = [] types = [] @@ -181,7 +229,7 @@ class Transformer: return out, types - def adapt(self, obj, fmt): + def adapt(self, obj: None, fmt: Format) -> MaybeOid: if obj is None: return None, type_oid["text"] @@ -189,7 +237,7 @@ class Transformer: func = self.get_adapt_function(cls, fmt) return func(obj) - def get_adapt_function(self, cls, fmt): + def get_adapt_function(self, cls: type, fmt: Format) -> AdapterFunc: try: return self._adapt_funcs[cls, fmt] except KeyError: @@ -197,11 +245,11 @@ class Transformer: adapter = self.lookup_adapter(cls, fmt) if isinstance(adapter, type): - adapter = adapter(cls, self.connection).adapt - - return adapter + return adapter(cls, self.connection).adapt + else: + return cast(AdapterFunc, adapter) - def lookup_adapter(self, cls, fmt): + def lookup_adapter(self, cls: type, fmt: Format) -> AdapterType: key = (cls, fmt) cur = self.cursor @@ -219,7 +267,7 @@ class Transformer: f"cannot adapt type {cls.__name__} to format {Format(fmt).name}" ) - def cast_row(self, result, n): + def cast_row(self, result: PGresult, n: int) -> Generator[Any, None, None]: self.result = result for col, func in enumerate(self._row_casters): @@ -228,7 +276,7 @@ class Transformer: v = func(v) yield v - def get_cast_function(self, oid, fmt): + def get_cast_function(self, oid: int, fmt: Format) -> TypecasterFunc: try: return self._cast_funcs[oid, fmt] except KeyError: @@ -236,11 +284,11 @@ class Transformer: caster = self.lookup_caster(oid, fmt) if isinstance(caster, type): - caster = caster(oid, self.connection).cast - - return caster + return caster(oid, self.connection).cast + else: + return cast(TypecasterFunc, caster) - def lookup_caster(self, oid, fmt): + def lookup_caster(self, oid: int, fmt: Format) -> TypecasterType: key = (oid, fmt) cur = self.cursor @@ -263,17 +311,18 @@ class UnknownCaster(Typecaster): Fallback object to convert unknown types to Python """ - def __init__(self, oid, conn): + def __init__(self, oid: int, conn: Optional[BaseConnection]): super().__init__(oid, conn) + self.decode: DecodeFunc if conn is not None: self.decode = conn.codec.decode else: self.decode = codecs.lookup("utf8").decode - def cast(self, data): + def cast(self, data: bytes) -> str: return self.decode(data)[0] @Typecaster.register_binary(INVALID_OID) -def cast_unknown(data): +def cast_unknown(data: bytes) -> bytes: return data diff --git a/psycopg3/connection.py b/psycopg3/connection.py index 21e714ddf..7d89c1b7b 100644 --- a/psycopg3/connection.py +++ b/psycopg3/connection.py @@ -8,6 +8,17 @@ import codecs import logging import asyncio import threading +from typing import ( + cast, + Any, + Generator, + List, + Optional, + Tuple, + Type, + TypeVar, + TYPE_CHECKING, +) from . import pq from . import exceptions as exc @@ -17,6 +28,13 @@ from .waiting import wait_select, wait_async, Wait, Ready logger = logging.getLogger(__name__) +ConnectGen = Generator[Tuple[int, Wait], Ready, pq.PGconn] +QueryGen = Generator[Tuple[int, Wait], Ready, List[pq.PGresult]] +RV = TypeVar("RV") + +if TYPE_CHECKING: + from .adaptation import AdaptersMap, TypecastersMap + class BaseConnection: """ @@ -26,38 +44,40 @@ class BaseConnection: allow different interfaces (sync/async). """ - def __init__(self, pgconn): + def __init__(self, pgconn: pq.PGconn): self.pgconn = pgconn - self.cursor_factory = None - self.adapters = {} - self.casters = {} + self.cursor_factory = cursor.BaseCursor + self.adapters: AdaptersMap = {} + self.casters: TypecastersMap = {} # name of the postgres encoding (in bytes) - self.pgenc = None + self.pgenc = b"" - def cursor(self, name=None, binary=False): + def cursor( + self, name: Optional[str] = None, binary: bool = False + ) -> cursor.BaseCursor: if name is not None: raise NotImplementedError return self.cursor_factory(self, binary=binary) @property - def codec(self): + def codec(self) -> codecs.CodecInfo: # TODO: utf8 fastpath? pgenc = self.pgconn.parameter_status(b"client_encoding") if self.pgenc != pgenc: # for unknown encodings and SQL_ASCII be strict and use ascii - pyenc = pq.py_codecs.get(pgenc.decode("ascii"), "ascii") + pyenc = pq.py_codecs.get(pgenc.decode("ascii")) or "ascii" self._codec = codecs.lookup(pyenc) self.pgenc = pgenc return self._codec - def encode(self, s): + def encode(self, s: str) -> bytes: return self.codec.encode(s)[0] - def decode(self, b): + def decode(self, b: bytes) -> str: return self.codec.decode(b)[0] @classmethod - def _connect_gen(cls, conninfo): + def _connect_gen(cls, conninfo: str) -> ConnectGen: """ Generator to create a database connection without blocking. @@ -65,9 +85,7 @@ class BaseConnection: generator can be restarted sending the appropriate `Ready` state when the file descriptor is ready. """ - conninfo = conninfo.encode("utf8") - - conn = pq.PGconn.connect_start(conninfo) + conn = pq.PGconn.connect_start(conninfo.encode("utf8")) logger.debug("connection started, status %s", conn.status.name) while 1: if conn.status == pq.ConnStatus.BAD: @@ -94,7 +112,7 @@ class BaseConnection: return conn @classmethod - def _exec_gen(cls, pgconn): + def _exec_gen(cls, pgconn: pq.PGconn) -> QueryGen: """ Generator returning query results without blocking. @@ -109,7 +127,7 @@ class BaseConnection: Return the list of results returned by the database (whether success or error). """ - results = [] + results: List[pq.PGresult] = [] while 1: f = pgconn.flush() @@ -148,13 +166,17 @@ class Connection(BaseConnection): This class implements a DBAPI-compliant interface. """ - def __init__(self, pgconn): + cursor_factory: Type[cursor.Cursor] + + def __init__(self, pgconn: pq.PGconn): super().__init__(pgconn) self.lock = threading.Lock() self.cursor_factory = cursor.Cursor @classmethod - def connect(cls, conninfo, connection_factory=None, **kwargs): + def connect( + cls, conninfo: str, connection_factory: Any = None, **kwargs: Any + ) -> "Connection": if connection_factory is not None: raise NotImplementedError() conninfo = make_conninfo(conninfo, **kwargs) @@ -162,13 +184,19 @@ class Connection(BaseConnection): pgconn = cls.wait(gen) return cls(pgconn) - def commit(self): + def cursor( + self, name: Optional[str] = None, binary: bool = False + ) -> cursor.Cursor: + cur = super().cursor(name, binary) + return cast(cursor.Cursor, cur) + + def commit(self) -> None: self._exec_commit_rollback(b"commit") - def rollback(self): + def rollback(self) -> None: self._exec_commit_rollback(b"rollback") - def _exec_commit_rollback(self, command): + def _exec_commit_rollback(self, command: bytes) -> None: with self.lock: status = self.pgconn.transaction_status if status == pq.TransactionStatus.IDLE: @@ -183,7 +211,7 @@ class Connection(BaseConnection): ) @classmethod - def wait(cls, gen): + def wait(cls, gen: Generator[Tuple[int, Wait], Ready, RV]) -> RV: return wait_select(gen) @@ -195,25 +223,33 @@ class AsyncConnection(BaseConnection): methods implemented as coroutines. """ - def __init__(self, pgconn): + cursor_factory: Type[cursor.AsyncCursor] + + def __init__(self, pgconn: pq.PGconn): super().__init__(pgconn) self.lock = asyncio.Lock() self.cursor_factory = cursor.AsyncCursor @classmethod - async def connect(cls, conninfo, **kwargs): + async def connect(cls, conninfo: str, **kwargs: Any) -> "AsyncConnection": conninfo = make_conninfo(conninfo, **kwargs) gen = cls._connect_gen(conninfo) pgconn = await cls.wait(gen) return cls(pgconn) - async def commit(self): + def cursor( + self, name: Optional[str] = None, binary: bool = False + ) -> cursor.AsyncCursor: + cur = super().cursor(name, binary) + return cast(cursor.AsyncCursor, cur) + + async def commit(self) -> None: await self._exec_commit_rollback(b"commit") - async def rollback(self): + async def rollback(self) -> None: await self._exec_commit_rollback(b"rollback") - async def _exec_commit_rollback(self, command): + async def _exec_commit_rollback(self, command: bytes) -> None: async with self.lock: status = self.pgconn.transaction_status if status == pq.TransactionStatus.IDLE: @@ -228,5 +264,5 @@ class AsyncConnection(BaseConnection): ) @classmethod - async def wait(cls, gen): + async def wait(cls, gen: Generator[Tuple[int, Wait], Ready, RV]) -> RV: return await wait_async(gen) diff --git a/psycopg3/conninfo.py b/psycopg3/conninfo.py index 365c3e3d9..81308a2a2 100644 --- a/psycopg3/conninfo.py +++ b/psycopg3/conninfo.py @@ -1,16 +1,17 @@ import re +from typing import Any, Dict, List from . import pq from . import exceptions as exc -def make_conninfo(conninfo=None, **kwargs): +def make_conninfo(conninfo: str = "", **kwargs: Any) -> str: """ Merge a string and keyword params into a single conninfo string. Raise ProgrammingError if the input don't make a valid conninfo. """ - if conninfo is None and not kwargs: + if not conninfo and not kwargs: return "" # If no kwarg is specified don't mung the conninfo but check if it's correct @@ -37,7 +38,7 @@ def make_conninfo(conninfo=None, **kwargs): return conninfo -def conninfo_to_dict(conninfo): +def conninfo_to_dict(conninfo: str) -> Dict[str, str]: """ Convert the *conninfo* string into a dictionary of parameters. @@ -51,7 +52,7 @@ def conninfo_to_dict(conninfo): } -def _parse_conninfo(conninfo): +def _parse_conninfo(conninfo: str) -> List[pq.ConninfoOption]: """ Verify that *conninfo* is a valid connection string. @@ -65,9 +66,11 @@ def _parse_conninfo(conninfo): raise exc.ProgrammingError(str(e)) -def _param_escape( - s, re_escape=re.compile(r"([\\'])"), re_space=re.compile(r"\s") -): +re_escape = re.compile(r"([\\'])") +re_space = re.compile(r"\s") + + +def _param_escape(s: str) -> str: """ Apply the escaping rule required by PQconnectdb """ diff --git a/psycopg3/cursor.py b/psycopg3/cursor.py index e5d63f1f9..ec6416820 100644 --- a/psycopg3/cursor.py +++ b/psycopg3/cursor.py @@ -4,29 +4,43 @@ psycopg3 cursor objects # Copyright (C) 2020 The Psycopg Team +from typing import Any, List, Mapping, Optional, Sequence, Tuple, TYPE_CHECKING + from . import exceptions as exc -from .pq import error_message, DiagnosticField, ExecStatus +from .pq import error_message, DiagnosticField, ExecStatus, PGresult, Format from .utils.queries import query2pg, reorder_params +from .utils.typing import Query, Params + +if TYPE_CHECKING: + from .connection import ( + BaseConnection, + Connection, + AsyncConnection, + QueryGen, + ) + from .adaptation import AdaptersMap, TypecastersMap class BaseCursor: - def __init__(self, conn, binary=False): + def __init__(self, conn: "BaseConnection", binary: bool = False): self.conn = conn self.binary = binary - self.adapters = {} - self.casters = {} + self.adapters: AdaptersMap = {} + self.casters: TypecastersMap = {} self._reset() - def _reset(self): + def _reset(self) -> None: from .adaptation import Transformer - self._results = [] - self._result = None + self._results: List[PGresult] = [] + self._result: Optional[PGresult] = None self._pos = 0 self._iresult = 0 self._transformer = Transformer(self) - def _execute_send(self, query, vars): + def _execute_send( + self, query: Query, vars: Optional[Params] + ) -> "QueryGen": # Implement part of execute() before waiting common to sync and async self._reset() @@ -40,21 +54,23 @@ class BaseCursor: query, formats, order = query2pg(query, vars, codec) if vars: if order is not None: + assert isinstance(vars, Mapping) vars = reorder_params(vars, order) + assert isinstance(vars, Sequence) params, types = self._transformer.adapt_sequence(vars, formats) self.conn.pgconn.send_query_params( query, params, param_formats=formats, param_types=types, - result_format=int(self.binary), + result_format=Format(self.binary), ) else: self.conn.pgconn.send_query(query) return self.conn._exec_gen(self.conn.pgconn) - def _execute_results(self, results): + def _execute_results(self, results: List[PGresult]) -> None: # Implement part of execute() after waiting common to sync and async if not results: raise exc.InternalError("got no result from the query") @@ -89,20 +105,22 @@ class BaseCursor: f" {', '.join(sorted(s.name for s in sorted(badstats)))}" ) - def nextset(self): + def nextset(self) -> Optional[bool]: self._iresult += 1 if self._iresult < len(self._results): self._result = self._results[self._iresult] self._pos = 0 return True + else: + return None - def fetchone(self): + def fetchone(self) -> Optional[Sequence[Any]]: rv = self._cast_row(self._pos) if rv is not None: self._pos += 1 return rv - def _cast_row(self, n): + def _cast_row(self, n: int) -> Optional[Tuple[Any, ...]]: if self._result is None: return None if n >= self._result.ntuples: @@ -112,7 +130,12 @@ class BaseCursor: class Cursor(BaseCursor): - def execute(self, query, vars=None): + conn: "Connection" + + def __init__(self, conn: "Connection", binary: bool = False): + super().__init__(conn, binary) + + def execute(self, query: Query, vars: Optional[Params] = None) -> "Cursor": with self.conn.lock: gen = self._execute_send(query, vars) results = self.conn.wait(gen) @@ -121,7 +144,14 @@ class Cursor(BaseCursor): class AsyncCursor(BaseCursor): - async def execute(self, query, vars=None): + conn: "AsyncConnection" + + def __init__(self, conn: "AsyncConnection", binary: bool = False): + super().__init__(conn, binary) + + async def execute( + self, query: Query, vars: Optional[Params] = None + ) -> "AsyncCursor": async with self.conn.lock: gen = self._execute_send(query, vars) results = await self.conn.wait(gen) diff --git a/psycopg3/exceptions.py b/psycopg3/exceptions.py index 7baba55b1..c8ce84438 100644 --- a/psycopg3/exceptions.py +++ b/psycopg3/exceptions.py @@ -18,6 +18,11 @@ DBAPI-defined Exceptions are defined in the following hierarchy:: # Copyright (C) 2020 The Psycopg Team +from typing import Any, Optional, Sequence, TYPE_CHECKING + +if TYPE_CHECKING: + from psycopg3.pq import PGresult # noqa + class Warning(Exception): """ @@ -32,7 +37,9 @@ class Error(Exception): Base exception for all the errors psycopg3 will raise. """ - def __init__(self, *args, pgresult=None): + def __init__( + self, *args: Sequence[Any], pgresult: Optional["PGresult"] = None + ): super().__init__(*args) self.pgresult = pgresult @@ -100,6 +107,6 @@ class NotSupportedError(DatabaseError): """ -def class_for_state(sqlstate): +def class_for_state(sqlstate: bytes) -> type: # TODO: stub return DatabaseError diff --git a/psycopg3/pq/__init__.py b/psycopg3/pq/__init__.py index 98575ae60..70a0fd0d3 100644 --- a/psycopg3/pq/__init__.py +++ b/psycopg3/pq/__init__.py @@ -19,7 +19,7 @@ from .enums import ( Format, ) from .encodings import py_codecs -from .misc import error_message +from .misc import error_message, ConninfoOption from . import pq_ctypes as pq_module @@ -41,6 +41,7 @@ __all__ = ( "Conninfo", "PQerror", "error_message", + "ConninfoOption", "py_codecs", "version", ) diff --git a/psycopg3/pq/_pq_ctypes.py b/psycopg3/pq/_pq_ctypes.py index 31887b241..86ed2a323 100644 --- a/psycopg3/pq/_pq_ctypes.py +++ b/psycopg3/pq/_pq_ctypes.py @@ -144,9 +144,9 @@ if libpq_version >= 120000: _PQhostaddr.restype = c_char_p -def PQhostaddr(pgconn): +def PQhostaddr(pgconn: type) -> bytes: if _PQhostaddr is not None: - return _PQhostaddr(pgconn) + return _PQhostaddr(pgconn) # type: ignore else: raise NotSupportedError( f"PQhostaddr requires libpq from PostgreSQL 12," diff --git a/psycopg3/pq/misc.py b/psycopg3/pq/misc.py index ed1e6cca4..c6b020500 100644 --- a/psycopg3/pq/misc.py +++ b/psycopg3/pq/misc.py @@ -4,8 +4,19 @@ Various functionalities to make easier to work with the libpq. # Copyright (C) 2020 The Psycopg Team +from collections import namedtuple +from typing import TYPE_CHECKING, Union -def error_message(obj): +if TYPE_CHECKING: + from psycopg3.pq import PGconn, PGresult # noqa + + +ConninfoOption = namedtuple( + "ConninfoOption", "keyword envvar compiled val label dispatcher dispsize" +) + + +def error_message(obj: Union["PGconn", "PGresult"]) -> str: """ Return an error message from a PGconn or PGresult. @@ -13,29 +24,33 @@ def error_message(obj): """ from psycopg3 import pq + bmsg: bytes + if isinstance(obj, pq.PGconn): - msg = obj.error_message + bmsg = obj.error_message # strip severity and whitespaces - if msg: - msg = msg.splitlines()[0].split(b":", 1)[-1].strip() + if bmsg: + bmsg = bmsg.splitlines()[0].split(b":", 1)[-1].strip() elif isinstance(obj, pq.PGresult): - msg = obj.error_field(pq.DiagnosticField.MESSAGE_PRIMARY) - if not msg: - msg = obj.error_message + bmsg = obj.error_field(pq.DiagnosticField.MESSAGE_PRIMARY) + if not bmsg: + bmsg = obj.error_message # strip severity and whitespaces - if msg: - msg = msg.splitlines()[0].split(b":", 1)[-1].strip() + if bmsg: + bmsg = bmsg.splitlines()[0].split(b":", 1)[-1].strip() else: raise TypeError( f"PGconn or PGresult expected, got {type(obj).__name__}" ) - if msg: - msg = msg.decode("utf8", "replace") # TODO: or in connection encoding? + if bmsg: + msg = bmsg.decode( + "utf8", "replace" + ) # TODO: or in connection encoding? else: msg = "no details available" diff --git a/psycopg3/pq/pq_ctypes.py b/psycopg3/pq/pq_ctypes.py index a79bb4298..0f1b47816 100644 --- a/psycopg3/pq/pq_ctypes.py +++ b/psycopg3/pq/pq_ctypes.py @@ -8,9 +8,9 @@ implementation. # Copyright (C) 2020 The Psycopg Team -from collections import namedtuple from ctypes import string_at from ctypes import c_char_p, c_int, pointer +from typing import Any, List, Optional, Sequence from .enums import ( ConnStatus, @@ -18,14 +18,16 @@ from .enums import ( ExecStatus, TransactionStatus, Ping, + DiagnosticField, + Format, ) -from .misc import error_message +from .misc import error_message, ConninfoOption from . import _pq_ctypes as impl from ..exceptions import OperationalError -def version(): - return impl.PQlibVersion() +def version() -> int: + return impl.PQlibVersion() # type: ignore class PQerror(OperationalError): @@ -35,18 +37,16 @@ class PQerror(OperationalError): class PGconn: __slots__ = ("pgconn_ptr",) - def __init__(self, pgconn_ptr): - self.pgconn_ptr = pgconn_ptr + def __init__(self, pgconn_ptr: impl.PGconn_struct): + self.pgconn_ptr: Optional[impl.PGconn_struct] = pgconn_ptr - def __del__(self): + def __del__(self) -> None: self.finish() @classmethod - def connect(cls, conninfo): + def connect(cls, conninfo: bytes) -> "PGconn": if not isinstance(conninfo, bytes): - raise TypeError( - "bytes expected, got %s instead" % type(conninfo).__name__ - ) + raise TypeError(f"bytes expected, got {type(conninfo)} instead") pgconn_ptr = impl.PQconnectdb(conninfo) if not pgconn_ptr: @@ -54,28 +54,26 @@ class PGconn: return cls(pgconn_ptr) @classmethod - def connect_start(cls, conninfo): + def connect_start(cls, conninfo: bytes) -> "PGconn": if not isinstance(conninfo, bytes): - raise TypeError( - "bytes expected, got %s instead" % type(conninfo).__name__ - ) + raise TypeError(f"bytes expected, got {type(conninfo)} instead") pgconn_ptr = impl.PQconnectStart(conninfo) if not pgconn_ptr: raise MemoryError("couldn't allocate PGconn") return cls(pgconn_ptr) - def connect_poll(self): + def connect_poll(self) -> PollingStatus: rv = impl.PQconnectPoll(self.pgconn_ptr) return PollingStatus(rv) - def finish(self): + def finish(self) -> None: self.pgconn_ptr, p = None, self.pgconn_ptr if p is not None: impl.PQfinish(p) @property - def info(self): + def info(self) -> List["ConninfoOption"]: opts = impl.PQconninfo(self.pgconn_ptr) if not opts: raise MemoryError("couldn't allocate connection info") @@ -84,130 +82,124 @@ class PGconn: finally: impl.PQconninfoFree(opts) - def reset(self): + def reset(self) -> None: impl.PQreset(self.pgconn_ptr) - def reset_start(self): + def reset_start(self) -> None: if not impl.PQresetStart(self.pgconn_ptr): raise PQerror("couldn't reset connection") - def reset_poll(self): + def reset_poll(self) -> PollingStatus: rv = impl.PQresetPoll(self.pgconn_ptr) return PollingStatus(rv) @classmethod - def ping(self, conninfo): + def ping(self, conninfo: bytes) -> Ping: if not isinstance(conninfo, bytes): - raise TypeError( - "bytes expected, got %s instead" % type(conninfo).__name__ - ) + raise TypeError(f"bytes expected, got {type(conninfo)} instead") rv = impl.PQping(conninfo) return Ping(rv) @property - def db(self): - return impl.PQdb(self.pgconn_ptr) + def db(self) -> bytes: + return impl.PQdb(self.pgconn_ptr) # type: ignore @property - def user(self): - return impl.PQuser(self.pgconn_ptr) + def user(self) -> bytes: + return impl.PQuser(self.pgconn_ptr) # type: ignore @property - def password(self): - return impl.PQpass(self.pgconn_ptr) + def password(self) -> bytes: + return impl.PQpass(self.pgconn_ptr) # type: ignore @property - def host(self): - return impl.PQhost(self.pgconn_ptr) + def host(self) -> bytes: + return impl.PQhost(self.pgconn_ptr) # type: ignore @property - def hostaddr(self): - return impl.PQhostaddr(self.pgconn_ptr) + def hostaddr(self) -> bytes: + return impl.PQhostaddr(self.pgconn_ptr) # type: ignore @property - def port(self): - return impl.PQport(self.pgconn_ptr) + def port(self) -> bytes: + return impl.PQport(self.pgconn_ptr) # type: ignore @property - def tty(self): - return impl.PQtty(self.pgconn_ptr) + def tty(self) -> bytes: + return impl.PQtty(self.pgconn_ptr) # type: ignore @property - def options(self): - return impl.PQoptions(self.pgconn_ptr) + def options(self) -> bytes: + return impl.PQoptions(self.pgconn_ptr) # type: ignore @property - def status(self): + def status(self) -> ConnStatus: rv = impl.PQstatus(self.pgconn_ptr) return ConnStatus(rv) @property - def transaction_status(self): + def transaction_status(self) -> TransactionStatus: rv = impl.PQtransactionStatus(self.pgconn_ptr) return TransactionStatus(rv) - def parameter_status(self, name): - return impl.PQparameterStatus(self.pgconn_ptr, name) + def parameter_status(self, name: bytes) -> bytes: + return impl.PQparameterStatus(self.pgconn_ptr, name) # type: ignore @property - def protocol_version(self): - return impl.PQprotocolVersion(self.pgconn_ptr) + def protocol_version(self) -> int: + return impl.PQprotocolVersion(self.pgconn_ptr) # type: ignore @property - def server_version(self): - return impl.PQserverVersion(self.pgconn_ptr) + def server_version(self) -> int: + return impl.PQserverVersion(self.pgconn_ptr) # type: ignore @property - def error_message(self): - return impl.PQerrorMessage(self.pgconn_ptr) + def error_message(self) -> bytes: + return impl.PQerrorMessage(self.pgconn_ptr) # type: ignore @property - def socket(self): - return impl.PQsocket(self.pgconn_ptr) + def socket(self) -> int: + return impl.PQsocket(self.pgconn_ptr) # type: ignore @property - def backend_pid(self): - return impl.PQbackendPID(self.pgconn_ptr) + def backend_pid(self) -> int: + return impl.PQbackendPID(self.pgconn_ptr) # type: ignore @property - def needs_password(self): + def needs_password(self) -> bool: return bool(impl.PQconnectionNeedsPassword(self.pgconn_ptr)) @property - def used_password(self): + def used_password(self) -> bool: return bool(impl.PQconnectionUsedPassword(self.pgconn_ptr)) @property - def ssl_in_use(self): + def ssl_in_use(self) -> bool: return bool(impl.PQsslInUse(self.pgconn_ptr)) - def exec_(self, command): + def exec_(self, command: bytes) -> "PGresult": if not isinstance(command, bytes): - raise TypeError( - "bytes expected, got %s instead" % type(command).__name__ - ) + raise TypeError(f"bytes expected, got {type(command)} instead") rv = impl.PQexec(self.pgconn_ptr, command) if not rv: raise MemoryError("couldn't allocate PGresult") return PGresult(rv) - def send_query(self, command): + def send_query(self, command: bytes) -> None: if not isinstance(command, bytes): - raise TypeError( - "bytes expected, got %s instead" % type(command).__name__ - ) + raise TypeError(f"bytes expected, got {type(command)} instead") if not impl.PQsendQuery(self.pgconn_ptr, command): raise PQerror(f"sending query failed: {error_message(self)}") def exec_params( self, - command, - param_values, - param_types=None, - param_formats=None, - result_format=0, - ): + command: bytes, + param_values: List[Optional[bytes]], + param_types: Optional[List[int]] = None, + param_formats: Optional[List[Format]] = None, + result_format: Format = Format.TEXT, + ) -> "PGresult": args = self._query_params_args( command, param_values, param_types, param_formats, result_format ) @@ -218,12 +210,12 @@ class PGconn: def send_query_params( self, - command, - param_values, - param_types=None, - param_formats=None, - result_format=0, - ): + command: bytes, + param_values: List[Optional[bytes]], + param_types: Optional[List[int]] = None, + param_formats: Optional[List[Format]] = None, + result_format: Format = Format.TEXT, + ) -> None: args = self._query_params_args( command, param_values, param_types, param_formats, result_format ) @@ -233,12 +225,15 @@ class PGconn: ) def _query_params_args( - self, command, param_values, param_types, param_formats, result_format, - ): + self, + command: bytes, + param_values: List[Optional[bytes]], + param_types: Optional[List[int]] = None, + param_formats: Optional[List[Format]] = None, + result_format: Format = Format.TEXT, + ) -> Any: if not isinstance(command, bytes): - raise TypeError( - "bytes expected, got %s instead" % type(command).__name__ - ) + raise TypeError(f"bytes expected, got {type(command)} instead") nparams = len(param_values) if nparams: @@ -247,7 +242,7 @@ class PGconn: *(len(p) if p is not None else 0 for p in param_values) ) else: - aparams = alenghts = None + aparams = alenghts = None # type: ignore if param_types is None: atypes = None @@ -280,16 +275,18 @@ class PGconn: result_format, ) - def prepare(self, name, command, param_types=None): + def prepare( + self, + name: bytes, + command: bytes, + param_types: Optional[List[int]] = None, + ) -> "PGresult": if not isinstance(name, bytes): - raise TypeError( - "'name' must be bytes, got %s instead" % type(name).__name__ - ) + raise TypeError(f"'name' must be bytes, got {type(name)} instead") if not isinstance(command, bytes): raise TypeError( - "'command' must be bytes, got %s instead" - % type(command).__name__ + f"'command' must be bytes, got {type(command)} instead" ) if param_types is None: @@ -305,12 +302,14 @@ class PGconn: return PGresult(rv) def exec_prepared( - self, name, param_values, param_formats=None, result_format=0 - ): + self, + name: bytes, + param_values: List[bytes], + param_formats: Optional[List[int]] = None, + result_format: int = 0, + ) -> "PGresult": if not isinstance(name, bytes): - raise TypeError( - "'name' must be bytes, got %s instead" % type(name).__name__ - ) + raise TypeError(f"'name' must be bytes, got {type(name)} instead") nparams = len(param_values) if nparams: @@ -319,7 +318,7 @@ class PGconn: *(len(p) if p is not None else 0 for p in param_values) ) else: - aparams = alenghts = None + aparams = alenghts = None # type: ignore if param_formats is None: aformats = None @@ -344,53 +343,49 @@ class PGconn: raise MemoryError("couldn't allocate PGresult") return PGresult(rv) - def describe_prepared(self, name): + def describe_prepared(self, name: bytes) -> "PGresult": if not isinstance(name, bytes): - raise TypeError( - "'name' must be bytes, got %s instead" % type(name).__name__ - ) + raise TypeError(f"'name' must be bytes, got {type(name)} instead") rv = impl.PQdescribePrepared(self.pgconn_ptr, name) if not rv: raise MemoryError("couldn't allocate PGresult") return PGresult(rv) - def describe_portal(self, name): + def describe_portal(self, name: bytes) -> "PGresult": if not isinstance(name, bytes): - raise TypeError( - "'name' must be bytes, got %s instead" % type(name).__name__ - ) + raise TypeError(f"'name' must be bytes, got {type(name)} instead") rv = impl.PQdescribePortal(self.pgconn_ptr, name) if not rv: raise MemoryError("couldn't allocate PGresult") return PGresult(rv) - def get_result(self): + def get_result(self) -> Optional["PGresult"]: rv = impl.PQgetResult(self.pgconn_ptr) return PGresult(rv) if rv else None - def consume_input(self): + def consume_input(self) -> None: if 1 != impl.PQconsumeInput(self.pgconn_ptr): raise PQerror(f"consuming input failed: {error_message(self)}") - def is_busy(self): - return impl.PQisBusy(self.pgconn_ptr) + def is_busy(self) -> int: + return impl.PQisBusy(self.pgconn_ptr) # type: ignore @property - def nonblocking(self): - return impl.PQisnonblocking(self.pgconn_ptr) + def nonblocking(self) -> int: + return impl.PQisnonblocking(self.pgconn_ptr) # type: ignore @nonblocking.setter - def nonblocking(self, arg): + def nonblocking(self, arg: int) -> None: if 0 > impl.PQsetnonblocking(self.pgconn_ptr, arg): raise PQerror(f"setting nonblocking failed: {error_message(self)}") - def flush(self): - rv = impl.PQflush(self.pgconn_ptr) + def flush(self) -> int: + rv: int = impl.PQflush(self.pgconn_ptr) if rv < 0: raise PQerror(f"flushing failed: {error_message(self)}") return rv - def make_empty_result(self, exec_status): + def make_empty_result(self, exec_status: ExecStatus) -> "PGresult": rv = impl.PQmakeEmptyPGresult(self.pgconn_ptr, exec_status) if not rv: raise MemoryError("couldn't allocate empty PGresult") @@ -400,64 +395,72 @@ class PGconn: class PGresult: __slots__ = ("pgresult_ptr",) - def __init__(self, pgresult_ptr): - self.pgresult_ptr = pgresult_ptr + def __init__(self, pgresult_ptr: type): + self.pgresult_ptr: Optional[type] = pgresult_ptr - def __del__(self): + def __del__(self) -> None: self.clear() - def clear(self): + def clear(self) -> None: self.pgresult_ptr, p = None, self.pgresult_ptr if p is not None: impl.PQclear(p) @property - def status(self): + def status(self) -> ExecStatus: rv = impl.PQresultStatus(self.pgresult_ptr) return ExecStatus(rv) @property - def error_message(self): - return impl.PQresultErrorMessage(self.pgresult_ptr) + def error_message(self) -> bytes: + return impl.PQresultErrorMessage(self.pgresult_ptr) # type: ignore - def error_field(self, fieldcode): - return impl.PQresultErrorField(self.pgresult_ptr, fieldcode) + def error_field(self, fieldcode: DiagnosticField) -> bytes: + return impl.PQresultErrorField( # type: ignore + self.pgresult_ptr, fieldcode + ) @property - def ntuples(self): - return impl.PQntuples(self.pgresult_ptr) + def ntuples(self) -> int: + return impl.PQntuples(self.pgresult_ptr) # type: ignore @property - def nfields(self): - return impl.PQnfields(self.pgresult_ptr) + def nfields(self) -> int: + return impl.PQnfields(self.pgresult_ptr) # type: ignore - def fname(self, column_number): - return impl.PQfname(self.pgresult_ptr, column_number) + def fname(self, column_number: int) -> int: + return impl.PQfname(self.pgresult_ptr, column_number) # type: ignore - def ftable(self, column_number): - return impl.PQftable(self.pgresult_ptr, column_number) + def ftable(self, column_number: int) -> int: + return impl.PQftable(self.pgresult_ptr, column_number) # type: ignore - def ftablecol(self, column_number): - return impl.PQftablecol(self.pgresult_ptr, column_number) + def ftablecol(self, column_number: int) -> int: + return impl.PQftablecol( # type: ignore + self.pgresult_ptr, column_number + ) - def fformat(self, column_number): - return impl.PQfformat(self.pgresult_ptr, column_number) + def fformat(self, column_number: int) -> Format: + return impl.PQfformat(self.pgresult_ptr, column_number) # type: ignore - def ftype(self, column_number): - return impl.PQftype(self.pgresult_ptr, column_number) + def ftype(self, column_number: int) -> int: + return impl.PQftype(self.pgresult_ptr, column_number) # type: ignore - def fmod(self, column_number): - return impl.PQfmod(self.pgresult_ptr, column_number) + def fmod(self, column_number: int) -> int: + return impl.PQfmod(self.pgresult_ptr, column_number) # type: ignore - def fsize(self, column_number): - return impl.PQfsize(self.pgresult_ptr, column_number) + def fsize(self, column_number: int) -> int: + return impl.PQfsize(self.pgresult_ptr, column_number) # type: ignore @property - def binary_tuples(self): - return impl.PQbinaryTuples(self.pgresult_ptr) - - def get_value(self, row_number, column_number): - length = impl.PQgetlength(self.pgresult_ptr, row_number, column_number) + def binary_tuples(self) -> int: + return impl.PQbinaryTuples(self.pgresult_ptr) # type: ignore + + def get_value( + self, row_number: int, column_number: int + ) -> Optional[bytes]: + length: int = impl.PQgetlength( + self.pgresult_ptr, row_number, column_number + ) if length: v = impl.PQgetvalue(self.pgresult_ptr, row_number, column_number) return string_at(v, length) @@ -468,35 +471,31 @@ class PGresult: return b"" @property - def nparams(self): - return impl.PQnparams(self.pgresult_ptr) + def nparams(self) -> int: + return impl.PQnparams(self.pgresult_ptr) # type: ignore - def param_type(self, param_number): - return impl.PQparamtype(self.pgresult_ptr, param_number) + def param_type(self, param_number: int) -> int: + return impl.PQparamtype( # type: ignore + self.pgresult_ptr, param_number + ) @property - def command_status(self): - return impl.PQcmdStatus(self.pgresult_ptr) + def command_status(self) -> bytes: + return impl.PQcmdStatus(self.pgresult_ptr) # type: ignore @property - def command_tuples(self): + def command_tuples(self) -> Optional[int]: rv = impl.PQcmdTuples(self.pgresult_ptr) - if rv: - return int(rv) + return int(rv) if rv else None @property - def oid_value(self): - return impl.PQoidValue(self.pgresult_ptr) - - -ConninfoOption = namedtuple( - "ConninfoOption", "keyword envvar compiled val label dispatcher dispsize" -) + def oid_value(self) -> int: + return impl.PQoidValue(self.pgresult_ptr) # type: ignore class Conninfo: @classmethod - def get_defaults(cls): + def get_defaults(cls) -> List[ConninfoOption]: opts = impl.PQconndefaults() if not opts: raise MemoryError("couldn't allocate connection defaults") @@ -506,11 +505,9 @@ class Conninfo: impl.PQconninfoFree(opts) @classmethod - def parse(cls, conninfo): + def parse(cls, conninfo: bytes) -> List[ConninfoOption]: if not isinstance(conninfo, bytes): - raise TypeError( - "bytes expected, got %s instead" % type(conninfo).__name__ - ) + raise TypeError(f"bytes expected, got {type(conninfo)} instead") errmsg = c_char_p() rv = impl.PQconninfoParse(conninfo, pointer(errmsg)) @@ -518,7 +515,7 @@ class Conninfo: if not errmsg: raise MemoryError("couldn't allocate on conninfo parse") else: - exc = PQerror(errmsg.value.decode("utf8", "replace")) + exc = PQerror((errmsg.value or b"").decode("utf8", "replace")) impl.PQfreemem(errmsg) raise exc @@ -528,7 +525,9 @@ class Conninfo: impl.PQconninfoFree(rv) @classmethod - def _options_from_array(cls, opts): + def _options_from_array( + cls, opts: Sequence[impl.PQconninfoOption_struct] + ) -> List[ConninfoOption]: rv = [] skws = "keyword envvar compiled val label dispatcher".split() for opt in opts: diff --git a/psycopg3/types/numeric.py b/psycopg3/types/numeric.py index 0dd0213b8..b633de924 100644 --- a/psycopg3/types/numeric.py +++ b/psycopg3/types/numeric.py @@ -5,6 +5,7 @@ Adapters of numeric types. # Copyright (C) 2020 The Psycopg Team import codecs +from typing import Tuple from ..adaptation import Adapter, Typecaster from .oids import type_oid @@ -14,10 +15,10 @@ _decode = codecs.lookup("ascii").decode @Adapter.register(int) -def adapt_int(obj): +def adapt_int(obj: int) -> Tuple[bytes, int]: return _encode(str(obj))[0], type_oid["numeric"] @Typecaster.register(type_oid["numeric"]) -def cast_int(data): +def cast_int(data: bytes) -> int: return int(_decode(data)[0]) diff --git a/psycopg3/types/oids.py b/psycopg3/types/oids.py index 025f6a337..d81b2c865 100644 --- a/psycopg3/types/oids.py +++ b/psycopg3/types/oids.py @@ -92,7 +92,7 @@ _oids_table = [ type_oid = {name: oid for name, oid, _, _ in _oids_table} -def self_update(): +def self_update() -> None: import subprocess as sp # queries output should make black happy diff --git a/psycopg3/types/text.py b/psycopg3/types/text.py index 3d1405deb..2db4d1d30 100644 --- a/psycopg3/types/text.py +++ b/psycopg3/types/text.py @@ -5,30 +5,39 @@ Adapters of textual types. # Copyright (C) 2020 The Psycopg Team import codecs +from typing import Optional, Union -from ..adaptation import Adapter -from ..adaptation import Typecaster +from ..adaptation import Adapter, Typecaster +from ..connection import BaseConnection +from ..utils.typing import EncodeFunc, DecodeFunc from .oids import type_oid @Adapter.register(str) @Adapter.register_binary(str) class StringAdapter(Adapter): - def __init__(self, cls, conn): + def __init__(self, cls: type, conn: BaseConnection): super().__init__(cls, conn) - self.encode = ( - conn.codec if conn is not None else codecs.lookup("utf8") - ).encode - def adapt(self, obj): - return self.encode(obj)[0] + self._encode: EncodeFunc + if conn is not None: + self._encode = conn.codec.encode + else: + self._encode = codecs.lookup("utf8").encode + + def adapt(self, obj: str) -> bytes: + return self._encode(obj)[0] @Typecaster.register(type_oid["text"]) @Typecaster.register_binary(type_oid["text"]) class StringCaster(Typecaster): - def __init__(self, oid, conn): + + decode: Optional[DecodeFunc] + + def __init__(self, oid: int, conn: BaseConnection): super().__init__(oid, conn) + if conn is not None: if conn.pgenc != b"SQL_ASCII": self.decode = conn.codec.decode @@ -37,7 +46,7 @@ class StringCaster(Typecaster): else: self.decode = codecs.lookup("utf8").decode - def cast(self, data): + def cast(self, data: bytes) -> Union[bytes, str]: if self.decode is not None: return self.decode(data)[0] else: diff --git a/psycopg3/utils/queries.py b/psycopg3/utils/queries.py index 8cdd8cb3e..5ad976e5f 100644 --- a/psycopg3/utils/queries.py +++ b/psycopg3/utils/queries.py @@ -5,13 +5,28 @@ Utility module to manipulate queries # Copyright (C) 2020 The Psycopg Team import re -from collections.abc import Sequence, Mapping +from codecs import CodecInfo +from typing import ( + Any, + Dict, + List, + Mapping, + Match, + NamedTuple, + Optional, + Sequence, + Tuple, + Union, +) from .. import exceptions as exc from ..pq import Format +from .typing import Params -def query2pg(query, vars, codec): +def query2pg( + query: bytes, vars: Params, codec: CodecInfo +) -> Tuple[bytes, List[Format], Optional[List[str]]]: """ Convert Python query and params into something Postgres understands. @@ -30,6 +45,9 @@ def query2pg(query, vars, codec): ) parts = split_query(query, codec.name) + order: Optional[List[str]] = None + chunks: List[bytes] = [] + formats = [] if isinstance(vars, Sequence) and not isinstance(vars, (bytes, str)): if len(vars) != len(parts) - 1: @@ -37,32 +55,40 @@ def query2pg(query, vars, codec): f"the query has {len(parts) - 1} placeholders but" f" {len(vars)} parameters were passed" ) - if vars and not isinstance(parts[0][1], int): + if vars and not isinstance(parts[0].index, int): raise TypeError( "named placeholders require a mapping of parameters" ) - order = None + + for part in parts[:-1]: + assert isinstance(part.index, int) + chunks.append(part.pre) + chunks.append(b"$%d" % (part.index + 1)) + formats.append(part.format) elif isinstance(vars, Mapping): - if vars and len(parts) > 1 and not isinstance(parts[0][1], bytes): + if vars and len(parts) > 1 and not isinstance(parts[0][1], str): raise TypeError( "positional placeholders (%s) require a sequence of parameters" ) - seen = {} + seen: Dict[str, Tuple[bytes, Format]] = {} order = [] for part in parts[:-1]: - name = codec.decode(part[1])[0] - if name not in seen: - n = len(seen) - part[1] = n - seen[name] = (n, part[2]) - order.append(name) + assert isinstance(part.index, str) + formats.append(part.format) + chunks.append(part.pre) + if part.index not in seen: + ph = b"$%d" % (len(seen) + 1) + seen[part.index] = (ph, part.format) + order.append(part.index) + chunks.append(ph) else: - if seen[name][1] != part[2]: + if seen[part.index][1] != part.format: raise exc.ProgrammingError( - f"placeholder '{name}' cannot have different formats" + f"placeholder '{part.index}' cannot have" + f" different formats" ) - part[1] = seen[name][0] + chunks.append(seen[part.index][0]) else: raise TypeError( @@ -70,16 +96,10 @@ def query2pg(query, vars, codec): f" got {type(vars).__name__}" ) - # Assemble query and parameters - rv = [] - formats = [] - for part in parts[:-1]: - rv.append(part[0]) - rv.append(b"$%d" % (part[1] + 1)) - formats.append(part[2]) - rv.append(parts[-1][0]) + # last part + chunks.append(parts[-1].pre) - return b"".join(rv), formats, order + return b"".join(chunks), formats, order _re_placeholder = re.compile( @@ -97,35 +117,48 @@ _re_placeholder = re.compile( ) -def split_query(query, encoding="ascii"): - parts = [] +class QueryPart(NamedTuple): + pre: bytes + # TODO: mypy bug? https://github.com/python/mypy/issues/8599 + index: Union[int, str] # type: ignore + format: Format + + +def split_query(query: bytes, encoding: str = "ascii") -> List[QueryPart]: + parts: List[Tuple[bytes, Optional[Match[bytes]]]] = [] cur = 0 - # pairs [(fragment, match)], with the last match None + # pairs [(fragment, match], with the last match None m = None for m in _re_placeholder.finditer(query): pre = query[cur : m.span(0)[0]] - parts.append([pre, m, None]) + parts.append((pre, m)) cur = m.span(0)[1] if m is None: - parts.append([query, None, None]) + parts.append((query, None)) else: - parts.append([query[cur:], None, None]) + parts.append((query[cur:], None)) + + rv = [] # drop the "%%", validate i = 0 phtype = None while i < len(parts): - part = parts[i] - m = part[1] + pre, m = parts[i] if m is None: - break # last part + # last part + rv.append(QueryPart(pre, 0, Format.TEXT)) + break + ph = m.group(0) if ph == b"%%": # unescape '%%' to '%' and merge the parts - parts[i + 1][0] = part[0] + b"%" + parts[i + 1][0] + pre1, m1 = parts[i + 1] + parts[i + 1] = (pre + b"%" + pre1, m1) del parts[i] continue + if ph == b"%(": raise exc.ProgrammingError( f"incomplete placeholder:" @@ -144,28 +177,29 @@ def split_query(query, encoding="ascii"): ) # Index or name - if m.group(1) is None: - part[1] = i - else: - part[1] = m.group(1) - - # Binary format - part[2] = Format(ph[-1:] == b"b") + index: Union[int, str] + index = i if m.group(1) is None else m.group(1).decode(encoding) if phtype is None: - phtype = type(part[1]) + phtype = type(index) else: - if phtype is not type(part[1]): # noqa + if phtype is not type(index): # noqa raise exc.ProgrammingError( "positional and named placeholders cannot be mixed" ) + # Binary format + format = Format(ph[-1:] == b"b") + + rv.append(QueryPart(pre, index, format)) i += 1 - return parts + return rv -def reorder_params(params, order): +def reorder_params( + params: Mapping[str, Any], order: Sequence[str] +) -> List[str]: """ Convert a mapping of parameters into an array in a specified order """ diff --git a/psycopg3/utils/typing.py b/psycopg3/utils/typing.py new file mode 100644 index 000000000..f96576ce8 --- /dev/null +++ b/psycopg3/utils/typing.py @@ -0,0 +1,13 @@ +""" +Additional types for checking +""" + +# Copyright (C) 2020 The Psycopg Team + +from typing import Any, Callable, Mapping, Sequence, Tuple, Union + +EncodeFunc = Callable[[str], Tuple[bytes, int]] +DecodeFunc = Callable[[bytes], Tuple[str, int]] + +Query = Union[str, bytes] +Params = Union[Sequence[Any], Mapping[str, Any]] diff --git a/psycopg3/waiting.py b/psycopg3/waiting.py index 7a6580f68..59e0e1f61 100644 --- a/psycopg3/waiting.py +++ b/psycopg3/waiting.py @@ -7,6 +7,7 @@ Code concerned with waiting in different contexts (blocking, async, etc). from enum import Enum from select import select +from typing import Generator, Tuple, TypeVar from asyncio import get_event_loop, Event from . import exceptions as exc @@ -15,8 +16,10 @@ from . import exceptions as exc Wait = Enum("Wait", "R W RW") Ready = Enum("Ready", "R W") +RV = TypeVar("RV") -def wait_select(gen): + +def wait_select(gen: Generator[Tuple[int, Wait], Ready, RV]) -> RV: """ Wait on the behalf of a generator using select(). @@ -48,10 +51,11 @@ def wait_select(gen): else: raise exc.InternalError("bad poll status: %s") except StopIteration as e: - return e.args[0] + rv: RV = e.args[0] + return rv -async def wait_async(gen): +async def wait_async(gen: Generator[Tuple[int, Wait], Ready, RV]) -> RV: """ Coroutine waiting for a generator to complete. @@ -65,9 +69,9 @@ async def wait_async(gen): # Not sure this is the best implementation but it's a start. ev = Event() loop = get_event_loop() - ready = None + ready = Ready.R - def wakeup(state): + def wakeup(state: Ready) -> None: nonlocal ready ready = state ev.set() @@ -96,4 +100,5 @@ async def wait_async(gen): else: raise exc.InternalError("bad poll status: %s") except StopIteration as e: - return e.args[0] + rv: RV = e.args[0] + return rv diff --git a/tests/test_query.py b/tests/test_query.py index 8aa6ec952..032301781 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -8,34 +8,22 @@ from psycopg3.utils.queries import split_query, query2pg, reorder_params @pytest.mark.parametrize( "input, want", [ - (b"", [[b"", None, None]]), - (b"foo bar", [[b"foo bar", None, None]]), - (b"foo %% bar", [[b"foo % bar", None, None]]), - (b"%s", [[b"", 0, False], [b"", None, None]]), - (b"%s foo", [[b"", 0, False], [b" foo", None, None]]), - (b"%b foo", [[b"", 0, True], [b" foo", None, None]]), - (b"foo %s", [[b"foo ", 0, False], [b"", None, None]]), - (b"foo %%%s bar", [[b"foo %", 0, False], [b" bar", None, None]]), - ( - b"foo %(name)s bar", - [[b"foo ", b"name", False], [b" bar", None, None]], - ), + (b"", [(b"", 0, 0)]), + (b"foo bar", [(b"foo bar", 0, 0)]), + (b"foo %% bar", [(b"foo % bar", 0, 0)]), + (b"%s", [(b"", 0, 0), (b"", 0, 0)]), + (b"%s foo", [(b"", 0, 0), (b" foo", 0, 0)]), + (b"%b foo", [(b"", 0, 1), (b" foo", 0, 0)]), + (b"foo %s", [(b"foo ", 0, 0), (b"", 0, 0)]), + (b"foo %%%s bar", [(b"foo %", 0, 0), (b" bar", 0, 0)]), + (b"foo %(name)s bar", [(b"foo ", "name", 0), (b" bar", 0, 0)]), ( b"foo %(name)s %(name)b bar", - [ - [b"foo ", b"name", False], - [b" ", b"name", True], - [b" bar", None, None], - ], + [(b"foo ", "name", 0), (b" ", "name", 1), (b" bar", 0, 0)], ), ( b"foo %s%b bar %s baz", - [ - [b"foo ", 0, False], - [b"", 1, True], - [b" bar ", 2, False], - [b" baz", None, None], - ], + [(b"foo ", 0, 0), (b"", 1, 1), (b" bar ", 2, 0), (b" baz", 0, 0)], ), ], ) diff --git a/tox.ini b/tox.ini index 2f9f668ab..0a1b28a35 100644 --- a/tox.ini +++ b/tox.ini @@ -24,7 +24,7 @@ exclude = env, .tox ignore = W503, E203 [mypy] -files = psycopg3, tests, setup.py +files = psycopg3, setup.py warn_unused_ignores = True [mypy-pytest]