# Copyright (C) 2020 The Psycopg Team
import codecs
-from typing import Dict, Tuple
+from typing import (
+ Any,
+ Callable,
+ cast,
+ Dict,
+ Generator,
+ List,
+ Optional,
+ Sequence,
+ Tuple,
+ Union,
+)
from functools import partial
from . import exceptions as exc
-from .pq import Format
+from .pq import Format, PGresult
from .cursor import BaseCursor
from .types.oids import type_oid, INVALID_OID
from .connection import BaseConnection
+from .utils.typing import DecodeFunc
+
+
+# Type system
+
+AdaptContext = Union[BaseConnection, BaseCursor]
+
+MaybeOid = Union[Optional[bytes], Tuple[Optional[bytes], int]]
+AdapterFunc = Callable[[Any], MaybeOid]
+AdapterType = Union["Adapter", AdapterFunc]
+AdaptersMap = Dict[Tuple[type, Format], AdapterType]
+
+TypecasterFunc = Callable[[bytes], Any]
+TypecasterType = Union["Typecaster", TypecasterFunc]
+TypecastersMap = Dict[Tuple[int, Format], TypecasterType]
class Adapter:
- globals: Dict[Tuple[type, Format], "Adapter"] = {} # TODO: incomplete type
+ globals: AdaptersMap = {}
- def __init__(self, cls, conn):
+ def __init__(self, cls: type, conn: BaseConnection):
self.cls = cls
self.conn = conn
- def adapt(self, obj):
+ def adapt(self, obj: Any) -> Union[bytes, Tuple[bytes, int]]:
raise NotImplementedError()
@staticmethod
- def register(cls, adapter=None, context=None, format=Format.TEXT):
+ def register(
+ cls: type,
+ adapter: Optional[AdapterType] = None,
+ context: Optional[AdaptContext] = None,
+ format: Format = Format.TEXT,
+ ) -> AdapterType:
if adapter is None:
# used as decorator
return partial(Adapter.register, cls, format=format)
return adapter
@staticmethod
- def register_binary(cls, adapter=None, context=None):
+ def register_binary(
+ cls: type,
+ adapter: Optional[AdapterType] = None,
+ context: Optional[AdaptContext] = None,
+ ) -> AdapterType:
return Adapter.register(cls, adapter, context, format=Format.BINARY)
class Typecaster:
- # TODO: incomplete type
- globals: Dict[Tuple[type, Format], "Typecaster"] = {}
+ globals: TypecastersMap = {}
- def __init__(self, oid, conn):
+ def __init__(self, oid: int, conn: Optional[BaseConnection]):
self.oid = oid
self.conn = conn
- def cast(self, data):
+ def cast(self, data: bytes) -> Any:
raise NotImplementedError()
@staticmethod
- def register(oid, caster=None, context=None, format=Format.TEXT):
+ def register(
+ oid: int,
+ caster: Optional[TypecasterType] = None,
+ context: Optional[AdaptContext] = None,
+ format: Format = Format.TEXT,
+ ) -> TypecasterType:
if caster is None:
# used as decorator
return partial(Typecaster.register, oid, format=format)
f" got {caster} instead"
)
- where = context.adapters if context is not None else Typecaster.globals
+ where = context.casters if context is not None else Typecaster.globals
where[oid, format] = caster
return caster
@staticmethod
- def register_binary(oid, caster=None, context=None):
+ def register_binary(
+ oid: int,
+ caster: Optional[TypecasterType] = None,
+ context: Optional[AdaptContext] = None,
+ ) -> TypecasterType:
return Typecaster.register(oid, caster, context, format=Format.BINARY)
state so adapting several values of the same type can use optimisations.
"""
- def __init__(self, context):
+ connection: Optional[BaseConnection]
+ cursor: Optional[BaseCursor]
+
+ def __init__(self, context: Optional[AdaptContext]):
if context is None:
self.connection = None
self.cursor = None
)
# mapping class, fmt -> adaptation function
- self._adapt_funcs = {}
+ self._adapt_funcs: Dict[Tuple[type, Format], AdapterFunc] = {}
# mapping oid, fmt -> cast function
- self._cast_funcs = {}
+ self._cast_funcs: Dict[Tuple[int, Format], TypecasterFunc] = {}
# The result to return values from
- self._result = None
+ self._result: Optional[PGresult] = None
# sequence of cast function from value to python
# the length of the result columns
- self._row_casters = None
+ self._row_casters: List[TypecasterFunc] = []
@property
- def result(self):
+ def result(self) -> Optional[PGresult]:
return self._result
@result.setter
- def result(self, result):
+ def result(self, result: PGresult) -> None:
if self._result is result:
return
func = self.get_cast_function(oid, fmt)
rc.append(func)
- def adapt_sequence(self, objs, fmts):
+ def adapt_sequence(
+ self, objs: Sequence[Any], fmts: Sequence[Format]
+ ) -> Tuple[List[Optional[bytes]], List[int]]:
out = []
types = []
return out, types
- def adapt(self, obj, fmt):
+ def adapt(self, obj: None, fmt: Format) -> MaybeOid:
if obj is None:
return None, type_oid["text"]
func = self.get_adapt_function(cls, fmt)
return func(obj)
- def get_adapt_function(self, cls, fmt):
+ def get_adapt_function(self, cls: type, fmt: Format) -> AdapterFunc:
try:
return self._adapt_funcs[cls, fmt]
except KeyError:
adapter = self.lookup_adapter(cls, fmt)
if isinstance(adapter, type):
- adapter = adapter(cls, self.connection).adapt
-
- return adapter
+ return adapter(cls, self.connection).adapt
+ else:
+ return cast(AdapterFunc, adapter)
- def lookup_adapter(self, cls, fmt):
+ def lookup_adapter(self, cls: type, fmt: Format) -> AdapterType:
key = (cls, fmt)
cur = self.cursor
f"cannot adapt type {cls.__name__} to format {Format(fmt).name}"
)
- def cast_row(self, result, n):
+ def cast_row(self, result: PGresult, n: int) -> Generator[Any, None, None]:
self.result = result
for col, func in enumerate(self._row_casters):
v = func(v)
yield v
- def get_cast_function(self, oid, fmt):
+ def get_cast_function(self, oid: int, fmt: Format) -> TypecasterFunc:
try:
return self._cast_funcs[oid, fmt]
except KeyError:
caster = self.lookup_caster(oid, fmt)
if isinstance(caster, type):
- caster = caster(oid, self.connection).cast
-
- return caster
+ return caster(oid, self.connection).cast
+ else:
+ return cast(TypecasterFunc, caster)
- def lookup_caster(self, oid, fmt):
+ def lookup_caster(self, oid: int, fmt: Format) -> TypecasterType:
key = (oid, fmt)
cur = self.cursor
Fallback object to convert unknown types to Python
"""
- def __init__(self, oid, conn):
+ def __init__(self, oid: int, conn: Optional[BaseConnection]):
super().__init__(oid, conn)
+ self.decode: DecodeFunc
if conn is not None:
self.decode = conn.codec.decode
else:
self.decode = codecs.lookup("utf8").decode
- def cast(self, data):
+ def cast(self, data: bytes) -> str:
return self.decode(data)[0]
@Typecaster.register_binary(INVALID_OID)
-def cast_unknown(data):
+def cast_unknown(data: bytes) -> bytes:
return data
import logging
import asyncio
import threading
+from typing import (
+ cast,
+ Any,
+ Generator,
+ List,
+ Optional,
+ Tuple,
+ Type,
+ TypeVar,
+ TYPE_CHECKING,
+)
from . import pq
from . import exceptions as exc
logger = logging.getLogger(__name__)
+ConnectGen = Generator[Tuple[int, Wait], Ready, pq.PGconn]
+QueryGen = Generator[Tuple[int, Wait], Ready, List[pq.PGresult]]
+RV = TypeVar("RV")
+
+if TYPE_CHECKING:
+ from .adaptation import AdaptersMap, TypecastersMap
+
class BaseConnection:
"""
allow different interfaces (sync/async).
"""
- def __init__(self, pgconn):
+ def __init__(self, pgconn: pq.PGconn):
self.pgconn = pgconn
- self.cursor_factory = None
- self.adapters = {}
- self.casters = {}
+ self.cursor_factory = cursor.BaseCursor
+ self.adapters: AdaptersMap = {}
+ self.casters: TypecastersMap = {}
# name of the postgres encoding (in bytes)
- self.pgenc = None
+ self.pgenc = b""
- def cursor(self, name=None, binary=False):
+ def cursor(
+ self, name: Optional[str] = None, binary: bool = False
+ ) -> cursor.BaseCursor:
if name is not None:
raise NotImplementedError
return self.cursor_factory(self, binary=binary)
@property
- def codec(self):
+ def codec(self) -> codecs.CodecInfo:
# TODO: utf8 fastpath?
pgenc = self.pgconn.parameter_status(b"client_encoding")
if self.pgenc != pgenc:
# for unknown encodings and SQL_ASCII be strict and use ascii
- pyenc = pq.py_codecs.get(pgenc.decode("ascii"), "ascii")
+ pyenc = pq.py_codecs.get(pgenc.decode("ascii")) or "ascii"
self._codec = codecs.lookup(pyenc)
self.pgenc = pgenc
return self._codec
- def encode(self, s):
+ def encode(self, s: str) -> bytes:
return self.codec.encode(s)[0]
- def decode(self, b):
+ def decode(self, b: bytes) -> str:
return self.codec.decode(b)[0]
@classmethod
- def _connect_gen(cls, conninfo):
+ def _connect_gen(cls, conninfo: str) -> ConnectGen:
"""
Generator to create a database connection without blocking.
generator can be restarted sending the appropriate `Ready` state when
the file descriptor is ready.
"""
- conninfo = conninfo.encode("utf8")
-
- conn = pq.PGconn.connect_start(conninfo)
+ conn = pq.PGconn.connect_start(conninfo.encode("utf8"))
logger.debug("connection started, status %s", conn.status.name)
while 1:
if conn.status == pq.ConnStatus.BAD:
return conn
@classmethod
- def _exec_gen(cls, pgconn):
+ def _exec_gen(cls, pgconn: pq.PGconn) -> QueryGen:
"""
Generator returning query results without blocking.
Return the list of results returned by the database (whether success
or error).
"""
- results = []
+ results: List[pq.PGresult] = []
while 1:
f = pgconn.flush()
This class implements a DBAPI-compliant interface.
"""
- def __init__(self, pgconn):
+ cursor_factory: Type[cursor.Cursor]
+
+ def __init__(self, pgconn: pq.PGconn):
super().__init__(pgconn)
self.lock = threading.Lock()
self.cursor_factory = cursor.Cursor
@classmethod
- def connect(cls, conninfo, connection_factory=None, **kwargs):
+ def connect(
+ cls, conninfo: str, connection_factory: Any = None, **kwargs: Any
+ ) -> "Connection":
if connection_factory is not None:
raise NotImplementedError()
conninfo = make_conninfo(conninfo, **kwargs)
pgconn = cls.wait(gen)
return cls(pgconn)
- def commit(self):
+ def cursor(
+ self, name: Optional[str] = None, binary: bool = False
+ ) -> cursor.Cursor:
+ cur = super().cursor(name, binary)
+ return cast(cursor.Cursor, cur)
+
+ def commit(self) -> None:
self._exec_commit_rollback(b"commit")
- def rollback(self):
+ def rollback(self) -> None:
self._exec_commit_rollback(b"rollback")
- def _exec_commit_rollback(self, command):
+ def _exec_commit_rollback(self, command: bytes) -> None:
with self.lock:
status = self.pgconn.transaction_status
if status == pq.TransactionStatus.IDLE:
)
@classmethod
- def wait(cls, gen):
+ def wait(cls, gen: Generator[Tuple[int, Wait], Ready, RV]) -> RV:
return wait_select(gen)
methods implemented as coroutines.
"""
- def __init__(self, pgconn):
+ cursor_factory: Type[cursor.AsyncCursor]
+
+ def __init__(self, pgconn: pq.PGconn):
super().__init__(pgconn)
self.lock = asyncio.Lock()
self.cursor_factory = cursor.AsyncCursor
@classmethod
- async def connect(cls, conninfo, **kwargs):
+ async def connect(cls, conninfo: str, **kwargs: Any) -> "AsyncConnection":
conninfo = make_conninfo(conninfo, **kwargs)
gen = cls._connect_gen(conninfo)
pgconn = await cls.wait(gen)
return cls(pgconn)
- async def commit(self):
+ def cursor(
+ self, name: Optional[str] = None, binary: bool = False
+ ) -> cursor.AsyncCursor:
+ cur = super().cursor(name, binary)
+ return cast(cursor.AsyncCursor, cur)
+
+ async def commit(self) -> None:
await self._exec_commit_rollback(b"commit")
- async def rollback(self):
+ async def rollback(self) -> None:
await self._exec_commit_rollback(b"rollback")
- async def _exec_commit_rollback(self, command):
+ async def _exec_commit_rollback(self, command: bytes) -> None:
async with self.lock:
status = self.pgconn.transaction_status
if status == pq.TransactionStatus.IDLE:
)
@classmethod
- async def wait(cls, gen):
+ async def wait(cls, gen: Generator[Tuple[int, Wait], Ready, RV]) -> RV:
return await wait_async(gen)
import re
+from typing import Any, Dict, List
from . import pq
from . import exceptions as exc
-def make_conninfo(conninfo=None, **kwargs):
+def make_conninfo(conninfo: str = "", **kwargs: Any) -> str:
"""
Merge a string and keyword params into a single conninfo string.
Raise ProgrammingError if the input don't make a valid conninfo.
"""
- if conninfo is None and not kwargs:
+ if not conninfo and not kwargs:
return ""
# If no kwarg is specified don't mung the conninfo but check if it's correct
return conninfo
-def conninfo_to_dict(conninfo):
+def conninfo_to_dict(conninfo: str) -> Dict[str, str]:
"""
Convert the *conninfo* string into a dictionary of parameters.
}
-def _parse_conninfo(conninfo):
+def _parse_conninfo(conninfo: str) -> List[pq.ConninfoOption]:
"""
Verify that *conninfo* is a valid connection string.
raise exc.ProgrammingError(str(e))
-def _param_escape(
- s, re_escape=re.compile(r"([\\'])"), re_space=re.compile(r"\s")
-):
+re_escape = re.compile(r"([\\'])")
+re_space = re.compile(r"\s")
+
+
+def _param_escape(s: str) -> str:
"""
Apply the escaping rule required by PQconnectdb
"""
# Copyright (C) 2020 The Psycopg Team
+from typing import Any, List, Mapping, Optional, Sequence, Tuple, TYPE_CHECKING
+
from . import exceptions as exc
-from .pq import error_message, DiagnosticField, ExecStatus
+from .pq import error_message, DiagnosticField, ExecStatus, PGresult, Format
from .utils.queries import query2pg, reorder_params
+from .utils.typing import Query, Params
+
+if TYPE_CHECKING:
+ from .connection import (
+ BaseConnection,
+ Connection,
+ AsyncConnection,
+ QueryGen,
+ )
+ from .adaptation import AdaptersMap, TypecastersMap
class BaseCursor:
- def __init__(self, conn, binary=False):
+ def __init__(self, conn: "BaseConnection", binary: bool = False):
self.conn = conn
self.binary = binary
- self.adapters = {}
- self.casters = {}
+ self.adapters: AdaptersMap = {}
+ self.casters: TypecastersMap = {}
self._reset()
- def _reset(self):
+ def _reset(self) -> None:
from .adaptation import Transformer
- self._results = []
- self._result = None
+ self._results: List[PGresult] = []
+ self._result: Optional[PGresult] = None
self._pos = 0
self._iresult = 0
self._transformer = Transformer(self)
- def _execute_send(self, query, vars):
+ def _execute_send(
+ self, query: Query, vars: Optional[Params]
+ ) -> "QueryGen":
# Implement part of execute() before waiting common to sync and async
self._reset()
query, formats, order = query2pg(query, vars, codec)
if vars:
if order is not None:
+ assert isinstance(vars, Mapping)
vars = reorder_params(vars, order)
+ assert isinstance(vars, Sequence)
params, types = self._transformer.adapt_sequence(vars, formats)
self.conn.pgconn.send_query_params(
query,
params,
param_formats=formats,
param_types=types,
- result_format=int(self.binary),
+ result_format=Format(self.binary),
)
else:
self.conn.pgconn.send_query(query)
return self.conn._exec_gen(self.conn.pgconn)
- def _execute_results(self, results):
+ def _execute_results(self, results: List[PGresult]) -> None:
# Implement part of execute() after waiting common to sync and async
if not results:
raise exc.InternalError("got no result from the query")
f" {', '.join(sorted(s.name for s in sorted(badstats)))}"
)
- def nextset(self):
+ def nextset(self) -> Optional[bool]:
self._iresult += 1
if self._iresult < len(self._results):
self._result = self._results[self._iresult]
self._pos = 0
return True
+ else:
+ return None
- def fetchone(self):
+ def fetchone(self) -> Optional[Sequence[Any]]:
rv = self._cast_row(self._pos)
if rv is not None:
self._pos += 1
return rv
- def _cast_row(self, n):
+ def _cast_row(self, n: int) -> Optional[Tuple[Any, ...]]:
if self._result is None:
return None
if n >= self._result.ntuples:
class Cursor(BaseCursor):
- def execute(self, query, vars=None):
+ conn: "Connection"
+
+ def __init__(self, conn: "Connection", binary: bool = False):
+ super().__init__(conn, binary)
+
+ def execute(self, query: Query, vars: Optional[Params] = None) -> "Cursor":
with self.conn.lock:
gen = self._execute_send(query, vars)
results = self.conn.wait(gen)
class AsyncCursor(BaseCursor):
- async def execute(self, query, vars=None):
+ conn: "AsyncConnection"
+
+ def __init__(self, conn: "AsyncConnection", binary: bool = False):
+ super().__init__(conn, binary)
+
+ async def execute(
+ self, query: Query, vars: Optional[Params] = None
+ ) -> "AsyncCursor":
async with self.conn.lock:
gen = self._execute_send(query, vars)
results = await self.conn.wait(gen)
# Copyright (C) 2020 The Psycopg Team
+from typing import Any, Optional, Sequence, TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from psycopg3.pq import PGresult # noqa
+
class Warning(Exception):
"""
Base exception for all the errors psycopg3 will raise.
"""
- def __init__(self, *args, pgresult=None):
+ def __init__(
+ self, *args: Sequence[Any], pgresult: Optional["PGresult"] = None
+ ):
super().__init__(*args)
self.pgresult = pgresult
"""
-def class_for_state(sqlstate):
+def class_for_state(sqlstate: bytes) -> type:
# TODO: stub
return DatabaseError
Format,
)
from .encodings import py_codecs
-from .misc import error_message
+from .misc import error_message, ConninfoOption
from . import pq_ctypes as pq_module
"Conninfo",
"PQerror",
"error_message",
+ "ConninfoOption",
"py_codecs",
"version",
)
_PQhostaddr.restype = c_char_p
-def PQhostaddr(pgconn):
+def PQhostaddr(pgconn: type) -> bytes:
if _PQhostaddr is not None:
- return _PQhostaddr(pgconn)
+ return _PQhostaddr(pgconn) # type: ignore
else:
raise NotSupportedError(
f"PQhostaddr requires libpq from PostgreSQL 12,"
# Copyright (C) 2020 The Psycopg Team
+from collections import namedtuple
+from typing import TYPE_CHECKING, Union
-def error_message(obj):
+if TYPE_CHECKING:
+ from psycopg3.pq import PGconn, PGresult # noqa
+
+
+ConninfoOption = namedtuple(
+ "ConninfoOption", "keyword envvar compiled val label dispatcher dispsize"
+)
+
+
+def error_message(obj: Union["PGconn", "PGresult"]) -> str:
"""
Return an error message from a PGconn or PGresult.
"""
from psycopg3 import pq
+ bmsg: bytes
+
if isinstance(obj, pq.PGconn):
- msg = obj.error_message
+ bmsg = obj.error_message
# strip severity and whitespaces
- if msg:
- msg = msg.splitlines()[0].split(b":", 1)[-1].strip()
+ if bmsg:
+ bmsg = bmsg.splitlines()[0].split(b":", 1)[-1].strip()
elif isinstance(obj, pq.PGresult):
- msg = obj.error_field(pq.DiagnosticField.MESSAGE_PRIMARY)
- if not msg:
- msg = obj.error_message
+ bmsg = obj.error_field(pq.DiagnosticField.MESSAGE_PRIMARY)
+ if not bmsg:
+ bmsg = obj.error_message
# strip severity and whitespaces
- if msg:
- msg = msg.splitlines()[0].split(b":", 1)[-1].strip()
+ if bmsg:
+ bmsg = bmsg.splitlines()[0].split(b":", 1)[-1].strip()
else:
raise TypeError(
f"PGconn or PGresult expected, got {type(obj).__name__}"
)
- if msg:
- msg = msg.decode("utf8", "replace") # TODO: or in connection encoding?
+ if bmsg:
+ msg = bmsg.decode(
+ "utf8", "replace"
+ ) # TODO: or in connection encoding?
else:
msg = "no details available"
# Copyright (C) 2020 The Psycopg Team
-from collections import namedtuple
from ctypes import string_at
from ctypes import c_char_p, c_int, pointer
+from typing import Any, List, Optional, Sequence
from .enums import (
ConnStatus,
ExecStatus,
TransactionStatus,
Ping,
+ DiagnosticField,
+ Format,
)
-from .misc import error_message
+from .misc import error_message, ConninfoOption
from . import _pq_ctypes as impl
from ..exceptions import OperationalError
-def version():
- return impl.PQlibVersion()
+def version() -> int:
+ return impl.PQlibVersion() # type: ignore
class PQerror(OperationalError):
class PGconn:
__slots__ = ("pgconn_ptr",)
- def __init__(self, pgconn_ptr):
- self.pgconn_ptr = pgconn_ptr
+ def __init__(self, pgconn_ptr: impl.PGconn_struct):
+ self.pgconn_ptr: Optional[impl.PGconn_struct] = pgconn_ptr
- def __del__(self):
+ def __del__(self) -> None:
self.finish()
@classmethod
- def connect(cls, conninfo):
+ def connect(cls, conninfo: bytes) -> "PGconn":
if not isinstance(conninfo, bytes):
- raise TypeError(
- "bytes expected, got %s instead" % type(conninfo).__name__
- )
+ raise TypeError(f"bytes expected, got {type(conninfo)} instead")
pgconn_ptr = impl.PQconnectdb(conninfo)
if not pgconn_ptr:
return cls(pgconn_ptr)
@classmethod
- def connect_start(cls, conninfo):
+ def connect_start(cls, conninfo: bytes) -> "PGconn":
if not isinstance(conninfo, bytes):
- raise TypeError(
- "bytes expected, got %s instead" % type(conninfo).__name__
- )
+ raise TypeError(f"bytes expected, got {type(conninfo)} instead")
pgconn_ptr = impl.PQconnectStart(conninfo)
if not pgconn_ptr:
raise MemoryError("couldn't allocate PGconn")
return cls(pgconn_ptr)
- def connect_poll(self):
+ def connect_poll(self) -> PollingStatus:
rv = impl.PQconnectPoll(self.pgconn_ptr)
return PollingStatus(rv)
- def finish(self):
+ def finish(self) -> None:
self.pgconn_ptr, p = None, self.pgconn_ptr
if p is not None:
impl.PQfinish(p)
@property
- def info(self):
+ def info(self) -> List["ConninfoOption"]:
opts = impl.PQconninfo(self.pgconn_ptr)
if not opts:
raise MemoryError("couldn't allocate connection info")
finally:
impl.PQconninfoFree(opts)
- def reset(self):
+ def reset(self) -> None:
impl.PQreset(self.pgconn_ptr)
- def reset_start(self):
+ def reset_start(self) -> None:
if not impl.PQresetStart(self.pgconn_ptr):
raise PQerror("couldn't reset connection")
- def reset_poll(self):
+ def reset_poll(self) -> PollingStatus:
rv = impl.PQresetPoll(self.pgconn_ptr)
return PollingStatus(rv)
@classmethod
- def ping(self, conninfo):
+ def ping(self, conninfo: bytes) -> Ping:
if not isinstance(conninfo, bytes):
- raise TypeError(
- "bytes expected, got %s instead" % type(conninfo).__name__
- )
+ raise TypeError(f"bytes expected, got {type(conninfo)} instead")
rv = impl.PQping(conninfo)
return Ping(rv)
@property
- def db(self):
- return impl.PQdb(self.pgconn_ptr)
+ def db(self) -> bytes:
+ return impl.PQdb(self.pgconn_ptr) # type: ignore
@property
- def user(self):
- return impl.PQuser(self.pgconn_ptr)
+ def user(self) -> bytes:
+ return impl.PQuser(self.pgconn_ptr) # type: ignore
@property
- def password(self):
- return impl.PQpass(self.pgconn_ptr)
+ def password(self) -> bytes:
+ return impl.PQpass(self.pgconn_ptr) # type: ignore
@property
- def host(self):
- return impl.PQhost(self.pgconn_ptr)
+ def host(self) -> bytes:
+ return impl.PQhost(self.pgconn_ptr) # type: ignore
@property
- def hostaddr(self):
- return impl.PQhostaddr(self.pgconn_ptr)
+ def hostaddr(self) -> bytes:
+ return impl.PQhostaddr(self.pgconn_ptr) # type: ignore
@property
- def port(self):
- return impl.PQport(self.pgconn_ptr)
+ def port(self) -> bytes:
+ return impl.PQport(self.pgconn_ptr) # type: ignore
@property
- def tty(self):
- return impl.PQtty(self.pgconn_ptr)
+ def tty(self) -> bytes:
+ return impl.PQtty(self.pgconn_ptr) # type: ignore
@property
- def options(self):
- return impl.PQoptions(self.pgconn_ptr)
+ def options(self) -> bytes:
+ return impl.PQoptions(self.pgconn_ptr) # type: ignore
@property
- def status(self):
+ def status(self) -> ConnStatus:
rv = impl.PQstatus(self.pgconn_ptr)
return ConnStatus(rv)
@property
- def transaction_status(self):
+ def transaction_status(self) -> TransactionStatus:
rv = impl.PQtransactionStatus(self.pgconn_ptr)
return TransactionStatus(rv)
- def parameter_status(self, name):
- return impl.PQparameterStatus(self.pgconn_ptr, name)
+ def parameter_status(self, name: bytes) -> bytes:
+ return impl.PQparameterStatus(self.pgconn_ptr, name) # type: ignore
@property
- def protocol_version(self):
- return impl.PQprotocolVersion(self.pgconn_ptr)
+ def protocol_version(self) -> int:
+ return impl.PQprotocolVersion(self.pgconn_ptr) # type: ignore
@property
- def server_version(self):
- return impl.PQserverVersion(self.pgconn_ptr)
+ def server_version(self) -> int:
+ return impl.PQserverVersion(self.pgconn_ptr) # type: ignore
@property
- def error_message(self):
- return impl.PQerrorMessage(self.pgconn_ptr)
+ def error_message(self) -> bytes:
+ return impl.PQerrorMessage(self.pgconn_ptr) # type: ignore
@property
- def socket(self):
- return impl.PQsocket(self.pgconn_ptr)
+ def socket(self) -> int:
+ return impl.PQsocket(self.pgconn_ptr) # type: ignore
@property
- def backend_pid(self):
- return impl.PQbackendPID(self.pgconn_ptr)
+ def backend_pid(self) -> int:
+ return impl.PQbackendPID(self.pgconn_ptr) # type: ignore
@property
- def needs_password(self):
+ def needs_password(self) -> bool:
return bool(impl.PQconnectionNeedsPassword(self.pgconn_ptr))
@property
- def used_password(self):
+ def used_password(self) -> bool:
return bool(impl.PQconnectionUsedPassword(self.pgconn_ptr))
@property
- def ssl_in_use(self):
+ def ssl_in_use(self) -> bool:
return bool(impl.PQsslInUse(self.pgconn_ptr))
- def exec_(self, command):
+ def exec_(self, command: bytes) -> "PGresult":
if not isinstance(command, bytes):
- raise TypeError(
- "bytes expected, got %s instead" % type(command).__name__
- )
+ raise TypeError(f"bytes expected, got {type(command)} instead")
rv = impl.PQexec(self.pgconn_ptr, command)
if not rv:
raise MemoryError("couldn't allocate PGresult")
return PGresult(rv)
- def send_query(self, command):
+ def send_query(self, command: bytes) -> None:
if not isinstance(command, bytes):
- raise TypeError(
- "bytes expected, got %s instead" % type(command).__name__
- )
+ raise TypeError(f"bytes expected, got {type(command)} instead")
if not impl.PQsendQuery(self.pgconn_ptr, command):
raise PQerror(f"sending query failed: {error_message(self)}")
def exec_params(
self,
- command,
- param_values,
- param_types=None,
- param_formats=None,
- result_format=0,
- ):
+ command: bytes,
+ param_values: List[Optional[bytes]],
+ param_types: Optional[List[int]] = None,
+ param_formats: Optional[List[Format]] = None,
+ result_format: Format = Format.TEXT,
+ ) -> "PGresult":
args = self._query_params_args(
command, param_values, param_types, param_formats, result_format
)
def send_query_params(
self,
- command,
- param_values,
- param_types=None,
- param_formats=None,
- result_format=0,
- ):
+ command: bytes,
+ param_values: List[Optional[bytes]],
+ param_types: Optional[List[int]] = None,
+ param_formats: Optional[List[Format]] = None,
+ result_format: Format = Format.TEXT,
+ ) -> None:
args = self._query_params_args(
command, param_values, param_types, param_formats, result_format
)
)
def _query_params_args(
- self, command, param_values, param_types, param_formats, result_format,
- ):
+ self,
+ command: bytes,
+ param_values: List[Optional[bytes]],
+ param_types: Optional[List[int]] = None,
+ param_formats: Optional[List[Format]] = None,
+ result_format: Format = Format.TEXT,
+ ) -> Any:
if not isinstance(command, bytes):
- raise TypeError(
- "bytes expected, got %s instead" % type(command).__name__
- )
+ raise TypeError(f"bytes expected, got {type(command)} instead")
nparams = len(param_values)
if nparams:
*(len(p) if p is not None else 0 for p in param_values)
)
else:
- aparams = alenghts = None
+ aparams = alenghts = None # type: ignore
if param_types is None:
atypes = None
result_format,
)
- def prepare(self, name, command, param_types=None):
+ def prepare(
+ self,
+ name: bytes,
+ command: bytes,
+ param_types: Optional[List[int]] = None,
+ ) -> "PGresult":
if not isinstance(name, bytes):
- raise TypeError(
- "'name' must be bytes, got %s instead" % type(name).__name__
- )
+ raise TypeError(f"'name' must be bytes, got {type(name)} instead")
if not isinstance(command, bytes):
raise TypeError(
- "'command' must be bytes, got %s instead"
- % type(command).__name__
+ f"'command' must be bytes, got {type(command)} instead"
)
if param_types is None:
return PGresult(rv)
def exec_prepared(
- self, name, param_values, param_formats=None, result_format=0
- ):
+ self,
+ name: bytes,
+ param_values: List[bytes],
+ param_formats: Optional[List[int]] = None,
+ result_format: int = 0,
+ ) -> "PGresult":
if not isinstance(name, bytes):
- raise TypeError(
- "'name' must be bytes, got %s instead" % type(name).__name__
- )
+ raise TypeError(f"'name' must be bytes, got {type(name)} instead")
nparams = len(param_values)
if nparams:
*(len(p) if p is not None else 0 for p in param_values)
)
else:
- aparams = alenghts = None
+ aparams = alenghts = None # type: ignore
if param_formats is None:
aformats = None
raise MemoryError("couldn't allocate PGresult")
return PGresult(rv)
- def describe_prepared(self, name):
+ def describe_prepared(self, name: bytes) -> "PGresult":
if not isinstance(name, bytes):
- raise TypeError(
- "'name' must be bytes, got %s instead" % type(name).__name__
- )
+ raise TypeError(f"'name' must be bytes, got {type(name)} instead")
rv = impl.PQdescribePrepared(self.pgconn_ptr, name)
if not rv:
raise MemoryError("couldn't allocate PGresult")
return PGresult(rv)
- def describe_portal(self, name):
+ def describe_portal(self, name: bytes) -> "PGresult":
if not isinstance(name, bytes):
- raise TypeError(
- "'name' must be bytes, got %s instead" % type(name).__name__
- )
+ raise TypeError(f"'name' must be bytes, got {type(name)} instead")
rv = impl.PQdescribePortal(self.pgconn_ptr, name)
if not rv:
raise MemoryError("couldn't allocate PGresult")
return PGresult(rv)
- def get_result(self):
+ def get_result(self) -> Optional["PGresult"]:
rv = impl.PQgetResult(self.pgconn_ptr)
return PGresult(rv) if rv else None
- def consume_input(self):
+ def consume_input(self) -> None:
if 1 != impl.PQconsumeInput(self.pgconn_ptr):
raise PQerror(f"consuming input failed: {error_message(self)}")
- def is_busy(self):
- return impl.PQisBusy(self.pgconn_ptr)
+ def is_busy(self) -> int:
+ return impl.PQisBusy(self.pgconn_ptr) # type: ignore
@property
- def nonblocking(self):
- return impl.PQisnonblocking(self.pgconn_ptr)
+ def nonblocking(self) -> int:
+ return impl.PQisnonblocking(self.pgconn_ptr) # type: ignore
@nonblocking.setter
- def nonblocking(self, arg):
+ def nonblocking(self, arg: int) -> None:
if 0 > impl.PQsetnonblocking(self.pgconn_ptr, arg):
raise PQerror(f"setting nonblocking failed: {error_message(self)}")
- def flush(self):
- rv = impl.PQflush(self.pgconn_ptr)
+ def flush(self) -> int:
+ rv: int = impl.PQflush(self.pgconn_ptr)
if rv < 0:
raise PQerror(f"flushing failed: {error_message(self)}")
return rv
- def make_empty_result(self, exec_status):
+ def make_empty_result(self, exec_status: ExecStatus) -> "PGresult":
rv = impl.PQmakeEmptyPGresult(self.pgconn_ptr, exec_status)
if not rv:
raise MemoryError("couldn't allocate empty PGresult")
class PGresult:
__slots__ = ("pgresult_ptr",)
- def __init__(self, pgresult_ptr):
- self.pgresult_ptr = pgresult_ptr
+ def __init__(self, pgresult_ptr: type):
+ self.pgresult_ptr: Optional[type] = pgresult_ptr
- def __del__(self):
+ def __del__(self) -> None:
self.clear()
- def clear(self):
+ def clear(self) -> None:
self.pgresult_ptr, p = None, self.pgresult_ptr
if p is not None:
impl.PQclear(p)
@property
- def status(self):
+ def status(self) -> ExecStatus:
rv = impl.PQresultStatus(self.pgresult_ptr)
return ExecStatus(rv)
@property
- def error_message(self):
- return impl.PQresultErrorMessage(self.pgresult_ptr)
+ def error_message(self) -> bytes:
+ return impl.PQresultErrorMessage(self.pgresult_ptr) # type: ignore
- def error_field(self, fieldcode):
- return impl.PQresultErrorField(self.pgresult_ptr, fieldcode)
+ def error_field(self, fieldcode: DiagnosticField) -> bytes:
+ return impl.PQresultErrorField( # type: ignore
+ self.pgresult_ptr, fieldcode
+ )
@property
- def ntuples(self):
- return impl.PQntuples(self.pgresult_ptr)
+ def ntuples(self) -> int:
+ return impl.PQntuples(self.pgresult_ptr) # type: ignore
@property
- def nfields(self):
- return impl.PQnfields(self.pgresult_ptr)
+ def nfields(self) -> int:
+ return impl.PQnfields(self.pgresult_ptr) # type: ignore
- def fname(self, column_number):
- return impl.PQfname(self.pgresult_ptr, column_number)
+ def fname(self, column_number: int) -> int:
+ return impl.PQfname(self.pgresult_ptr, column_number) # type: ignore
- def ftable(self, column_number):
- return impl.PQftable(self.pgresult_ptr, column_number)
+ def ftable(self, column_number: int) -> int:
+ return impl.PQftable(self.pgresult_ptr, column_number) # type: ignore
- def ftablecol(self, column_number):
- return impl.PQftablecol(self.pgresult_ptr, column_number)
+ def ftablecol(self, column_number: int) -> int:
+ return impl.PQftablecol( # type: ignore
+ self.pgresult_ptr, column_number
+ )
- def fformat(self, column_number):
- return impl.PQfformat(self.pgresult_ptr, column_number)
+ def fformat(self, column_number: int) -> Format:
+ return impl.PQfformat(self.pgresult_ptr, column_number) # type: ignore
- def ftype(self, column_number):
- return impl.PQftype(self.pgresult_ptr, column_number)
+ def ftype(self, column_number: int) -> int:
+ return impl.PQftype(self.pgresult_ptr, column_number) # type: ignore
- def fmod(self, column_number):
- return impl.PQfmod(self.pgresult_ptr, column_number)
+ def fmod(self, column_number: int) -> int:
+ return impl.PQfmod(self.pgresult_ptr, column_number) # type: ignore
- def fsize(self, column_number):
- return impl.PQfsize(self.pgresult_ptr, column_number)
+ def fsize(self, column_number: int) -> int:
+ return impl.PQfsize(self.pgresult_ptr, column_number) # type: ignore
@property
- def binary_tuples(self):
- return impl.PQbinaryTuples(self.pgresult_ptr)
-
- def get_value(self, row_number, column_number):
- length = impl.PQgetlength(self.pgresult_ptr, row_number, column_number)
+ def binary_tuples(self) -> int:
+ return impl.PQbinaryTuples(self.pgresult_ptr) # type: ignore
+
+ def get_value(
+ self, row_number: int, column_number: int
+ ) -> Optional[bytes]:
+ length: int = impl.PQgetlength(
+ self.pgresult_ptr, row_number, column_number
+ )
if length:
v = impl.PQgetvalue(self.pgresult_ptr, row_number, column_number)
return string_at(v, length)
return b""
@property
- def nparams(self):
- return impl.PQnparams(self.pgresult_ptr)
+ def nparams(self) -> int:
+ return impl.PQnparams(self.pgresult_ptr) # type: ignore
- def param_type(self, param_number):
- return impl.PQparamtype(self.pgresult_ptr, param_number)
+ def param_type(self, param_number: int) -> int:
+ return impl.PQparamtype( # type: ignore
+ self.pgresult_ptr, param_number
+ )
@property
- def command_status(self):
- return impl.PQcmdStatus(self.pgresult_ptr)
+ def command_status(self) -> bytes:
+ return impl.PQcmdStatus(self.pgresult_ptr) # type: ignore
@property
- def command_tuples(self):
+ def command_tuples(self) -> Optional[int]:
rv = impl.PQcmdTuples(self.pgresult_ptr)
- if rv:
- return int(rv)
+ return int(rv) if rv else None
@property
- def oid_value(self):
- return impl.PQoidValue(self.pgresult_ptr)
-
-
-ConninfoOption = namedtuple(
- "ConninfoOption", "keyword envvar compiled val label dispatcher dispsize"
-)
+ def oid_value(self) -> int:
+ return impl.PQoidValue(self.pgresult_ptr) # type: ignore
class Conninfo:
@classmethod
- def get_defaults(cls):
+ def get_defaults(cls) -> List[ConninfoOption]:
opts = impl.PQconndefaults()
if not opts:
raise MemoryError("couldn't allocate connection defaults")
impl.PQconninfoFree(opts)
@classmethod
- def parse(cls, conninfo):
+ def parse(cls, conninfo: bytes) -> List[ConninfoOption]:
if not isinstance(conninfo, bytes):
- raise TypeError(
- "bytes expected, got %s instead" % type(conninfo).__name__
- )
+ raise TypeError(f"bytes expected, got {type(conninfo)} instead")
errmsg = c_char_p()
rv = impl.PQconninfoParse(conninfo, pointer(errmsg))
if not errmsg:
raise MemoryError("couldn't allocate on conninfo parse")
else:
- exc = PQerror(errmsg.value.decode("utf8", "replace"))
+ exc = PQerror((errmsg.value or b"").decode("utf8", "replace"))
impl.PQfreemem(errmsg)
raise exc
impl.PQconninfoFree(rv)
@classmethod
- def _options_from_array(cls, opts):
+ def _options_from_array(
+ cls, opts: Sequence[impl.PQconninfoOption_struct]
+ ) -> List[ConninfoOption]:
rv = []
skws = "keyword envvar compiled val label dispatcher".split()
for opt in opts:
# Copyright (C) 2020 The Psycopg Team
import codecs
+from typing import Tuple
from ..adaptation import Adapter, Typecaster
from .oids import type_oid
@Adapter.register(int)
-def adapt_int(obj):
+def adapt_int(obj: int) -> Tuple[bytes, int]:
return _encode(str(obj))[0], type_oid["numeric"]
@Typecaster.register(type_oid["numeric"])
-def cast_int(data):
+def cast_int(data: bytes) -> int:
return int(_decode(data)[0])
type_oid = {name: oid for name, oid, _, _ in _oids_table}
-def self_update():
+def self_update() -> None:
import subprocess as sp
# queries output should make black happy
# Copyright (C) 2020 The Psycopg Team
import codecs
+from typing import Optional, Union
-from ..adaptation import Adapter
-from ..adaptation import Typecaster
+from ..adaptation import Adapter, Typecaster
+from ..connection import BaseConnection
+from ..utils.typing import EncodeFunc, DecodeFunc
from .oids import type_oid
@Adapter.register(str)
@Adapter.register_binary(str)
class StringAdapter(Adapter):
- def __init__(self, cls, conn):
+ def __init__(self, cls: type, conn: BaseConnection):
super().__init__(cls, conn)
- self.encode = (
- conn.codec if conn is not None else codecs.lookup("utf8")
- ).encode
- def adapt(self, obj):
- return self.encode(obj)[0]
+ self._encode: EncodeFunc
+ if conn is not None:
+ self._encode = conn.codec.encode
+ else:
+ self._encode = codecs.lookup("utf8").encode
+
+ def adapt(self, obj: str) -> bytes:
+ return self._encode(obj)[0]
@Typecaster.register(type_oid["text"])
@Typecaster.register_binary(type_oid["text"])
class StringCaster(Typecaster):
- def __init__(self, oid, conn):
+
+ decode: Optional[DecodeFunc]
+
+ def __init__(self, oid: int, conn: BaseConnection):
super().__init__(oid, conn)
+
if conn is not None:
if conn.pgenc != b"SQL_ASCII":
self.decode = conn.codec.decode
else:
self.decode = codecs.lookup("utf8").decode
- def cast(self, data):
+ def cast(self, data: bytes) -> Union[bytes, str]:
if self.decode is not None:
return self.decode(data)[0]
else:
# Copyright (C) 2020 The Psycopg Team
import re
-from collections.abc import Sequence, Mapping
+from codecs import CodecInfo
+from typing import (
+ Any,
+ Dict,
+ List,
+ Mapping,
+ Match,
+ NamedTuple,
+ Optional,
+ Sequence,
+ Tuple,
+ Union,
+)
from .. import exceptions as exc
from ..pq import Format
+from .typing import Params
-def query2pg(query, vars, codec):
+def query2pg(
+ query: bytes, vars: Params, codec: CodecInfo
+) -> Tuple[bytes, List[Format], Optional[List[str]]]:
"""
Convert Python query and params into something Postgres understands.
)
parts = split_query(query, codec.name)
+ order: Optional[List[str]] = None
+ chunks: List[bytes] = []
+ formats = []
if isinstance(vars, Sequence) and not isinstance(vars, (bytes, str)):
if len(vars) != len(parts) - 1:
f"the query has {len(parts) - 1} placeholders but"
f" {len(vars)} parameters were passed"
)
- if vars and not isinstance(parts[0][1], int):
+ if vars and not isinstance(parts[0].index, int):
raise TypeError(
"named placeholders require a mapping of parameters"
)
- order = None
+
+ for part in parts[:-1]:
+ assert isinstance(part.index, int)
+ chunks.append(part.pre)
+ chunks.append(b"$%d" % (part.index + 1))
+ formats.append(part.format)
elif isinstance(vars, Mapping):
- if vars and len(parts) > 1 and not isinstance(parts[0][1], bytes):
+ if vars and len(parts) > 1 and not isinstance(parts[0][1], str):
raise TypeError(
"positional placeholders (%s) require a sequence of parameters"
)
- seen = {}
+ seen: Dict[str, Tuple[bytes, Format]] = {}
order = []
for part in parts[:-1]:
- name = codec.decode(part[1])[0]
- if name not in seen:
- n = len(seen)
- part[1] = n
- seen[name] = (n, part[2])
- order.append(name)
+ assert isinstance(part.index, str)
+ formats.append(part.format)
+ chunks.append(part.pre)
+ if part.index not in seen:
+ ph = b"$%d" % (len(seen) + 1)
+ seen[part.index] = (ph, part.format)
+ order.append(part.index)
+ chunks.append(ph)
else:
- if seen[name][1] != part[2]:
+ if seen[part.index][1] != part.format:
raise exc.ProgrammingError(
- f"placeholder '{name}' cannot have different formats"
+ f"placeholder '{part.index}' cannot have"
+ f" different formats"
)
- part[1] = seen[name][0]
+ chunks.append(seen[part.index][0])
else:
raise TypeError(
f" got {type(vars).__name__}"
)
- # Assemble query and parameters
- rv = []
- formats = []
- for part in parts[:-1]:
- rv.append(part[0])
- rv.append(b"$%d" % (part[1] + 1))
- formats.append(part[2])
- rv.append(parts[-1][0])
+ # last part
+ chunks.append(parts[-1].pre)
- return b"".join(rv), formats, order
+ return b"".join(chunks), formats, order
_re_placeholder = re.compile(
)
-def split_query(query, encoding="ascii"):
- parts = []
+class QueryPart(NamedTuple):
+ pre: bytes
+ # TODO: mypy bug? https://github.com/python/mypy/issues/8599
+ index: Union[int, str] # type: ignore
+ format: Format
+
+
+def split_query(query: bytes, encoding: str = "ascii") -> List[QueryPart]:
+ parts: List[Tuple[bytes, Optional[Match[bytes]]]] = []
cur = 0
- # pairs [(fragment, match)], with the last match None
+ # pairs [(fragment, match], with the last match None
m = None
for m in _re_placeholder.finditer(query):
pre = query[cur : m.span(0)[0]]
- parts.append([pre, m, None])
+ parts.append((pre, m))
cur = m.span(0)[1]
if m is None:
- parts.append([query, None, None])
+ parts.append((query, None))
else:
- parts.append([query[cur:], None, None])
+ parts.append((query[cur:], None))
+
+ rv = []
# drop the "%%", validate
i = 0
phtype = None
while i < len(parts):
- part = parts[i]
- m = part[1]
+ pre, m = parts[i]
if m is None:
- break # last part
+ # last part
+ rv.append(QueryPart(pre, 0, Format.TEXT))
+ break
+
ph = m.group(0)
if ph == b"%%":
# unescape '%%' to '%' and merge the parts
- parts[i + 1][0] = part[0] + b"%" + parts[i + 1][0]
+ pre1, m1 = parts[i + 1]
+ parts[i + 1] = (pre + b"%" + pre1, m1)
del parts[i]
continue
+
if ph == b"%(":
raise exc.ProgrammingError(
f"incomplete placeholder:"
)
# Index or name
- if m.group(1) is None:
- part[1] = i
- else:
- part[1] = m.group(1)
-
- # Binary format
- part[2] = Format(ph[-1:] == b"b")
+ index: Union[int, str]
+ index = i if m.group(1) is None else m.group(1).decode(encoding)
if phtype is None:
- phtype = type(part[1])
+ phtype = type(index)
else:
- if phtype is not type(part[1]): # noqa
+ if phtype is not type(index): # noqa
raise exc.ProgrammingError(
"positional and named placeholders cannot be mixed"
)
+ # Binary format
+ format = Format(ph[-1:] == b"b")
+
+ rv.append(QueryPart(pre, index, format))
i += 1
- return parts
+ return rv
-def reorder_params(params, order):
+def reorder_params(
+ params: Mapping[str, Any], order: Sequence[str]
+) -> List[str]:
"""
Convert a mapping of parameters into an array in a specified order
"""
--- /dev/null
+"""
+Additional types for checking
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from typing import Any, Callable, Mapping, Sequence, Tuple, Union
+
+EncodeFunc = Callable[[str], Tuple[bytes, int]]
+DecodeFunc = Callable[[bytes], Tuple[str, int]]
+
+Query = Union[str, bytes]
+Params = Union[Sequence[Any], Mapping[str, Any]]
from enum import Enum
from select import select
+from typing import Generator, Tuple, TypeVar
from asyncio import get_event_loop, Event
from . import exceptions as exc
Wait = Enum("Wait", "R W RW")
Ready = Enum("Ready", "R W")
+RV = TypeVar("RV")
-def wait_select(gen):
+
+def wait_select(gen: Generator[Tuple[int, Wait], Ready, RV]) -> RV:
"""
Wait on the behalf of a generator using select().
else:
raise exc.InternalError("bad poll status: %s")
except StopIteration as e:
- return e.args[0]
+ rv: RV = e.args[0]
+ return rv
-async def wait_async(gen):
+async def wait_async(gen: Generator[Tuple[int, Wait], Ready, RV]) -> RV:
"""
Coroutine waiting for a generator to complete.
# Not sure this is the best implementation but it's a start.
ev = Event()
loop = get_event_loop()
- ready = None
+ ready = Ready.R
- def wakeup(state):
+ def wakeup(state: Ready) -> None:
nonlocal ready
ready = state
ev.set()
else:
raise exc.InternalError("bad poll status: %s")
except StopIteration as e:
- return e.args[0]
+ rv: RV = e.args[0]
+ return rv
@pytest.mark.parametrize(
"input, want",
[
- (b"", [[b"", None, None]]),
- (b"foo bar", [[b"foo bar", None, None]]),
- (b"foo %% bar", [[b"foo % bar", None, None]]),
- (b"%s", [[b"", 0, False], [b"", None, None]]),
- (b"%s foo", [[b"", 0, False], [b" foo", None, None]]),
- (b"%b foo", [[b"", 0, True], [b" foo", None, None]]),
- (b"foo %s", [[b"foo ", 0, False], [b"", None, None]]),
- (b"foo %%%s bar", [[b"foo %", 0, False], [b" bar", None, None]]),
- (
- b"foo %(name)s bar",
- [[b"foo ", b"name", False], [b" bar", None, None]],
- ),
+ (b"", [(b"", 0, 0)]),
+ (b"foo bar", [(b"foo bar", 0, 0)]),
+ (b"foo %% bar", [(b"foo % bar", 0, 0)]),
+ (b"%s", [(b"", 0, 0), (b"", 0, 0)]),
+ (b"%s foo", [(b"", 0, 0), (b" foo", 0, 0)]),
+ (b"%b foo", [(b"", 0, 1), (b" foo", 0, 0)]),
+ (b"foo %s", [(b"foo ", 0, 0), (b"", 0, 0)]),
+ (b"foo %%%s bar", [(b"foo %", 0, 0), (b" bar", 0, 0)]),
+ (b"foo %(name)s bar", [(b"foo ", "name", 0), (b" bar", 0, 0)]),
(
b"foo %(name)s %(name)b bar",
- [
- [b"foo ", b"name", False],
- [b" ", b"name", True],
- [b" bar", None, None],
- ],
+ [(b"foo ", "name", 0), (b" ", "name", 1), (b" bar", 0, 0)],
),
(
b"foo %s%b bar %s baz",
- [
- [b"foo ", 0, False],
- [b"", 1, True],
- [b" bar ", 2, False],
- [b" baz", None, None],
- ],
+ [(b"foo ", 0, 0), (b"", 1, 1), (b" bar ", 2, 0), (b" baz", 0, 0)],
),
],
)
ignore = W503, E203
[mypy]
-files = psycopg3, tests, setup.py
+files = psycopg3, setup.py
warn_unused_ignores = True
[mypy-pytest]