]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added protocol for different implementations of Transform
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 11 May 2020 04:33:29 +0000 (16:33 +1200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 11 May 2020 04:45:44 +0000 (16:45 +1200)
mypy passes all the ckecks.

18 files changed:
.gitignore
.travis.yml
psycopg3/.gitignore
psycopg3/_psycopg3.pyi [new file with mode: 0644]
psycopg3/adapt.py
psycopg3/connection.py
psycopg3/cursor.py
psycopg3/pq/.gitignore
psycopg3/pq/pq_cython.pyx
psycopg3/pq/proto.py
psycopg3/proto.py [new file with mode: 0644]
psycopg3/transform.py [new file with mode: 0644]
psycopg3/transform.pyx
psycopg3/types/array.py
psycopg3/types/composite.py
psycopg3/types/text.py
psycopg3/utils/queries.py
psycopg3/utils/typing.py

index e48743185617dbd85b304d7934798a69a39f53a7..4d5e6f253c6d352148189bdbc296015619e73a66 100644 (file)
@@ -1,3 +1,5 @@
 env
 /psycopg3.egg-info
 /.tox
+/.eggs
+/build
index 91bdee98ee5be4cfb419e918c566bc47c748e6b4..89e1e56424acbd5ac7718b548886f76f75a6e6a1 100644 (file)
@@ -81,9 +81,3 @@ install:
 
 script:
   - tox
-
-
-# This branch is still far from passing tests
-branches:
-  except:
-  - cython
index 5ce3bcd1040dc4f1870fb0655e7b43a7ea7cba9d..f14a44ed80626d180743fa35275dfdde152f5ee9 100644 (file)
@@ -1,3 +1,3 @@
 _psycopg3.c
-_psycopg3.cpython-36m-x86_64-linux-gnu.so
+_psycopg3.*.so
 *.html
diff --git a/psycopg3/_psycopg3.pyi b/psycopg3/_psycopg3.pyi
new file mode 100644 (file)
index 0000000..fd19c4d
--- /dev/null
@@ -0,0 +1,53 @@
+"""
+Stub representaton of the public objects exposed by the _psycopg3 module.
+
+TODO: this should be generated by mypy's stubgen but it crashes with no
+information. Will submit a bug.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import codecs
+from typing import Any, Iterable, List, Optional, Sequence, Tuple
+
+from .connection import BaseConnection
+from .utils.typing import AdaptContext, DumpFunc, DumpersMap, DumperType
+from .utils.typing import LoadFunc, LoadersMap, LoaderType, MaybeOid
+from . import pq
+
+Format = pq.Format
+
+class Transformer:
+    def __init__(self, context: AdaptContext = None): ...
+    @property
+    def connection(self) -> Optional[BaseConnection]: ...
+    @property
+    def codec(self) -> codecs.CodecInfo: ...
+    @property
+    def dumpers(self) -> DumpersMap: ...
+    @property
+    def loaders(self) -> LoadersMap: ...
+    @property
+    def pgresult(self) -> Optional["pq.proto.PGresult"]: ...
+    @pgresult.setter
+    def pgresult(self, result: Optional["pq.proto.PGresult"]) -> None: ...
+    def set_row_types(self, types: Sequence[Tuple[int, Format]]) -> None: ...
+    def dump_sequence(
+        self, objs: Iterable[Any], formats: Iterable[Format]
+    ) -> Tuple[List[Optional[bytes]], List[int]]: ...
+    def dump(self, obj: None, format: Format = Format.TEXT) -> MaybeOid: ...
+    def get_dump_function(self, src: type, format: Format) -> DumpFunc: ...
+    def lookup_dumper(self, src: type, format: Format) -> DumperType: ...
+    def load_row(self, row: int) -> Optional[Tuple[Any, ...]]: ...
+    def load_sequence(
+        self, record: Sequence[Optional[bytes]]
+    ) -> Tuple[Any, ...]: ...
+    def load(
+        self, data: bytes, oid: int, format: Format = Format.TEXT
+    ) -> Any: ...
+    def get_load_function(self, oid: int, format: Format) -> LoadFunc: ...
+    def lookup_loader(self, oid: int, format: Format) -> LoaderType: ...
+
+def register_builtin_c_loaders() -> None: ...
+
+# vim: set syntax=python:
index 836a7d79c3caa547ab29139910fed28b22932614..24500ee0547de37b8197f82cae9e080595666142 100644 (file)
@@ -4,34 +4,18 @@ Entry point into the adaptation system.
 
 # Copyright (C) 2020 The Psycopg Team
 
-import codecs
-from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence
-from typing import Tuple, Type, Union
+from typing import Any, Callable, Optional, Tuple, Type, Union
 
-from . import errors as e
 from . import pq
+from . import proto
 from .cursor import BaseCursor
-from .types.oids import builtins, INVALID_OID
 from .connection import BaseConnection
-
-TEXT_OID = builtins["text"].oid
+from .utils.typing import AdaptContext, DumpersMap, DumperType
+from .utils.typing import LoadersMap, LoaderType
 
 # Part of the module interface (just importing it makes mypy unhappy)
 Format = pq.Format
 
-# Type system
-
-AdaptContext = Union[None, BaseConnection, BaseCursor, "Transformer"]
-
-MaybeOid = Union[Optional[bytes], Tuple[Optional[bytes], int]]
-DumpFunc = Callable[[Any], MaybeOid]
-DumperType = Union[Type["Dumper"], DumpFunc]
-DumpersMap = Dict[Tuple[type, Format], DumperType]
-
-LoadFunc = Callable[[bytes], Any]
-LoaderType = Union[Type["Loader"], LoadFunc]
-LoadersMap = Dict[Tuple[int, Format], LoaderType]
-
 
 class Dumper:
     globals: DumpersMap = {}
@@ -155,221 +139,6 @@ class Loader:
         return binary_
 
 
-class Transformer:
-    """
-    An object that can adapt efficiently between Python and PostgreSQL.
-
-    The life cycle of the object is the query, so it is assumed that stuff like
-    the server version or connection encoding will not change. It can have its
-    state so adapting several values of the same type can use optimisations.
-    """
-
-    def __init__(self, context: AdaptContext = None):
-        self.connection: Optional[BaseConnection]
-        self.codec: codecs.CodecInfo
-        self.dumpers: DumpersMap
-        self.loaders: LoadersMap
-        self._dumpers_maps: List[DumpersMap] = []
-        self._loaders_maps: List[LoadersMap] = []
-        self._setup_context(context)
-        self.pgresult = None
-
-        # mapping class, fmt -> dump function
-        self._dump_funcs: Dict[Tuple[type, Format], DumpFunc] = {}
-
-        # mapping oid, fmt -> load function
-        self._load_funcs: Dict[Tuple[int, Format], LoadFunc] = {}
-
-        # sequence of load functions from value to python
-        # the length of the result columns
-        self._row_loaders: List[LoadFunc] = []
-
-    def _setup_context(self, context: AdaptContext) -> None:
-        if context is None:
-            self.connection = None
-            self.codec = codecs.lookup("utf8")
-            self.dumpers = {}
-            self.loaders = {}
-            self._dumpers_maps = [self.dumpers]
-            self._loaders_maps = [self.loaders]
-
-        elif isinstance(context, Transformer):
-            # A transformer created from a transformers: usually it happens
-            # for nested types: share the entire state of the parent
-            self.connection = context.connection
-            self.codec = context.codec
-            self.dumpers = context.dumpers
-            self.loaders = context.loaders
-            self._dumpers_maps.extend(context._dumpers_maps)
-            self._loaders_maps.extend(context._loaders_maps)
-            # the global maps are already in the lists
-            return
-
-        elif isinstance(context, BaseCursor):
-            self.connection = context.connection
-            self.codec = context.connection.codec
-            self.dumpers = {}
-            self._dumpers_maps.extend(
-                (self.dumpers, context.dumpers, self.connection.dumpers)
-            )
-            self.loaders = {}
-            self._loaders_maps.extend(
-                (self.loaders, context.loaders, self.connection.loaders)
-            )
-
-        elif isinstance(context, BaseConnection):
-            self.connection = context
-            self.codec = context.codec
-            self.dumpers = {}
-            self._dumpers_maps.extend((self.dumpers, context.dumpers))
-            self.loaders = {}
-            self._loaders_maps.extend((self.loaders, context.loaders))
-
-        self._dumpers_maps.append(Dumper.globals)
-        self._loaders_maps.append(Loader.globals)
-
-    @property
-    def pgresult(self) -> Optional[pq.proto.PGresult]:
-        return self._pgresult
-
-    @pgresult.setter
-    def pgresult(self, result: Optional[pq.proto.PGresult]) -> None:
-        self._pgresult = result
-        rc = self._row_loaders = []
-
-        self._ntuples: int
-        self._nfields: int
-        if result is None:
-            self._nfields = self._ntuples = 0
-            return
-
-        nf = self._nfields = result.nfields
-        self._ntuples = result.ntuples
-
-        for i in range(nf):
-            oid = result.ftype(i)
-            fmt = result.fformat(i)
-            rc.append(self.get_load_function(oid, fmt))
-
-    def set_row_types(self, types: Iterable[Tuple[int, Format]]) -> None:
-        rc = self._row_loaders = []
-        for oid, fmt in types:
-            rc.append(self.get_load_function(oid, fmt))
-
-    def dump_sequence(
-        self, objs: Iterable[Any], formats: Iterable[Format]
-    ) -> Tuple[List[Optional[bytes]], List[int]]:
-        out = []
-        types = []
-
-        for var, fmt in zip(objs, formats):
-            data = self.dump(var, fmt)
-            if isinstance(data, tuple):
-                oid = data[1]
-                data = data[0]
-            else:
-                oid = TEXT_OID
-
-            out.append(data)
-            types.append(oid)
-
-        return out, types
-
-    def dump(self, obj: None, format: Format = Format.TEXT) -> MaybeOid:
-        if obj is None:
-            return None, TEXT_OID
-
-        src = type(obj)
-        func = self.get_dump_function(src, format)
-        return func(obj)
-
-    def get_dump_function(self, src: type, format: Format) -> DumpFunc:
-        key = (src, format)
-        try:
-            return self._dump_funcs[key]
-        except KeyError:
-            pass
-
-        dumper = self.lookup_dumper(src, format)
-        func: DumpFunc
-        if isinstance(dumper, type):
-            func = dumper(src, self).dump
-        else:
-            func = dumper
-
-        self._dump_funcs[key] = func
-        return func
-
-    def lookup_dumper(self, src: type, format: Format) -> DumperType:
-        key = (src, format)
-        for amap in self._dumpers_maps:
-            if key in amap:
-                return amap[key]
-
-        raise e.ProgrammingError(
-            f"cannot adapt type {src.__name__} to format {Format(format).name}"
-        )
-
-    def load_row(self, row: int) -> Optional[Tuple[Any, ...]]:
-        res = self.pgresult
-        if res is None:
-            return None
-
-        if row >= self._ntuples:
-            return None
-
-        rv: List[Any] = []
-        for col in range(self._nfields):
-            val = res.get_value(row, col)
-            if val is None:
-                rv.append(None)
-            else:
-                rv.append(self._row_loaders[col](val))
-
-        return tuple(rv)
-
-    def load_sequence(
-        self, record: Sequence[Optional[bytes]]
-    ) -> Tuple[Any, ...]:
-        return tuple(
-            (self._row_loaders[i](val) if val is not None else None)
-            for i, val in enumerate(record)
-        )
-
-    def load(self, data: bytes, oid: int, format: Format = Format.TEXT) -> Any:
-        if data is not None:
-            f = self.get_load_function(oid, format)
-            return f(data)
-        else:
-            return None
-
-    def get_load_function(self, oid: int, format: Format) -> LoadFunc:
-        key = (oid, format)
-        try:
-            return self._load_funcs[key]
-        except KeyError:
-            pass
-
-        loader = self.lookup_loader(oid, format)
-        func: LoadFunc
-        if isinstance(loader, type):
-            func = loader(oid, self).load
-        else:
-            func = loader
-
-        self._load_funcs[key] = func
-        return func
-
-    def lookup_loader(self, oid: int, format: Format) -> LoaderType:
-        key = (oid, format)
-
-        for tcmap in self._loaders_maps:
-            if key in tcmap:
-                return tcmap[key]
-
-        return Loader.globals[INVALID_OID, format]
-
-
 def _connection_from_context(
     context: AdaptContext,
 ) -> Optional[BaseConnection]:
@@ -385,6 +154,14 @@ def _connection_from_context(
         raise TypeError(f"can't get a connection from {type(context)}")
 
 
+Transformer: Type[proto.Transformer]
+
 # Override it with fast object if available
 if pq.__impl__ == "c":
-    from ._psycopg3 import Transformer  # noqa
+    from . import _psycopg3
+
+    Transformer = _psycopg3.Transformer
+else:
+    from . import transform
+
+    Transformer = transform.Transformer
index e7570062e770952e0f14aff756e9b9c3da597caa..f0556b36eb8d616ff858cf682873bc75ba8ff82b 100644 (file)
@@ -17,11 +17,11 @@ from . import cursor
 from . import generators
 from .conninfo import make_conninfo
 from .waiting import wait, wait_async
+from .utils.typing import DumpersMap, LoadersMap
 
 logger = logging.getLogger(__name__)
 
 if TYPE_CHECKING:
-    from .adapt import DumpersMap, LoadersMap
     from .generators import PQGen, RV
 
 
index b9dc38b4197576ba296adc33f967f6035691289d..a48171896c3a27db00d13e42dc34627aba4c655b 100644 (file)
@@ -11,11 +11,11 @@ from typing import Any, List, Optional, Sequence, TYPE_CHECKING
 from . import errors as e
 from . import pq
 from . import generators
+from . import proto
 from .utils.queries import PostgresQuery
-from .utils.typing import Query, Params
+from .utils.typing import Query, Params, DumpersMap, LoadersMap
 
 if TYPE_CHECKING:
-    from .adapt import DumpersMap, LoadersMap, Transformer
     from .connection import BaseConnection, Connection, AsyncConnection
 
 
@@ -60,7 +60,7 @@ class Column(Sequence[Any]):
 class BaseCursor:
     ExecStatus = pq.ExecStatus
 
-    _transformer: "Transformer"
+    _transformer: proto.Transformer
 
     def __init__(self, connection: "BaseConnection", binary: bool = False):
         self.connection = connection
index 8b68386055c0bac6dfa1f1cc09492ec9e8817409..7b4389339d20df07fa21087c27a2d46e5f271926 100644 (file)
@@ -1,3 +1,3 @@
 pq_cython.c
-pq_cython.cpython-36m-x86_64-linux-gnu.so
+pq_cython.*.so
 pq_cython.html
index 937a3cc64f706f6cf2a0f2bd961c1bfcf9d1561d..8b090917b94132218ef7158ef2c9bbebb1dfed03 100644 (file)
@@ -6,6 +6,8 @@ libpq Python wrapper using cython bindings.
 
 from cpython.mem cimport PyMem_Malloc, PyMem_Free
 
+from typing import List, Optional, Sequence
+
 from psycopg3.pq cimport libpq as impl
 from psycopg3.pq.libpq cimport Oid
 from psycopg3.errors import OperationalError
index 36ef88b766a0831efaccd074586d85a5d80803de..23bf76615a5b58d1a3f9a33cfede67ac02458ee4 100644 (file)
@@ -1,3 +1,9 @@
+"""
+Protocol objects to represent objects exposed by different pq implementations.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
 from typing import Any, List, Optional, Sequence, TYPE_CHECKING
 from typing_extensions import Protocol
 
diff --git a/psycopg3/proto.py b/psycopg3/proto.py
new file mode 100644 (file)
index 0000000..9db962b
--- /dev/null
@@ -0,0 +1,82 @@
+"""
+Protocol objects representing different implementations of the same classes.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import codecs
+from typing import Any, Iterable, List, Optional, Sequence, Tuple
+from typing import TYPE_CHECKING
+from typing_extensions import Protocol
+
+from .utils.typing import AdaptContext, DumpFunc, DumpersMap, DumperType
+from .utils.typing import LoadFunc, LoadersMap, LoaderType, MaybeOid
+from . import pq
+
+if TYPE_CHECKING:
+    from .connection import BaseConnection  # noqa
+
+Format = pq.Format
+
+
+class Transformer(Protocol):
+    def __init__(self, context: AdaptContext = None):
+        ...
+
+    @property
+    def connection(self) -> Optional["BaseConnection"]:
+        ...
+
+    @property
+    def codec(self) -> codecs.CodecInfo:
+        ...
+
+    @property
+    def pgresult(self) -> Optional["pq.proto.PGresult"]:
+        ...
+
+    @pgresult.setter
+    def pgresult(self, result: Optional["pq.proto.PGresult"]) -> None:
+        ...
+
+    @property
+    def dumpers(self) -> DumpersMap:
+        ...
+
+    @property
+    def loaders(self) -> LoadersMap:
+        ...
+
+    def set_row_types(self, types: Sequence[Tuple[int, Format]]) -> None:
+        ...
+
+    def dump_sequence(
+        self, objs: Iterable[Any], formats: Iterable[Format]
+    ) -> Tuple[List[Optional[bytes]], List[int]]:
+        ...
+
+    def dump(self, obj: None, format: Format = Format.TEXT) -> MaybeOid:
+        ...
+
+    def get_dump_function(self, src: type, format: Format) -> DumpFunc:
+        ...
+
+    def lookup_dumper(self, src: type, format: Format) -> DumperType:
+        ...
+
+    def load_row(self, row: int) -> Optional[Tuple[Any, ...]]:
+        ...
+
+    def load_sequence(
+        self, record: Sequence[Optional[bytes]]
+    ) -> Tuple[Any, ...]:
+        ...
+
+    def load(self, data: bytes, oid: int, format: Format = Format.TEXT) -> Any:
+        ...
+
+    def get_load_function(self, oid: int, format: Format) -> LoadFunc:
+        ...
+
+    def lookup_loader(self, oid: int, format: Format) -> LoaderType:
+        ...
diff --git a/psycopg3/transform.py b/psycopg3/transform.py
new file mode 100644 (file)
index 0000000..d5bb41d
--- /dev/null
@@ -0,0 +1,252 @@
+"""
+Helper object to transform values between Python and PostgreSQL
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import codecs
+from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple
+
+from . import errors as e
+from . import pq
+from .cursor import BaseCursor
+from .types.oids import builtins, INVALID_OID
+from .connection import BaseConnection
+from .utils.typing import AdaptContext, DumpFunc, DumpersMap, DumperType
+from .utils.typing import LoadFunc, LoadersMap, LoaderType, MaybeOid
+
+Format = pq.Format
+TEXT_OID = builtins["text"].oid
+
+
+class Transformer:
+    """
+    An object that can adapt efficiently between Python and PostgreSQL.
+
+    The life cycle of the object is the query, so it is assumed that stuff like
+    the server version or connection encoding will not change. It can have its
+    state so adapting several values of the same type can use optimisations.
+    """
+
+    def __init__(self, context: AdaptContext = None):
+        self._dumpers: DumpersMap
+        self._loaders: LoadersMap
+        self._dumpers_maps: List[DumpersMap] = []
+        self._loaders_maps: List[LoadersMap] = []
+        self._setup_context(context)
+        self.pgresult = None
+
+        # mapping class, fmt -> dump function
+        self._dump_funcs: Dict[Tuple[type, Format], DumpFunc] = {}
+
+        # mapping oid, fmt -> load function
+        self._load_funcs: Dict[Tuple[int, Format], LoadFunc] = {}
+
+        # sequence of load functions from value to python
+        # the length of the result columns
+        self._row_loaders: List[LoadFunc] = []
+
+    def _setup_context(self, context: AdaptContext) -> None:
+        if context is None:
+            self._connection = None
+            self._codec = codecs.lookup("utf8")
+            self._dumpers = {}
+            self._loaders = {}
+            self._dumpers_maps = [self._dumpers]
+            self._loaders_maps = [self._loaders]
+
+        elif isinstance(context, Transformer):
+            # A transformer created from a transformers: usually it happens
+            # for nested types: share the entire state of the parent
+            self._connection = context.connection
+            self._codec = context.codec
+            self._dumpers = context.dumpers
+            self._loaders = context.loaders
+            self._dumpers_maps.extend(context._dumpers_maps)
+            self._loaders_maps.extend(context._loaders_maps)
+            # the global maps are already in the lists
+            return
+
+        elif isinstance(context, BaseCursor):
+            self._connection = context.connection
+            self._codec = context.connection.codec
+            self._dumpers = {}
+            self._dumpers_maps.extend(
+                (self._dumpers, context.dumpers, context.connection.dumpers)
+            )
+            self._loaders = {}
+            self._loaders_maps.extend(
+                (self._loaders, context.loaders, context.connection.loaders)
+            )
+
+        elif isinstance(context, BaseConnection):
+            self._connection = context
+            self._codec = context.codec
+            self._dumpers = {}
+            self._dumpers_maps.extend((self._dumpers, context.dumpers))
+            self._loaders = {}
+            self._loaders_maps.extend((self._loaders, context.loaders))
+
+        from .adapt import Dumper, Loader
+
+        self._dumpers_maps.append(Dumper.globals)
+        self._loaders_maps.append(Loader.globals)
+
+    @property
+    def connection(self) -> Optional["BaseConnection"]:
+        return self._connection
+
+    @property
+    def codec(self) -> codecs.CodecInfo:
+        return self._codec
+
+    @property
+    def pgresult(self) -> Optional[pq.proto.PGresult]:
+        return self._pgresult
+
+    @pgresult.setter
+    def pgresult(self, result: Optional[pq.proto.PGresult]) -> None:
+        self._pgresult = result
+        rc = self._row_loaders = []
+
+        self._ntuples: int
+        self._nfields: int
+        if result is None:
+            self._nfields = self._ntuples = 0
+            return
+
+        nf = self._nfields = result.nfields
+        self._ntuples = result.ntuples
+
+        for i in range(nf):
+            oid = result.ftype(i)
+            fmt = result.fformat(i)
+            rc.append(self.get_load_function(oid, fmt))
+
+    @property
+    def dumpers(self) -> DumpersMap:
+        return self._dumpers
+
+    @property
+    def loaders(self) -> LoadersMap:
+        return self._loaders
+
+    def set_row_types(self, types: Iterable[Tuple[int, Format]]) -> None:
+        rc = self._row_loaders = []
+        for oid, fmt in types:
+            rc.append(self.get_load_function(oid, fmt))
+
+    def dump_sequence(
+        self, objs: Iterable[Any], formats: Iterable[Format]
+    ) -> Tuple[List[Optional[bytes]], List[int]]:
+        out = []
+        types = []
+
+        for var, fmt in zip(objs, formats):
+            data = self.dump(var, fmt)
+            if isinstance(data, tuple):
+                oid = data[1]
+                data = data[0]
+            else:
+                oid = TEXT_OID
+
+            out.append(data)
+            types.append(oid)
+
+        return out, types
+
+    def dump(self, obj: None, format: Format = Format.TEXT) -> MaybeOid:
+        if obj is None:
+            return None, TEXT_OID
+
+        src = type(obj)
+        func = self.get_dump_function(src, format)
+        return func(obj)
+
+    def get_dump_function(self, src: type, format: Format) -> DumpFunc:
+        key = (src, format)
+        try:
+            return self._dump_funcs[key]
+        except KeyError:
+            pass
+
+        dumper = self.lookup_dumper(src, format)
+        func: DumpFunc
+        if isinstance(dumper, type):
+            func = dumper(src, self).dump
+        else:
+            func = dumper
+
+        self._dump_funcs[key] = func
+        return func
+
+    def lookup_dumper(self, src: type, format: Format) -> DumperType:
+        key = (src, format)
+        for amap in self._dumpers_maps:
+            if key in amap:
+                return amap[key]
+
+        raise e.ProgrammingError(
+            f"cannot adapt type {src.__name__} to format {Format(format).name}"
+        )
+
+    def load_row(self, row: int) -> Optional[Tuple[Any, ...]]:
+        res = self.pgresult
+        if res is None:
+            return None
+
+        if row >= self._ntuples:
+            return None
+
+        rv: List[Any] = []
+        for col in range(self._nfields):
+            val = res.get_value(row, col)
+            if val is None:
+                rv.append(None)
+            else:
+                rv.append(self._row_loaders[col](val))
+
+        return tuple(rv)
+
+    def load_sequence(
+        self, record: Sequence[Optional[bytes]]
+    ) -> Tuple[Any, ...]:
+        return tuple(
+            (self._row_loaders[i](val) if val is not None else None)
+            for i, val in enumerate(record)
+        )
+
+    def load(self, data: bytes, oid: int, format: Format = Format.TEXT) -> Any:
+        if data is not None:
+            f = self.get_load_function(oid, format)
+            return f(data)
+        else:
+            return None
+
+    def get_load_function(self, oid: int, format: Format) -> LoadFunc:
+        key = (oid, format)
+        try:
+            return self._load_funcs[key]
+        except KeyError:
+            pass
+
+        loader = self.lookup_loader(oid, format)
+        func: LoadFunc
+        if isinstance(loader, type):
+            func = loader(oid, self).load
+        else:
+            func = loader
+
+        self._load_funcs[key] = func
+        return func
+
+    def lookup_loader(self, oid: int, format: Format) -> LoaderType:
+        key = (oid, format)
+
+        for tcmap in self._loaders_maps:
+            if key in tcmap:
+                return tcmap[key]
+
+        from .adapt import Loader
+
+        return Loader.globals[INVALID_OID, format]
index 2c565fdcb07cc9ebe1738aafdc1aef1009c0261a..36fc99031b266b0cce695873b079246b323e587f 100644 (file)
@@ -5,7 +5,7 @@ from cpython.mem cimport PyMem_Malloc, PyMem_Free
 from cpython.tuple cimport PyTuple_New, PyTuple_SET_ITEM
 
 import codecs
-from typing import Any, Dict, Iterable, List, Optional, Tuple
+from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple
 
 from psycopg3.pq cimport libpq
 from psycopg3.pq.pq_cython cimport PGresult
@@ -18,7 +18,6 @@ from psycopg3.pq.enums import Format
 TEXT_OID = 25
 
 
-
 cdef class Transformer:
     """
     An object that can adapt efficiently between Python and PostgreSQL.
@@ -248,7 +247,6 @@ cdef class Transformer:
             Py_INCREF(pyval)
             PyTuple_SET_ITEM(rv, col, pyval)
 
-
         return rv
 
     def load_sequence(
index bfc137b82c0311f3b6b505e34ca49b59d0ee32ad..a7cf7b34c22e37bee8059d4f0e2cfa20891c872a 100644 (file)
@@ -10,8 +10,8 @@ from typing import Any, Generator, List, Optional, Tuple
 
 from .. import errors as e
 from ..adapt import Format, Dumper, Loader, Transformer
-from ..adapt import AdaptContext
 from .oids import builtins
+from ..utils.typing import AdaptContext
 
 TEXT_OID = builtins["text"].oid
 TEXT_ARRAY_OID = builtins["text"].array_oid
index c801e575ec90f717a8d3045d013bec857d6245d2..51cc2111b74b7d7c8ad76b2ee741d96f0b3731e1 100644 (file)
@@ -9,8 +9,9 @@ from typing import Any, Callable, Generator, Sequence, Tuple
 from typing import Optional, TYPE_CHECKING
 
 from . import array
-from ..adapt import Format, Dumper, Loader, Transformer, AdaptContext
+from ..adapt import Format, Dumper, Loader, Transformer
 from .oids import builtins, TypeInfo
+from ..utils.typing import AdaptContext
 
 if TYPE_CHECKING:
     from ..connection import Connection, AsyncConnection
@@ -256,7 +257,9 @@ class CompositeLoader(RecordLoader):
         )
 
     def _config_types(self, data: bytes) -> None:
-        self._tx.set_row_types((oid, Format.TEXT) for oid in self.fields_types)
+        self._tx.set_row_types(
+            [(oid, Format.TEXT) for oid in self.fields_types]
+        )
 
 
 class BinaryCompositeLoader(BinaryRecordLoader):
index 8600536c45681f7507fcf931dbd8cda3f32c944a..f770a812abe00fe430d60db0a0fed87eeed037f0 100644 (file)
@@ -5,13 +5,16 @@ Adapters for textual types.
 # Copyright (C) 2020 The Psycopg Team
 
 import codecs
-from typing import Optional, Tuple, Union
+from typing import Optional, Tuple, Union, TYPE_CHECKING
 
-from ..adapt import Dumper, Loader, AdaptContext
+from ..adapt import Dumper, Loader
 from ..utils.typing import EncodeFunc, DecodeFunc
 from ..pq import Escaping
 from .oids import builtins, INVALID_OID
 
+if TYPE_CHECKING:
+    from ..utils.typing import AdaptContext
+
 TEXT_OID = builtins["text"].oid
 BYTEA_OID = builtins["bytea"].oid
 
@@ -19,7 +22,7 @@ BYTEA_OID = builtins["bytea"].oid
 @Dumper.text(str)
 @Dumper.binary(str)
 class StringDumper(Dumper):
-    def __init__(self, src: type, context: AdaptContext):
+    def __init__(self, src: type, context: "AdaptContext"):
         super().__init__(src, context)
 
         self._encode: EncodeFunc
@@ -44,7 +47,7 @@ class StringLoader(Loader):
 
     decode: Optional[DecodeFunc]
 
-    def __init__(self, oid: int, context: AdaptContext):
+    def __init__(self, oid: int, context: "AdaptContext"):
         super().__init__(oid, context)
 
         if self.connection is not None:
@@ -68,7 +71,7 @@ class StringLoader(Loader):
 @Loader.text(builtins["bpchar"].oid)
 @Loader.binary(builtins["bpchar"].oid)
 class UnknownLoader(Loader):
-    def __init__(self, oid: int, context: AdaptContext):
+    def __init__(self, oid: int, context: "AdaptContext"):
         super().__init__(oid, context)
 
         self.decode: DecodeFunc
@@ -83,7 +86,7 @@ class UnknownLoader(Loader):
 
 @Dumper.text(bytes)
 class BytesDumper(Dumper):
-    def __init__(self, src: type, context: AdaptContext = None):
+    def __init__(self, src: type, context: "AdaptContext" = None):
         super().__init__(src, context)
         self.esc = Escaping(
             self.connection.pgconn if self.connection is not None else None
index 41c7fa88e49af79b5eac9e455237a9cb6ba3912a..c3a9a2356ca0d69f30619ca4eff6259cde819a4b 100644 (file)
@@ -14,7 +14,7 @@ from ..pq import Format
 from .typing import Query, Params
 
 if TYPE_CHECKING:
-    from ..adapt import Transformer
+    from ..proto import Transformer
 
 
 class PostgresQuery:
index f96576ce88be3b98c309c42ec84cd341c91bf66b..f8c41f13b178507ed3512b7e545f779eef323fef 100644 (file)
@@ -4,10 +4,34 @@ Additional types for checking
 
 # Copyright (C) 2020 The Psycopg Team
 
-from typing import Any, Callable, Mapping, Sequence, Tuple, Union
+from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple
+from typing import Type, Union, TYPE_CHECKING
+
+from .. import pq
+
+if TYPE_CHECKING:
+    from ..connection import BaseConnection  # noqa
+    from ..cursor import BaseCursor  # noqa
+    from ..adapt import Dumper, Loader  # noqa
+    from ..proto import Transformer  # noqa
+
+# Part of the module interface (just importing it makes mypy unhappy)
+Format = pq.Format
+
 
 EncodeFunc = Callable[[str], Tuple[bytes, int]]
 DecodeFunc = Callable[[bytes], Tuple[str, int]]
 
 Query = Union[str, bytes]
 Params = Union[Sequence[Any], Mapping[str, Any]]
+
+AdaptContext = Union[None, "BaseConnection", "BaseCursor", "Transformer"]
+
+MaybeOid = Union[Optional[bytes], Tuple[Optional[bytes], int]]
+DumpFunc = Callable[[Any], MaybeOid]
+DumperType = Union[Type["Dumper"], DumpFunc]
+DumpersMap = Dict[Tuple[type, Format], DumperType]
+
+LoadFunc = Callable[[bytes], Any]
+LoaderType = Union[Type["Loader"], LoadFunc]
+LoadersMap = Dict[Tuple[int, Format], LoaderType]