From: Daniele Varrazzo Date: Mon, 11 May 2020 04:33:29 +0000 (+1200) Subject: Added protocol for different implementations of Transform X-Git-Tag: 3.0.dev0~533 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=57c968f0c1f54de378efa31f8feeb85917ecd513;p=thirdparty%2Fpsycopg.git Added protocol for different implementations of Transform mypy passes all the ckecks. --- diff --git a/.gitignore b/.gitignore index e48743185..4d5e6f253 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ env /psycopg3.egg-info /.tox +/.eggs +/build diff --git a/.travis.yml b/.travis.yml index 91bdee98e..89e1e5642 100644 --- a/.travis.yml +++ b/.travis.yml @@ -81,9 +81,3 @@ install: script: - tox - - -# This branch is still far from passing tests -branches: - except: - - cython diff --git a/psycopg3/.gitignore b/psycopg3/.gitignore index 5ce3bcd10..f14a44ed8 100644 --- a/psycopg3/.gitignore +++ b/psycopg3/.gitignore @@ -1,3 +1,3 @@ _psycopg3.c -_psycopg3.cpython-36m-x86_64-linux-gnu.so +_psycopg3.*.so *.html diff --git a/psycopg3/_psycopg3.pyi b/psycopg3/_psycopg3.pyi new file mode 100644 index 000000000..fd19c4d48 --- /dev/null +++ b/psycopg3/_psycopg3.pyi @@ -0,0 +1,53 @@ +""" +Stub representaton of the public objects exposed by the _psycopg3 module. + +TODO: this should be generated by mypy's stubgen but it crashes with no +information. Will submit a bug. +""" + +# Copyright (C) 2020 The Psycopg Team + +import codecs +from typing import Any, Iterable, List, Optional, Sequence, Tuple + +from .connection import BaseConnection +from .utils.typing import AdaptContext, DumpFunc, DumpersMap, DumperType +from .utils.typing import LoadFunc, LoadersMap, LoaderType, MaybeOid +from . import pq + +Format = pq.Format + +class Transformer: + def __init__(self, context: AdaptContext = None): ... + @property + def connection(self) -> Optional[BaseConnection]: ... + @property + def codec(self) -> codecs.CodecInfo: ... + @property + def dumpers(self) -> DumpersMap: ... + @property + def loaders(self) -> LoadersMap: ... + @property + def pgresult(self) -> Optional["pq.proto.PGresult"]: ... + @pgresult.setter + def pgresult(self, result: Optional["pq.proto.PGresult"]) -> None: ... + def set_row_types(self, types: Sequence[Tuple[int, Format]]) -> None: ... + def dump_sequence( + self, objs: Iterable[Any], formats: Iterable[Format] + ) -> Tuple[List[Optional[bytes]], List[int]]: ... + def dump(self, obj: None, format: Format = Format.TEXT) -> MaybeOid: ... + def get_dump_function(self, src: type, format: Format) -> DumpFunc: ... + def lookup_dumper(self, src: type, format: Format) -> DumperType: ... + def load_row(self, row: int) -> Optional[Tuple[Any, ...]]: ... + def load_sequence( + self, record: Sequence[Optional[bytes]] + ) -> Tuple[Any, ...]: ... + def load( + self, data: bytes, oid: int, format: Format = Format.TEXT + ) -> Any: ... + def get_load_function(self, oid: int, format: Format) -> LoadFunc: ... + def lookup_loader(self, oid: int, format: Format) -> LoaderType: ... + +def register_builtin_c_loaders() -> None: ... + +# vim: set syntax=python: diff --git a/psycopg3/adapt.py b/psycopg3/adapt.py index 836a7d79c..24500ee05 100644 --- a/psycopg3/adapt.py +++ b/psycopg3/adapt.py @@ -4,34 +4,18 @@ Entry point into the adaptation system. # Copyright (C) 2020 The Psycopg Team -import codecs -from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence -from typing import Tuple, Type, Union +from typing import Any, Callable, Optional, Tuple, Type, Union -from . import errors as e from . import pq +from . import proto from .cursor import BaseCursor -from .types.oids import builtins, INVALID_OID from .connection import BaseConnection - -TEXT_OID = builtins["text"].oid +from .utils.typing import AdaptContext, DumpersMap, DumperType +from .utils.typing import LoadersMap, LoaderType # Part of the module interface (just importing it makes mypy unhappy) Format = pq.Format -# Type system - -AdaptContext = Union[None, BaseConnection, BaseCursor, "Transformer"] - -MaybeOid = Union[Optional[bytes], Tuple[Optional[bytes], int]] -DumpFunc = Callable[[Any], MaybeOid] -DumperType = Union[Type["Dumper"], DumpFunc] -DumpersMap = Dict[Tuple[type, Format], DumperType] - -LoadFunc = Callable[[bytes], Any] -LoaderType = Union[Type["Loader"], LoadFunc] -LoadersMap = Dict[Tuple[int, Format], LoaderType] - class Dumper: globals: DumpersMap = {} @@ -155,221 +139,6 @@ class Loader: return binary_ -class Transformer: - """ - An object that can adapt efficiently between Python and PostgreSQL. - - The life cycle of the object is the query, so it is assumed that stuff like - the server version or connection encoding will not change. It can have its - state so adapting several values of the same type can use optimisations. - """ - - def __init__(self, context: AdaptContext = None): - self.connection: Optional[BaseConnection] - self.codec: codecs.CodecInfo - self.dumpers: DumpersMap - self.loaders: LoadersMap - self._dumpers_maps: List[DumpersMap] = [] - self._loaders_maps: List[LoadersMap] = [] - self._setup_context(context) - self.pgresult = None - - # mapping class, fmt -> dump function - self._dump_funcs: Dict[Tuple[type, Format], DumpFunc] = {} - - # mapping oid, fmt -> load function - self._load_funcs: Dict[Tuple[int, Format], LoadFunc] = {} - - # sequence of load functions from value to python - # the length of the result columns - self._row_loaders: List[LoadFunc] = [] - - def _setup_context(self, context: AdaptContext) -> None: - if context is None: - self.connection = None - self.codec = codecs.lookup("utf8") - self.dumpers = {} - self.loaders = {} - self._dumpers_maps = [self.dumpers] - self._loaders_maps = [self.loaders] - - elif isinstance(context, Transformer): - # A transformer created from a transformers: usually it happens - # for nested types: share the entire state of the parent - self.connection = context.connection - self.codec = context.codec - self.dumpers = context.dumpers - self.loaders = context.loaders - self._dumpers_maps.extend(context._dumpers_maps) - self._loaders_maps.extend(context._loaders_maps) - # the global maps are already in the lists - return - - elif isinstance(context, BaseCursor): - self.connection = context.connection - self.codec = context.connection.codec - self.dumpers = {} - self._dumpers_maps.extend( - (self.dumpers, context.dumpers, self.connection.dumpers) - ) - self.loaders = {} - self._loaders_maps.extend( - (self.loaders, context.loaders, self.connection.loaders) - ) - - elif isinstance(context, BaseConnection): - self.connection = context - self.codec = context.codec - self.dumpers = {} - self._dumpers_maps.extend((self.dumpers, context.dumpers)) - self.loaders = {} - self._loaders_maps.extend((self.loaders, context.loaders)) - - self._dumpers_maps.append(Dumper.globals) - self._loaders_maps.append(Loader.globals) - - @property - def pgresult(self) -> Optional[pq.proto.PGresult]: - return self._pgresult - - @pgresult.setter - def pgresult(self, result: Optional[pq.proto.PGresult]) -> None: - self._pgresult = result - rc = self._row_loaders = [] - - self._ntuples: int - self._nfields: int - if result is None: - self._nfields = self._ntuples = 0 - return - - nf = self._nfields = result.nfields - self._ntuples = result.ntuples - - for i in range(nf): - oid = result.ftype(i) - fmt = result.fformat(i) - rc.append(self.get_load_function(oid, fmt)) - - def set_row_types(self, types: Iterable[Tuple[int, Format]]) -> None: - rc = self._row_loaders = [] - for oid, fmt in types: - rc.append(self.get_load_function(oid, fmt)) - - def dump_sequence( - self, objs: Iterable[Any], formats: Iterable[Format] - ) -> Tuple[List[Optional[bytes]], List[int]]: - out = [] - types = [] - - for var, fmt in zip(objs, formats): - data = self.dump(var, fmt) - if isinstance(data, tuple): - oid = data[1] - data = data[0] - else: - oid = TEXT_OID - - out.append(data) - types.append(oid) - - return out, types - - def dump(self, obj: None, format: Format = Format.TEXT) -> MaybeOid: - if obj is None: - return None, TEXT_OID - - src = type(obj) - func = self.get_dump_function(src, format) - return func(obj) - - def get_dump_function(self, src: type, format: Format) -> DumpFunc: - key = (src, format) - try: - return self._dump_funcs[key] - except KeyError: - pass - - dumper = self.lookup_dumper(src, format) - func: DumpFunc - if isinstance(dumper, type): - func = dumper(src, self).dump - else: - func = dumper - - self._dump_funcs[key] = func - return func - - def lookup_dumper(self, src: type, format: Format) -> DumperType: - key = (src, format) - for amap in self._dumpers_maps: - if key in amap: - return amap[key] - - raise e.ProgrammingError( - f"cannot adapt type {src.__name__} to format {Format(format).name}" - ) - - def load_row(self, row: int) -> Optional[Tuple[Any, ...]]: - res = self.pgresult - if res is None: - return None - - if row >= self._ntuples: - return None - - rv: List[Any] = [] - for col in range(self._nfields): - val = res.get_value(row, col) - if val is None: - rv.append(None) - else: - rv.append(self._row_loaders[col](val)) - - return tuple(rv) - - def load_sequence( - self, record: Sequence[Optional[bytes]] - ) -> Tuple[Any, ...]: - return tuple( - (self._row_loaders[i](val) if val is not None else None) - for i, val in enumerate(record) - ) - - def load(self, data: bytes, oid: int, format: Format = Format.TEXT) -> Any: - if data is not None: - f = self.get_load_function(oid, format) - return f(data) - else: - return None - - def get_load_function(self, oid: int, format: Format) -> LoadFunc: - key = (oid, format) - try: - return self._load_funcs[key] - except KeyError: - pass - - loader = self.lookup_loader(oid, format) - func: LoadFunc - if isinstance(loader, type): - func = loader(oid, self).load - else: - func = loader - - self._load_funcs[key] = func - return func - - def lookup_loader(self, oid: int, format: Format) -> LoaderType: - key = (oid, format) - - for tcmap in self._loaders_maps: - if key in tcmap: - return tcmap[key] - - return Loader.globals[INVALID_OID, format] - - def _connection_from_context( context: AdaptContext, ) -> Optional[BaseConnection]: @@ -385,6 +154,14 @@ def _connection_from_context( raise TypeError(f"can't get a connection from {type(context)}") +Transformer: Type[proto.Transformer] + # Override it with fast object if available if pq.__impl__ == "c": - from ._psycopg3 import Transformer # noqa + from . import _psycopg3 + + Transformer = _psycopg3.Transformer +else: + from . import transform + + Transformer = transform.Transformer diff --git a/psycopg3/connection.py b/psycopg3/connection.py index e7570062e..f0556b36e 100644 --- a/psycopg3/connection.py +++ b/psycopg3/connection.py @@ -17,11 +17,11 @@ from . import cursor from . import generators from .conninfo import make_conninfo from .waiting import wait, wait_async +from .utils.typing import DumpersMap, LoadersMap logger = logging.getLogger(__name__) if TYPE_CHECKING: - from .adapt import DumpersMap, LoadersMap from .generators import PQGen, RV diff --git a/psycopg3/cursor.py b/psycopg3/cursor.py index b9dc38b41..a48171896 100644 --- a/psycopg3/cursor.py +++ b/psycopg3/cursor.py @@ -11,11 +11,11 @@ from typing import Any, List, Optional, Sequence, TYPE_CHECKING from . import errors as e from . import pq from . import generators +from . import proto from .utils.queries import PostgresQuery -from .utils.typing import Query, Params +from .utils.typing import Query, Params, DumpersMap, LoadersMap if TYPE_CHECKING: - from .adapt import DumpersMap, LoadersMap, Transformer from .connection import BaseConnection, Connection, AsyncConnection @@ -60,7 +60,7 @@ class Column(Sequence[Any]): class BaseCursor: ExecStatus = pq.ExecStatus - _transformer: "Transformer" + _transformer: proto.Transformer def __init__(self, connection: "BaseConnection", binary: bool = False): self.connection = connection diff --git a/psycopg3/pq/.gitignore b/psycopg3/pq/.gitignore index 8b6838605..7b4389339 100644 --- a/psycopg3/pq/.gitignore +++ b/psycopg3/pq/.gitignore @@ -1,3 +1,3 @@ pq_cython.c -pq_cython.cpython-36m-x86_64-linux-gnu.so +pq_cython.*.so pq_cython.html diff --git a/psycopg3/pq/pq_cython.pyx b/psycopg3/pq/pq_cython.pyx index 937a3cc64..8b090917b 100644 --- a/psycopg3/pq/pq_cython.pyx +++ b/psycopg3/pq/pq_cython.pyx @@ -6,6 +6,8 @@ libpq Python wrapper using cython bindings. from cpython.mem cimport PyMem_Malloc, PyMem_Free +from typing import List, Optional, Sequence + from psycopg3.pq cimport libpq as impl from psycopg3.pq.libpq cimport Oid from psycopg3.errors import OperationalError diff --git a/psycopg3/pq/proto.py b/psycopg3/pq/proto.py index 36ef88b76..23bf76615 100644 --- a/psycopg3/pq/proto.py +++ b/psycopg3/pq/proto.py @@ -1,3 +1,9 @@ +""" +Protocol objects to represent objects exposed by different pq implementations. +""" + +# Copyright (C) 2020 The Psycopg Team + from typing import Any, List, Optional, Sequence, TYPE_CHECKING from typing_extensions import Protocol diff --git a/psycopg3/proto.py b/psycopg3/proto.py new file mode 100644 index 000000000..9db962b52 --- /dev/null +++ b/psycopg3/proto.py @@ -0,0 +1,82 @@ +""" +Protocol objects representing different implementations of the same classes. +""" + +# Copyright (C) 2020 The Psycopg Team + +import codecs +from typing import Any, Iterable, List, Optional, Sequence, Tuple +from typing import TYPE_CHECKING +from typing_extensions import Protocol + +from .utils.typing import AdaptContext, DumpFunc, DumpersMap, DumperType +from .utils.typing import LoadFunc, LoadersMap, LoaderType, MaybeOid +from . import pq + +if TYPE_CHECKING: + from .connection import BaseConnection # noqa + +Format = pq.Format + + +class Transformer(Protocol): + def __init__(self, context: AdaptContext = None): + ... + + @property + def connection(self) -> Optional["BaseConnection"]: + ... + + @property + def codec(self) -> codecs.CodecInfo: + ... + + @property + def pgresult(self) -> Optional["pq.proto.PGresult"]: + ... + + @pgresult.setter + def pgresult(self, result: Optional["pq.proto.PGresult"]) -> None: + ... + + @property + def dumpers(self) -> DumpersMap: + ... + + @property + def loaders(self) -> LoadersMap: + ... + + def set_row_types(self, types: Sequence[Tuple[int, Format]]) -> None: + ... + + def dump_sequence( + self, objs: Iterable[Any], formats: Iterable[Format] + ) -> Tuple[List[Optional[bytes]], List[int]]: + ... + + def dump(self, obj: None, format: Format = Format.TEXT) -> MaybeOid: + ... + + def get_dump_function(self, src: type, format: Format) -> DumpFunc: + ... + + def lookup_dumper(self, src: type, format: Format) -> DumperType: + ... + + def load_row(self, row: int) -> Optional[Tuple[Any, ...]]: + ... + + def load_sequence( + self, record: Sequence[Optional[bytes]] + ) -> Tuple[Any, ...]: + ... + + def load(self, data: bytes, oid: int, format: Format = Format.TEXT) -> Any: + ... + + def get_load_function(self, oid: int, format: Format) -> LoadFunc: + ... + + def lookup_loader(self, oid: int, format: Format) -> LoaderType: + ... diff --git a/psycopg3/transform.py b/psycopg3/transform.py new file mode 100644 index 000000000..d5bb41d49 --- /dev/null +++ b/psycopg3/transform.py @@ -0,0 +1,252 @@ +""" +Helper object to transform values between Python and PostgreSQL +""" + +# Copyright (C) 2020 The Psycopg Team + +import codecs +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple + +from . import errors as e +from . import pq +from .cursor import BaseCursor +from .types.oids import builtins, INVALID_OID +from .connection import BaseConnection +from .utils.typing import AdaptContext, DumpFunc, DumpersMap, DumperType +from .utils.typing import LoadFunc, LoadersMap, LoaderType, MaybeOid + +Format = pq.Format +TEXT_OID = builtins["text"].oid + + +class Transformer: + """ + An object that can adapt efficiently between Python and PostgreSQL. + + The life cycle of the object is the query, so it is assumed that stuff like + the server version or connection encoding will not change. It can have its + state so adapting several values of the same type can use optimisations. + """ + + def __init__(self, context: AdaptContext = None): + self._dumpers: DumpersMap + self._loaders: LoadersMap + self._dumpers_maps: List[DumpersMap] = [] + self._loaders_maps: List[LoadersMap] = [] + self._setup_context(context) + self.pgresult = None + + # mapping class, fmt -> dump function + self._dump_funcs: Dict[Tuple[type, Format], DumpFunc] = {} + + # mapping oid, fmt -> load function + self._load_funcs: Dict[Tuple[int, Format], LoadFunc] = {} + + # sequence of load functions from value to python + # the length of the result columns + self._row_loaders: List[LoadFunc] = [] + + def _setup_context(self, context: AdaptContext) -> None: + if context is None: + self._connection = None + self._codec = codecs.lookup("utf8") + self._dumpers = {} + self._loaders = {} + self._dumpers_maps = [self._dumpers] + self._loaders_maps = [self._loaders] + + elif isinstance(context, Transformer): + # A transformer created from a transformers: usually it happens + # for nested types: share the entire state of the parent + self._connection = context.connection + self._codec = context.codec + self._dumpers = context.dumpers + self._loaders = context.loaders + self._dumpers_maps.extend(context._dumpers_maps) + self._loaders_maps.extend(context._loaders_maps) + # the global maps are already in the lists + return + + elif isinstance(context, BaseCursor): + self._connection = context.connection + self._codec = context.connection.codec + self._dumpers = {} + self._dumpers_maps.extend( + (self._dumpers, context.dumpers, context.connection.dumpers) + ) + self._loaders = {} + self._loaders_maps.extend( + (self._loaders, context.loaders, context.connection.loaders) + ) + + elif isinstance(context, BaseConnection): + self._connection = context + self._codec = context.codec + self._dumpers = {} + self._dumpers_maps.extend((self._dumpers, context.dumpers)) + self._loaders = {} + self._loaders_maps.extend((self._loaders, context.loaders)) + + from .adapt import Dumper, Loader + + self._dumpers_maps.append(Dumper.globals) + self._loaders_maps.append(Loader.globals) + + @property + def connection(self) -> Optional["BaseConnection"]: + return self._connection + + @property + def codec(self) -> codecs.CodecInfo: + return self._codec + + @property + def pgresult(self) -> Optional[pq.proto.PGresult]: + return self._pgresult + + @pgresult.setter + def pgresult(self, result: Optional[pq.proto.PGresult]) -> None: + self._pgresult = result + rc = self._row_loaders = [] + + self._ntuples: int + self._nfields: int + if result is None: + self._nfields = self._ntuples = 0 + return + + nf = self._nfields = result.nfields + self._ntuples = result.ntuples + + for i in range(nf): + oid = result.ftype(i) + fmt = result.fformat(i) + rc.append(self.get_load_function(oid, fmt)) + + @property + def dumpers(self) -> DumpersMap: + return self._dumpers + + @property + def loaders(self) -> LoadersMap: + return self._loaders + + def set_row_types(self, types: Iterable[Tuple[int, Format]]) -> None: + rc = self._row_loaders = [] + for oid, fmt in types: + rc.append(self.get_load_function(oid, fmt)) + + def dump_sequence( + self, objs: Iterable[Any], formats: Iterable[Format] + ) -> Tuple[List[Optional[bytes]], List[int]]: + out = [] + types = [] + + for var, fmt in zip(objs, formats): + data = self.dump(var, fmt) + if isinstance(data, tuple): + oid = data[1] + data = data[0] + else: + oid = TEXT_OID + + out.append(data) + types.append(oid) + + return out, types + + def dump(self, obj: None, format: Format = Format.TEXT) -> MaybeOid: + if obj is None: + return None, TEXT_OID + + src = type(obj) + func = self.get_dump_function(src, format) + return func(obj) + + def get_dump_function(self, src: type, format: Format) -> DumpFunc: + key = (src, format) + try: + return self._dump_funcs[key] + except KeyError: + pass + + dumper = self.lookup_dumper(src, format) + func: DumpFunc + if isinstance(dumper, type): + func = dumper(src, self).dump + else: + func = dumper + + self._dump_funcs[key] = func + return func + + def lookup_dumper(self, src: type, format: Format) -> DumperType: + key = (src, format) + for amap in self._dumpers_maps: + if key in amap: + return amap[key] + + raise e.ProgrammingError( + f"cannot adapt type {src.__name__} to format {Format(format).name}" + ) + + def load_row(self, row: int) -> Optional[Tuple[Any, ...]]: + res = self.pgresult + if res is None: + return None + + if row >= self._ntuples: + return None + + rv: List[Any] = [] + for col in range(self._nfields): + val = res.get_value(row, col) + if val is None: + rv.append(None) + else: + rv.append(self._row_loaders[col](val)) + + return tuple(rv) + + def load_sequence( + self, record: Sequence[Optional[bytes]] + ) -> Tuple[Any, ...]: + return tuple( + (self._row_loaders[i](val) if val is not None else None) + for i, val in enumerate(record) + ) + + def load(self, data: bytes, oid: int, format: Format = Format.TEXT) -> Any: + if data is not None: + f = self.get_load_function(oid, format) + return f(data) + else: + return None + + def get_load_function(self, oid: int, format: Format) -> LoadFunc: + key = (oid, format) + try: + return self._load_funcs[key] + except KeyError: + pass + + loader = self.lookup_loader(oid, format) + func: LoadFunc + if isinstance(loader, type): + func = loader(oid, self).load + else: + func = loader + + self._load_funcs[key] = func + return func + + def lookup_loader(self, oid: int, format: Format) -> LoaderType: + key = (oid, format) + + for tcmap in self._loaders_maps: + if key in tcmap: + return tcmap[key] + + from .adapt import Loader + + return Loader.globals[INVALID_OID, format] diff --git a/psycopg3/transform.pyx b/psycopg3/transform.pyx index 2c565fdcb..36fc99031 100644 --- a/psycopg3/transform.pyx +++ b/psycopg3/transform.pyx @@ -5,7 +5,7 @@ from cpython.mem cimport PyMem_Malloc, PyMem_Free from cpython.tuple cimport PyTuple_New, PyTuple_SET_ITEM import codecs -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple from psycopg3.pq cimport libpq from psycopg3.pq.pq_cython cimport PGresult @@ -18,7 +18,6 @@ from psycopg3.pq.enums import Format TEXT_OID = 25 - cdef class Transformer: """ An object that can adapt efficiently between Python and PostgreSQL. @@ -248,7 +247,6 @@ cdef class Transformer: Py_INCREF(pyval) PyTuple_SET_ITEM(rv, col, pyval) - return rv def load_sequence( diff --git a/psycopg3/types/array.py b/psycopg3/types/array.py index bfc137b82..a7cf7b34c 100644 --- a/psycopg3/types/array.py +++ b/psycopg3/types/array.py @@ -10,8 +10,8 @@ from typing import Any, Generator, List, Optional, Tuple from .. import errors as e from ..adapt import Format, Dumper, Loader, Transformer -from ..adapt import AdaptContext from .oids import builtins +from ..utils.typing import AdaptContext TEXT_OID = builtins["text"].oid TEXT_ARRAY_OID = builtins["text"].array_oid diff --git a/psycopg3/types/composite.py b/psycopg3/types/composite.py index c801e575e..51cc2111b 100644 --- a/psycopg3/types/composite.py +++ b/psycopg3/types/composite.py @@ -9,8 +9,9 @@ from typing import Any, Callable, Generator, Sequence, Tuple from typing import Optional, TYPE_CHECKING from . import array -from ..adapt import Format, Dumper, Loader, Transformer, AdaptContext +from ..adapt import Format, Dumper, Loader, Transformer from .oids import builtins, TypeInfo +from ..utils.typing import AdaptContext if TYPE_CHECKING: from ..connection import Connection, AsyncConnection @@ -256,7 +257,9 @@ class CompositeLoader(RecordLoader): ) def _config_types(self, data: bytes) -> None: - self._tx.set_row_types((oid, Format.TEXT) for oid in self.fields_types) + self._tx.set_row_types( + [(oid, Format.TEXT) for oid in self.fields_types] + ) class BinaryCompositeLoader(BinaryRecordLoader): diff --git a/psycopg3/types/text.py b/psycopg3/types/text.py index 8600536c4..f770a812a 100644 --- a/psycopg3/types/text.py +++ b/psycopg3/types/text.py @@ -5,13 +5,16 @@ Adapters for textual types. # Copyright (C) 2020 The Psycopg Team import codecs -from typing import Optional, Tuple, Union +from typing import Optional, Tuple, Union, TYPE_CHECKING -from ..adapt import Dumper, Loader, AdaptContext +from ..adapt import Dumper, Loader from ..utils.typing import EncodeFunc, DecodeFunc from ..pq import Escaping from .oids import builtins, INVALID_OID +if TYPE_CHECKING: + from ..utils.typing import AdaptContext + TEXT_OID = builtins["text"].oid BYTEA_OID = builtins["bytea"].oid @@ -19,7 +22,7 @@ BYTEA_OID = builtins["bytea"].oid @Dumper.text(str) @Dumper.binary(str) class StringDumper(Dumper): - def __init__(self, src: type, context: AdaptContext): + def __init__(self, src: type, context: "AdaptContext"): super().__init__(src, context) self._encode: EncodeFunc @@ -44,7 +47,7 @@ class StringLoader(Loader): decode: Optional[DecodeFunc] - def __init__(self, oid: int, context: AdaptContext): + def __init__(self, oid: int, context: "AdaptContext"): super().__init__(oid, context) if self.connection is not None: @@ -68,7 +71,7 @@ class StringLoader(Loader): @Loader.text(builtins["bpchar"].oid) @Loader.binary(builtins["bpchar"].oid) class UnknownLoader(Loader): - def __init__(self, oid: int, context: AdaptContext): + def __init__(self, oid: int, context: "AdaptContext"): super().__init__(oid, context) self.decode: DecodeFunc @@ -83,7 +86,7 @@ class UnknownLoader(Loader): @Dumper.text(bytes) class BytesDumper(Dumper): - def __init__(self, src: type, context: AdaptContext = None): + def __init__(self, src: type, context: "AdaptContext" = None): super().__init__(src, context) self.esc = Escaping( self.connection.pgconn if self.connection is not None else None diff --git a/psycopg3/utils/queries.py b/psycopg3/utils/queries.py index 41c7fa88e..c3a9a2356 100644 --- a/psycopg3/utils/queries.py +++ b/psycopg3/utils/queries.py @@ -14,7 +14,7 @@ from ..pq import Format from .typing import Query, Params if TYPE_CHECKING: - from ..adapt import Transformer + from ..proto import Transformer class PostgresQuery: diff --git a/psycopg3/utils/typing.py b/psycopg3/utils/typing.py index f96576ce8..f8c41f13b 100644 --- a/psycopg3/utils/typing.py +++ b/psycopg3/utils/typing.py @@ -4,10 +4,34 @@ Additional types for checking # Copyright (C) 2020 The Psycopg Team -from typing import Any, Callable, Mapping, Sequence, Tuple, Union +from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple +from typing import Type, Union, TYPE_CHECKING + +from .. import pq + +if TYPE_CHECKING: + from ..connection import BaseConnection # noqa + from ..cursor import BaseCursor # noqa + from ..adapt import Dumper, Loader # noqa + from ..proto import Transformer # noqa + +# Part of the module interface (just importing it makes mypy unhappy) +Format = pq.Format + EncodeFunc = Callable[[str], Tuple[bytes, int]] DecodeFunc = Callable[[bytes], Tuple[str, int]] Query = Union[str, bytes] Params = Union[Sequence[Any], Mapping[str, Any]] + +AdaptContext = Union[None, "BaseConnection", "BaseCursor", "Transformer"] + +MaybeOid = Union[Optional[bytes], Tuple[Optional[bytes], int]] +DumpFunc = Callable[[Any], MaybeOid] +DumperType = Union[Type["Dumper"], DumpFunc] +DumpersMap = Dict[Tuple[type, Format], DumperType] + +LoadFunc = Callable[[bytes], Any] +LoaderType = Union[Type["Loader"], LoadFunc] +LoadersMap = Dict[Tuple[int, Format], LoaderType]