]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Adaptation context reworked
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 26 Dec 2020 00:08:32 +0000 (01:08 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 26 Dec 2020 14:01:58 +0000 (15:01 +0100)
Don't use a sequence of dictionaries, but a single copy-on-write
structure, which is cheaper to explore.

Added a protocol representing all the objects that can be used as
adaptation context in order to avoid cascades of isinstance.

The difference in behaviour is that changing global adapters doesn't
affect already created connections which had been customised, and likely
customised cursors are not affected by global and connection changes.
The utility of the previous behaviour doesn't seem anything on which
people would build an empire on, the new behaviour is good as well.

20 files changed:
psycopg3/psycopg3/_queries.py
psycopg3/psycopg3/_transform.py
psycopg3/psycopg3/adapt.py
psycopg3/psycopg3/connection.py
psycopg3/psycopg3/cursor.py
psycopg3/psycopg3/proto.py
psycopg3/psycopg3/sql.py
psycopg3/psycopg3/types/array.py
psycopg3/psycopg3/types/composite.py
psycopg3/psycopg3/types/date.py
psycopg3/psycopg3/types/network.py
psycopg3/psycopg3/types/range.py
psycopg3/psycopg3/types/text.py
psycopg3/psycopg3/types/uuid.py
psycopg3_c/psycopg3_c/_psycopg3.pyi
psycopg3_c/psycopg3_c/adapt.pyx
psycopg3_c/psycopg3_c/transform.pyx
psycopg3_c/psycopg3_c/types/numeric.pyx
psycopg3_c/psycopg3_c/types/text.pyx
tests/types/test_composite.py

index 5df9237be56516ca0a3aa042c905545d6f6bd91d..ab3b784fd9da7e0b402b130f6b46727700065955 100644 (file)
@@ -34,6 +34,7 @@ class PostgresQuery:
 
     _parts: List[QueryPart]
     _query = b""
+    _encoding: str = "utf-8"
     params: Optional[List[Optional[bytes]]] = None
     # these are tuples so they can be used as keys e.g. in prepared stmts
     types: Tuple[int, ...] = ()
@@ -42,11 +43,11 @@ class PostgresQuery:
 
     def __init__(self, transformer: "Transformer"):
         self._tx = transformer
-        if (
-            self._tx.connection
-            and self._tx.connection.pgconn.server_version < 100000
-        ):
-            self._unknown_oid = TEXT_OID
+        conn = transformer.connection
+        if conn:
+            self._encoding = conn.client_encoding
+            if conn.pgconn.server_version < 100000:
+                self._unknown_oid = TEXT_OID
 
     def convert(self, query: Query, vars: Optional[Params]) -> None:
         """
@@ -60,11 +61,11 @@ class PostgresQuery:
 
         if vars is not None:
             self.query, self.formats, self._order, self._parts = _query2pg(
-                query, self._tx.encoding
+                query, self._encoding
             )
         else:
             if isinstance(query, str):
-                query = query.encode(self._tx.encoding)
+                query = query.encode(self._encoding)
             self.query = query
             self.formats = self._order = None
 
@@ -81,17 +82,15 @@ class PostgresQuery:
                 self._parts, vars, self._order
             )
             assert self.formats is not None
-            ps = self.params = []
-            ts = []
+            ps: List[Optional[bytes]] = [None] * len(params)
+            ts = [self._unknown_oid] * len(params)
             for i in range(len(params)):
                 param = params[i]
                 if param is not None:
                     dumper = self._tx.get_dumper(param, self.formats[i])
-                    ps.append(dumper.dump(param))
-                    ts.append(dumper.oid)
-                else:
-                    ps.append(None)
-                    ts.append(self._unknown_oid)
+                    ps[i] = dumper.dump(param)
+                    ts[i] = dumper.oid
+            self.params = ps
             self.types = tuple(ts)
         else:
             self.params = None
index 945497fa42db1c9d8172a6db2d147097b0f51c8c..55e6c187c8ae787884de04ae6ba65914df09a591 100644 (file)
@@ -10,17 +10,15 @@ from typing import TYPE_CHECKING
 from . import errors as e
 from .pq import Format
 from .oids import INVALID_OID
-from .proto import AdaptContext, DumpersMap
-from .proto import LoadFunc, LoadersMap
-from .cursor import BaseCursor
-from .connection import BaseConnection
+from .proto import LoadFunc, AdaptContext
 
 if TYPE_CHECKING:
     from .pq.proto import PGresult
-    from .adapt import Dumper, Loader
+    from .adapt import Dumper, Loader, AdaptersMap
+    from .connection import BaseConnection
 
 
-class Transformer:
+class Transformer(AdaptContext):
     """
     An object that can adapt efficiently between Python and PostgreSQL.
 
@@ -30,14 +28,20 @@ class Transformer:
     """
 
     __module__ = "psycopg3.adapt"
+    _adapters: "AdaptersMap"
+    _pgresult: Optional["PGresult"] = None
 
-    def __init__(self, context: AdaptContext = None):
-        self._dumpers: DumpersMap
-        self._loaders: LoadersMap
-        self._dumpers_maps: List[DumpersMap] = []
-        self._loaders_maps: List[LoadersMap] = []
-        self._setup_context(context)
-        self._pgresult: Optional["PGresult"] = None
+    def __init__(self, context: Optional[AdaptContext] = None):
+        # WARNING: don't store context, or you'll create a loop with the Cursor
+        if context:
+            self._adapters = context.adapters
+            self._connection = context.connection
+
+        else:
+            from .adapt import global_adapters
+
+            self._adapters = global_adapters
+            self._connection = None
 
         # mapping class, fmt -> Dumper instance
         self._dumpers_cache: Dict[Tuple[type, Format], "Dumper"] = {}
@@ -49,59 +53,13 @@ class Transformer:
         # the length of the result columns
         self._row_loaders: List[LoadFunc] = []
 
-    def _setup_context(self, context: AdaptContext) -> None:
-        if not context:
-            self._connection = None
-            self._encoding = "utf-8"
-            self._dumpers = {}
-            self._loaders = {}
-            self._dumpers_maps = [self._dumpers]
-            self._loaders_maps = [self._loaders]
-
-        elif isinstance(context, Transformer):
-            # A transformer created from a transformers: usually it happens
-            # for nested types: share the entire state of the parent
-            self._connection = context.connection
-            self._encoding = context.encoding
-            self._dumpers = context.dumpers
-            self._loaders = context.loaders
-            self._dumpers_maps.extend(context._dumpers_maps)
-            self._loaders_maps.extend(context._loaders_maps)
-            # the global maps are already in the lists
-            return
-
-        elif isinstance(context, BaseCursor):
-            self._connection = context.connection
-            self._encoding = context.connection.client_encoding
-            self._dumpers = {}
-            self._dumpers_maps.extend(
-                (self._dumpers, context.dumpers, context.connection.dumpers)
-            )
-            self._loaders = {}
-            self._loaders_maps.extend(
-                (self._loaders, context.loaders, context.connection.loaders)
-            )
-
-        elif isinstance(context, BaseConnection):
-            self._connection = context
-            self._encoding = context.client_encoding
-            self._dumpers = {}
-            self._dumpers_maps.extend((self._dumpers, context.dumpers))
-            self._loaders = {}
-            self._loaders_maps.extend((self._loaders, context.loaders))
-
-        from .adapt import Dumper, Loader
-
-        self._dumpers_maps.append(Dumper.globals)
-        self._loaders_maps.append(Loader.globals)
-
     @property
     def connection(self) -> Optional["BaseConnection"]:
         return self._connection
 
     @property
-    def encoding(self) -> str:
-        return self._encoding
+    def adapters(self) -> "AdaptersMap":
+        return self._adapters
 
     @property
     def pgresult(self) -> Optional["PGresult"]:
@@ -126,14 +84,6 @@ class Transformer:
             fmt = result.fformat(i)
             rc.append(self.get_loader(oid, fmt).load)
 
-    @property
-    def dumpers(self) -> DumpersMap:
-        return self._dumpers
-
-    @property
-    def loaders(self) -> LoadersMap:
-        return self._loaders
-
     def set_row_types(self, types: Iterable[Tuple[int, Format]]) -> None:
         rc = self._row_loaders = []
         for oid, fmt in types:
@@ -151,29 +101,25 @@ class Transformer:
         # in contexts from the most specific to the most generic.
         # Also look for superclasses: if you can adapt a type you should be
         # able to adapt its subtypes, otherwise Liskov is sad.
-        for dmap in self._dumpers_maps:
-            for scls in cls.__mro__:
-                dumper_class = dmap.get((scls, format))
-                if not dumper_class:
-                    continue
+        dmap = self._adapters._dumpers
+        for scls in cls.__mro__:
+            dumper_class = dmap.get((scls, format))
+            if not dumper_class:
+                continue
 
-                self._dumpers_cache[cls, format] = dumper = dumper_class(
-                    cls, self
-                )
-                return dumper
+            dumper = self._dumpers_cache[cls, format] = dumper_class(cls, self)
+            return dumper
 
         # If the adapter is not found, look for its name as a string
-        for dmap in self._dumpers_maps:
-            for scls in cls.__mro__:
-                fqn = f"{cls.__module__}.{scls.__qualname__}"
-                dumper_class = dmap.get((fqn, format))
-                if dumper_class is None:
-                    continue
-
-                key = (cls, format)
-                dmap[key] = dumper_class
-                self._dumpers_cache[key] = dumper = dumper_class(cls, self)
-                return dumper
+        for scls in cls.__mro__:
+            fqn = f"{cls.__module__}.{scls.__qualname__}"
+            dumper_class = dmap.get((fqn, format))
+            if dumper_class is None:
+                continue
+
+            dmap[cls, format] = dumper_class
+            dumper = self._dumpers_cache[cls, format] = dumper_class(cls, self)
+            return dumper
 
         raise e.ProgrammingError(
             f"cannot adapt type {type(obj).__name__}"
@@ -211,14 +157,8 @@ class Transformer:
         except KeyError:
             pass
 
-        for tcmap in self._loaders_maps:
-            if key in tcmap:
-                loader_cls = tcmap[key]
-                break
-        else:
-            from .adapt import Loader  # noqa
-
-            loader_cls = Loader.globals[INVALID_OID, format]
-
-        self._loaders_cache[key] = loader = loader_cls(key[0], self)
+        loader_cls = self._adapters._loaders.get(key)
+        if not loader_cls:
+            loader_cls = self._adapters._loaders[INVALID_OID, format]
+        loader = self._loaders_cache[key] = loader_cls(oid, self)
         return loader
index f9a03765bb95409ff00da40dce1522a67ce16296..9b5e89b03d3a8de0f940fca82cfa9fede3abbf0f 100644 (file)
@@ -5,15 +5,16 @@ Entry point into the adaptation system.
 # Copyright (C) 2020 The Psycopg Team
 
 from abc import ABC, abstractmethod
-from typing import Any, cast, Callable, Optional, Type, Union
+from typing import Any, Callable, Optional, Type, TYPE_CHECKING, Union
 
 from . import pq
 from . import proto
 from .pq import Format as Format
 from .oids import TEXT_OID
-from .proto import AdaptContext, DumpersMap, DumperType, LoadersMap, LoaderType
-from .cursor import BaseCursor
-from .connection import BaseConnection
+from .proto import DumpersMap, DumperType, LoadersMap, LoaderType, AdaptContext
+
+if TYPE_CHECKING:
+    from .connection import BaseConnection
 
 
 class Dumper(ABC):
@@ -21,17 +22,16 @@ class Dumper(ABC):
     Convert Python object of the type *src* to PostgreSQL representation.
     """
 
-    globals: DumpersMap = {}
-    connection: Optional[BaseConnection]
+    connection: Optional["BaseConnection"] = None
 
     # A class-wide oid, which will be used by default by instances unless
     # the subclass overrides it in init.
     _oid: int = 0
 
-    def __init__(self, src: type, context: AdaptContext = None):
+    def __init__(self, src: type, context: Optional[AdaptContext] = None):
         self.src = src
-        self.context = context
-        self.connection = connection_from_context(context)
+        self.connection = context.connection if context else None
+
         self.oid = self._oid
         """The oid to pass to the server, if known."""
 
@@ -65,19 +65,14 @@ class Dumper(ABC):
     def register(
         cls,
         src: Union[type, str],
-        context: AdaptContext = None,
+        context: Optional[AdaptContext] = None,
         format: Format = Format.TEXT,
     ) -> None:
         """
         Configure *context* to use this dumper to convert object of type *src*.
         """
-        if not isinstance(src, (str, type)):
-            raise TypeError(
-                f"dumpers should be registered on classes, got {src} instead"
-            )
-
-        where = context.dumpers if context else Dumper.globals
-        where[src, format] = cls
+        adapters = context.adapters if context else global_adapters
+        adapters.register_dumper(src, cls, format=format)
 
     @classmethod
     def text(cls, src: Union[type, str]) -> Callable[[DumperType], DumperType]:
@@ -103,13 +98,11 @@ class Loader(ABC):
     Convert PostgreSQL objects with OID *oid* to Python objects.
     """
 
-    globals: LoadersMap = {}
-    connection: Optional[BaseConnection]
+    connection: Optional["BaseConnection"]
 
-    def __init__(self, oid: int, context: AdaptContext = None):
+    def __init__(self, oid: int, context: Optional[AdaptContext] = None):
         self.oid = oid
-        self.context = context
-        self.connection = connection_from_context(context)
+        self.connection = context.connection if context else None
 
     @abstractmethod
     def load(self, data: bytes) -> Any:
@@ -120,19 +113,14 @@ class Loader(ABC):
     def register(
         cls,
         oid: int,
-        context: AdaptContext = None,
+        context: Optional[AdaptContext] = None,
         format: Format = Format.TEXT,
     ) -> None:
         """
         Configure *context* to use this loader to convert values with OID *oid*.
         """
-        if not isinstance(oid, int):
-            raise TypeError(
-                f"loaders should be registered on oid, got {oid} instead"
-            )
-
-        where = context.loaders if context else Loader.globals
-        where[oid, format] = cls
+        adapters = context.adapters if context else global_adapters
+        adapters.register_loader(oid, cls, format=format)
 
     @classmethod
     def text(cls, oid: int) -> Callable[[LoaderType], LoaderType]:
@@ -151,19 +139,70 @@ class Loader(ABC):
         return binary_
 
 
-def connection_from_context(
-    context: AdaptContext,
-) -> Optional[BaseConnection]:
-    if not context:
-        return None
-    elif isinstance(context, BaseConnection):
-        return context
-    elif isinstance(context, BaseCursor):
-        return cast(BaseConnection, context.connection)
-    elif isinstance(context, Transformer):
-        return context.connection
-    else:
-        raise TypeError(f"can't get a connection from {type(context)}")
+class AdaptersMap:
+    """
+    Map oids to Loaders and types to Dumpers.
+
+    The object can start empty or copy from another object of the same class.
+    Copies are copy-on-write: if the maps are updated make a copy. This way
+    extending e.g. global map by a connection or a connection map from a cursor
+    is cheap: a copy is made only on customisation.
+    """
+
+    _dumpers: DumpersMap
+    _loaders: LoadersMap
+
+    def __init__(self, extend: Optional["AdaptersMap"] = None):
+        if extend:
+            self._dumpers = extend._dumpers
+            self._own_dumpers = False
+            self._loaders = extend._loaders
+            self._own_loaders = False
+        else:
+            self._dumpers = {}
+            self._own_dumpers = True
+            self._loaders = {}
+            self._own_loaders = True
+
+    def register_dumper(
+        self,
+        src: Union[type, str],
+        dumper: Type[Dumper],
+        format: Format = Format.TEXT,
+    ) -> None:
+        """
+        Configure the context to use *dumper* to convert object of type *src*.
+        """
+        if not isinstance(src, (str, type)):
+            raise TypeError(
+                f"dumpers should be registered on classes, got {src} instead"
+            )
+
+        if not self._own_dumpers:
+            self._dumpers = self._dumpers.copy()
+            self._own_dumpers = True
+
+        self._dumpers[src, format] = dumper
+
+    def register_loader(
+        self, oid: int, loader: Type[Loader], format: Format = Format.TEXT
+    ) -> None:
+        """
+        Configure the context to use *loader* to convert data of oid *oid*.
+        """
+        if not isinstance(oid, int):
+            raise TypeError(
+                f"loaders should be registered on oid, got {oid} instead"
+            )
+
+        if not self._own_loaders:
+            self._loaders = self._loaders.copy()
+            self._own_loaders = True
+
+        self._loaders[oid, format] = loader
+
+
+global_adapters = AdaptersMap()
 
 
 Transformer: Type[proto.Transformer]
index ee40d6b7de6a35ed5feb0768a167032930dcda23..319c3179a0b73c9c3e782fd1696da7302d425edb 100644 (file)
@@ -10,7 +10,7 @@ import logging
 import threading
 from types import TracebackType
 from typing import Any, AsyncIterator, Callable, Iterator, List, NamedTuple
-from typing import Optional, Type, TYPE_CHECKING, TypeVar
+from typing import Optional, Type, TYPE_CHECKING
 from weakref import ref, ReferenceType
 from functools import partial
 from contextlib import contextmanager
@@ -21,13 +21,15 @@ else:
     from .utils.context import asynccontextmanager
 
 from . import pq
+from . import adapt
 from . import cursor
 from . import errors as e
 from . import waiting
 from . import encodings
 from .pq import TransactionStatus, ExecStatus, Format
 from .sql import Composable
-from .proto import DumpersMap, LoadersMap, PQGen, PQGenConn, RV, Query, Params
+from .proto import PQGen, PQGenConn, RV, Query, Params, AdaptContext
+from .proto import ConnectionType
 from .conninfo import make_conninfo
 from .generators import notifies
 from .transaction import Transaction, AsyncTransaction
@@ -71,12 +73,11 @@ class Notify(NamedTuple):
 
 Notify.__module__ = "psycopg3"
 
-C = TypeVar("C", bound="BaseConnection")
 NoticeHandler = Callable[[e.Diagnostic], None]
 NotifyHandler = Callable[[Notify], None]
 
 
-class BaseConnection:
+class BaseConnection(AdaptContext):
     """
     Base class for different types of connections.
 
@@ -105,8 +106,7 @@ class BaseConnection:
     def __init__(self, pgconn: "PGconn"):
         self.pgconn = pgconn  # TODO: document this
         self._autocommit = False
-        self.dumpers: DumpersMap = {}
-        self.loaders: LoadersMap = {}
+        self._adapters = adapt.AdaptersMap(adapt.global_adapters)
         self._notice_handlers: List[NoticeHandler] = []
         self._notify_handlers: List[NotifyHandler] = []
 
@@ -115,7 +115,7 @@ class BaseConnection:
         # only a begin/commit and not a savepoint.
         self._savepoints: List[str] = []
 
-        self._prepared = PrepareManager()
+        self._prepared: PrepareManager = PrepareManager()
 
         wself = ref(self)
 
@@ -177,6 +177,15 @@ class BaseConnection:
         if result.status != ExecStatus.TUPLES_OK:
             raise e.error_from_result(result, encoding=self.client_encoding)
 
+    @property
+    def adapters(self) -> adapt.AdaptersMap:
+        return self._adapters
+
+    @property
+    def connection(self) -> "BaseConnection":
+        # implement the AdaptContext protocol
+        return self
+
     def cancel(self) -> None:
         """Cancel the current operation on the connection."""
         c = self.pgconn.get_cancel()
@@ -279,12 +288,12 @@ class BaseConnection:
 
     @classmethod
     def _connect_gen(
-        cls: Type[C],
+        cls: Type[ConnectionType],
         conninfo: str = "",
         *,
         autocommit: bool = False,
         **kwargs: Any,
-    ) -> PQGenConn[C]:
+    ) -> PQGenConn[ConnectionType]:
         """Generator to connect to the database and create a new instance."""
         conninfo = make_conninfo(conninfo, **kwargs)
         pgconn = yield from connect(conninfo)
index b5301d3a27ba9bdb29c83f2571b259442c8e6bdb..08811862c4cece8b22380f9eecdc22e4ba286c81 100644 (file)
@@ -10,11 +10,13 @@ from typing import Any, AsyncIterator, Callable, Generic, Iterator, List
 from typing import Optional, Sequence, Type, TYPE_CHECKING
 from contextlib import contextmanager
 
-from . import errors as e
 from . import pq
+from . import adapt
+from . import errors as e
+
 from .pq import ExecStatus, Format
 from .copy import Copy, AsyncCopy
-from .proto import ConnectionType, Query, Params, DumpersMap, LoadersMap, PQGen
+from .proto import ConnectionType, Query, Params, PQGen
 from ._column import Column
 from ._queries import PostgresQuery
 from ._preparing import Prepare
@@ -55,8 +57,7 @@ class BaseCursor(Generic[ConnectionType]):
     ):
         self._conn = connection
         self.format = format
-        self.dumpers: DumpersMap = {}
-        self.loaders: LoadersMap = {}
+        self._adapters = adapt.AdaptersMap(connection.adapters)
         self._reset()
         self.arraysize = 1
         self._closed = False
@@ -75,6 +76,10 @@ class BaseCursor(Generic[ConnectionType]):
         """The connection this cursor is using."""
         return self._conn
 
+    @property
+    def adapters(self) -> adapt.AdaptersMap:
+        return self._adapters
+
     @property
     def closed(self) -> bool:
         """`True` if the cursor is closed."""
@@ -227,8 +232,6 @@ class BaseCursor(Generic[ConnectionType]):
         It is implemented as generator because it may send additional queries,
         such as `begin`.
         """
-        from . import adapt
-
         if self.closed:
             raise e.InterfaceError("the cursor is closed")
 
index fff6ffee0aed5ae113506f9d369c1223d11f2dec..7be4cca3aae262b2556ae95934e4f514b8b03f89 100644 (file)
@@ -14,8 +14,7 @@ from .pq import Format
 
 if TYPE_CHECKING:
     from .connection import BaseConnection
-    from .cursor import BaseCursor
-    from .adapt import Dumper, Loader
+    from .adapt import Dumper, Loader, AdaptersMap
     from .waiting import Wait, Ready
     from .sql import Composable
 
@@ -43,8 +42,6 @@ Wait states.
 
 # Adaptation types
 
-AdaptContext = Union[None, "BaseConnection", "BaseCursor", "Transformer"]
-
 DumpFunc = Callable[[Any], bytes]
 DumperType = Type["Dumper"]
 DumpersMap = Dict[Tuple[Union[type, str], Format], DumperType]
@@ -57,32 +54,40 @@ LoadersMap = Dict[Tuple[int, Format], LoaderType]
 # as there are both C and a Python implementation
 
 
-class Transformer(Protocol):
-    def __init__(self, context: AdaptContext = None):
+class AdaptContext(Protocol):
+    """
+    A context describing how types are adapted.
+
+    Example of AdaptContext are connections, cursors, transformers.
+    """
+
+    @property
+    def adapters(self) -> "AdaptersMap":
         ...
 
     @property
     def connection(self) -> Optional["BaseConnection"]:
         ...
 
-    @property
-    def encoding(self) -> str:
+
+class Transformer(Protocol):
+    def __init__(self, context: Optional[AdaptContext] = None):
         ...
 
     @property
-    def pgresult(self) -> Optional[pq.proto.PGresult]:
+    def connection(self) -> Optional["BaseConnection"]:
         ...
 
-    @pgresult.setter
-    def pgresult(self, result: Optional[pq.proto.PGresult]) -> None:
+    @property
+    def adapters(self) -> "AdaptersMap":
         ...
 
     @property
-    def dumpers(self) -> DumpersMap:
+    def pgresult(self) -> Optional[pq.proto.PGresult]:
         ...
 
-    @property
-    def loaders(self) -> LoadersMap:
+    @pgresult.setter
+    def pgresult(self, result: Optional[pq.proto.PGresult]) -> None:
         ...
 
     def set_row_types(self, types: Sequence[Tuple[int, Format]]) -> None:
index ba7587ec3fe825b46dc59d695063454eed2035cc..d13a66412db4cd8867ed893ac435bcb91589f93f 100644 (file)
@@ -4,18 +4,16 @@ SQL composition utility module
 
 # Copyright (C) 2020 The Psycopg Team
 
+import codecs
 import string
 from typing import Any, Iterator, List, Optional, Sequence, Union
-from typing import TYPE_CHECKING
 
 from .pq import Escaping, Format
+from .adapt import Transformer
 from .proto import AdaptContext
 
-if TYPE_CHECKING:
-    from .connection import BaseConnection
 
-
-def quote(obj: Any, context: AdaptContext = None) -> str:
+def quote(obj: Any, context: Optional[AdaptContext] = None) -> str:
     """
     Adapt a Python object to a quoted SQL string.
 
@@ -28,11 +26,7 @@ def quote(obj: Any, context: AdaptContext = None) -> str:
     rules used, otherwise only global rules are used.
 
     """
-    from .adapt import connection_from_context
-
-    conn = connection_from_context(context)
-    enc = conn.client_encoding if conn else "utf-8"
-    return Literal(obj).as_bytes(context).decode(enc)
+    return Literal(obj).as_string(context)
 
 
 class Composable(object):
@@ -56,7 +50,7 @@ class Composable(object):
     def __repr__(self) -> str:
         return f"{self.__class__.__name__}({self._obj!r})"
 
-    def as_bytes(self, context: AdaptContext) -> bytes:
+    def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
         """
         Return the value of the object as bytes.
 
@@ -71,7 +65,7 @@ class Composable(object):
         # TODO: add tests and docs for as_bytes
         raise NotImplementedError
 
-    def as_string(self, context: AdaptContext) -> str:
+    def as_string(self, context: Optional[AdaptContext]) -> str:
         """
         Return the value of the object as string.
 
@@ -79,8 +73,14 @@ class Composable(object):
         :type context: `connection` or `cursor`
 
         """
-        conn = _connection_from_context(context)
-        return self.as_bytes(context).decode(conn.client_encoding)
+        conn = context.connection if context else None
+        enc = conn.client_encoding if conn else "utf-8"
+        b = self.as_bytes(context)
+        if isinstance(b, bytes):
+            return b.decode(enc)
+        else:
+            # buffer object
+            return codecs.lookup(enc).decode(b)[0]
 
     def __add__(self, other: "Composable") -> "Composed":
         if isinstance(other, Composed):
@@ -128,7 +128,7 @@ class Composed(Composable):
         ]
         super().__init__(seq)
 
-    def as_bytes(self, context: AdaptContext) -> bytes:
+    def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
         return b"".join(obj.as_bytes(context) for obj in self._obj)
 
     def __iter__(self) -> Iterator[Composable]:
@@ -198,12 +198,16 @@ class SQL(Composable):
         if not isinstance(obj, str):
             raise TypeError(f"SQL values must be strings, got {obj!r} instead")
 
-    def as_string(self, context: AdaptContext) -> str:
+    def as_string(self, context: Optional[AdaptContext]) -> str:
         return self._obj
 
-    def as_bytes(self, context: AdaptContext) -> bytes:
-        conn = _connection_from_context(context)
-        return self._obj.encode(conn.client_encoding)
+    def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
+        enc = "utf-8"
+        if context:
+            conn = context.connection
+            if conn:
+                enc = conn.client_encoding
+        return self._obj.encode(enc)
 
     def format(self, *args: Any, **kwargs: Any) -> Composed:
         """
@@ -356,8 +360,10 @@ class Identifier(Composable):
     def __repr__(self) -> str:
         return f"{self.__class__.__name__}({', '.join(map(repr, self._obj))})"
 
-    def as_bytes(self, context: AdaptContext) -> bytes:
-        conn = _connection_from_context(context)
+    def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
+        conn = context.connection if context else None
+        if not conn:
+            raise ValueError("a connection is necessary for Identifier")
         esc = Escaping(conn.pgconn)
         enc = conn.client_encoding
         escs = [esc.escape_identifier(s.encode(enc)) for s in self._obj]
@@ -385,10 +391,8 @@ class Literal(Composable):
 
     """
 
-    def as_bytes(self, context: AdaptContext) -> bytes:
-        from .adapt import Transformer
-
-        tx = context if isinstance(context, Transformer) else Transformer()
+    def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
+        tx = Transformer(context)
         dumper = tx.get_dumper(self._obj, Format.TEXT)
         return dumper.quote(self._obj)
 
@@ -440,25 +444,16 @@ class Placeholder(Composable):
 
         return f"{self.__class__.__name__}({', '.join(parts)})"
 
-    def as_string(self, context: AdaptContext) -> str:
+    def as_string(self, context: Optional[AdaptContext]) -> str:
         code = "s" if self._format == Format.TEXT else "b"
         return f"%({self._obj}){code}" if self._obj else f"%{code}"
 
-    def as_bytes(self, context: AdaptContext) -> bytes:
-        conn = _connection_from_context(context)
-        return self.as_string(context).encode(conn.client_encoding)
+    def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
+        conn = context.connection if context else None
+        enc = conn.client_encoding if conn else "utf-8"
+        return self.as_string(context).encode(enc)
 
 
 # Literals
 NULL = SQL("NULL")
 DEFAULT = SQL("DEFAULT")
-
-
-def _connection_from_context(context: AdaptContext) -> "BaseConnection":
-    from .adapt import connection_from_context
-
-    conn = connection_from_context(context)
-    if not conn:
-        raise ValueError(f"no connection in the context: {context}")
-
-    return conn
index b552675fd9021b960c784d4540a3072ba8970d8f..129c62d9832f9a792a474188c16da32c030d9f9d 100644 (file)
@@ -18,7 +18,7 @@ class BaseListDumper(Dumper):
 
     _oid = TEXT_ARRAY_OID
 
-    def __init__(self, src: type, context: AdaptContext = None):
+    def __init__(self, src: type, context: Optional[AdaptContext] = None):
         super().__init__(src, context)
         self._tx = Transformer(context)
 
@@ -159,7 +159,7 @@ class ListBinaryDumper(BaseListDumper):
 class BaseArrayLoader(Loader):
     base_oid: int
 
-    def __init__(self, oid: int, context: AdaptContext = None):
+    def __init__(self, oid: int, context: Optional[AdaptContext] = None):
         super().__init__(oid, context)
         self._tx = Transformer(context)
 
@@ -263,7 +263,7 @@ class ArrayBinaryLoader(BaseArrayLoader):
 def register(
     array_oid: int,
     base_oid: int,
-    context: AdaptContext = None,
+    context: Optional[AdaptContext] = None,
     name: Optional[str] = None,
 ) -> None:
     if not name:
index 3f26525d0ad57d4d65e4bb5ec0b9458f127bed86..ebebf649e21952eb5a0ee5a6815d91750aef3a19 100644 (file)
@@ -70,7 +70,7 @@ class CompositeInfo(TypeInfo):
 
     def register(
         self,
-        context: AdaptContext = None,
+        context: Optional[AdaptContext] = None,
         factory: Optional[Callable[..., Any]] = None,
     ) -> None:
         if not factory:
@@ -144,7 +144,7 @@ where t.oid = %(name)s::regtype
 
 
 class SequenceDumper(Dumper):
-    def __init__(self, src: type, context: AdaptContext = None):
+    def __init__(self, src: type, context: Optional[AdaptContext] = None):
         super().__init__(src, context)
         self._tx = Transformer(context)
 
@@ -190,7 +190,7 @@ class TupleDumper(SequenceDumper):
 
 
 class BaseCompositeLoader(Loader):
-    def __init__(self, oid: int, context: AdaptContext = None):
+    def __init__(self, oid: int, context: Optional[AdaptContext] = None):
         super().__init__(oid, context)
         self._tx = Transformer(context)
 
index b47f782310b8083939b06bd40577865c3e1f3b80..d41855c802bd03cb1efa338b83482ac6d1fdf678 100644 (file)
@@ -7,7 +7,7 @@ Adapters for date/time types.
 import re
 import sys
 from datetime import date, datetime, time, timedelta
-from typing import cast
+from typing import cast, Optional
 
 from ..oids import builtins
 from ..adapt import Dumper, Loader
@@ -51,7 +51,7 @@ class TimeDeltaDumper(Dumper):
 
     _oid = builtins["interval"].oid
 
-    def __init__(self, src: type, context: AdaptContext = None):
+    def __init__(self, src: type, context: Optional[AdaptContext] = None):
         super().__init__(src, context)
         if self.connection:
             if (
@@ -75,7 +75,7 @@ class TimeDeltaDumper(Dumper):
 
 @Loader.text(builtins["date"].oid)
 class DateLoader(Loader):
-    def __init__(self, oid: int, context: AdaptContext):
+    def __init__(self, oid: int, context: Optional[AdaptContext] = None):
         super().__init__(oid, context)
         self._format = self._format_from_context()
 
@@ -161,7 +161,7 @@ class TimeTzLoader(TimeLoader):
     _format = "%H:%M:%S.%f%z"
     _format_no_micro = _format.replace(".%f", "")
 
-    def __init__(self, oid: int, context: AdaptContext):
+    def __init__(self, oid: int, context: Optional[AdaptContext] = None):
         if sys.version_info < (3, 7):
             setattr(self, "load", self._load_py36)
 
@@ -193,7 +193,7 @@ class TimeTzLoader(TimeLoader):
 
 @Loader.text(builtins["timestamp"].oid)
 class TimestampLoader(DateLoader):
-    def __init__(self, oid: int, context: AdaptContext):
+    def __init__(self, oid: int, context: Optional[AdaptContext] = None):
         super().__init__(oid, context)
         self._format_no_micro = self._format.replace(".%f", "")
 
@@ -245,7 +245,7 @@ class TimestampLoader(DateLoader):
 
 @Loader.text(builtins["timestamptz"].oid)
 class TimestamptzLoader(TimestampLoader):
-    def __init__(self, oid: int, context: AdaptContext):
+    def __init__(self, oid: int, context: Optional[AdaptContext] = None):
         if sys.version_info < (3, 7):
             setattr(self, "load", self._load_py36)
 
@@ -321,7 +321,7 @@ class IntervalLoader(Loader):
         re.VERBOSE,
     )
 
-    def __init__(self, oid: int, context: AdaptContext):
+    def __init__(self, oid: int, context: Optional[AdaptContext] = None):
         super().__init__(oid, context)
         if self.connection:
             ints = self.connection.pgconn.parameter_status(b"IntervalStyle")
index a04e6117a9d28fc79d589d51d2eaee82794e6654..ff5df320b97753ea723e36517d12c574e926fc40 100644 (file)
@@ -4,7 +4,7 @@ Adapters for network types.
 
 # Copyright (C) 2020 The Psycopg Team
 
-from typing import Callable, Union, TYPE_CHECKING
+from typing import Callable, Optional, Union, TYPE_CHECKING
 
 from ..oids import builtins
 from ..adapt import Dumper, Loader
@@ -46,7 +46,7 @@ class NetworkDumper(Dumper):
 
 
 class _LazyIpaddress(Loader):
-    def __init__(self, oid: int, context: AdaptContext = None):
+    def __init__(self, oid: int, context: Optional[AdaptContext] = None):
         super().__init__(oid, context)
         global ip_address, ip_interface, ip_network
         from ipaddress import ip_address, ip_interface, ip_network
index 7b2d396ecf62a7d5d773aa5c44a6dc8de5f62253..c480c688b2ee41492f7aaa820d48b6fae5dcd009 100644 (file)
@@ -399,7 +399,7 @@ class RangeInfo(TypeInfo):
 
     def register(
         self,
-        context: AdaptContext = None,
+        context: Optional[AdaptContext] = None,
         range_class: Optional[Type[Range[Any]]] = None,
     ) -> None:
         if not range_class:
index 1e393523150dd3f57ccc1a8059e40cd33b299758..b215989209559d6b1fbf4a6dd18aeaacfedd8d0d 100644 (file)
@@ -4,7 +4,7 @@ Adapters for textual types.
 
 # Copyright (C) 2020 The Psycopg Team
 
-from typing import Union, TYPE_CHECKING
+from typing import Optional, Union, TYPE_CHECKING
 
 from ..pq import Escaping
 from ..oids import builtins, INVALID_OID
@@ -20,7 +20,7 @@ class _StringDumper(Dumper):
 
     _encoding = "utf-8"
 
-    def __init__(self, src: type, context: AdaptContext):
+    def __init__(self, src: type, context: Optional[AdaptContext] = None):
         super().__init__(src, context)
 
         conn = self.connection
@@ -61,7 +61,7 @@ class TextLoader(Loader):
 
     _encoding = "utf-8"
 
-    def __init__(self, oid: int, context: AdaptContext):
+    def __init__(self, oid: int, context: Optional[AdaptContext] = None):
         super().__init__(oid, context)
         conn = self.connection
         if conn:
@@ -83,7 +83,7 @@ class BytesDumper(Dumper):
 
     _oid = builtins["bytea"].oid
 
-    def __init__(self, src: type, context: AdaptContext = None):
+    def __init__(self, src: type, context: Optional[AdaptContext] = None):
         super().__init__(src, context)
         self._esc = Escaping(
             self.connection.pgconn if self.connection else None
@@ -113,7 +113,7 @@ class BytesBinaryDumper(Dumper):
 class ByteaLoader(Loader):
     _escaping: "EscapingProto"
 
-    def __init__(self, oid: int, context: AdaptContext = None):
+    def __init__(self, oid: int, context: Optional[AdaptContext] = None):
         super().__init__(oid, context)
         if not hasattr(self.__class__, "_escaping"):
             self.__class__._escaping = Escaping()
index 040e51b927c5aff80441266f54425a2fac767b81..819c4b3fe163b797473a818d1af741c3340cb09d 100644 (file)
@@ -4,7 +4,7 @@ Adapters for the UUID type.
 
 # Copyright (C) 2020 The Psycopg Team
 
-from typing import Callable, TYPE_CHECKING
+from typing import Callable, Optional, TYPE_CHECKING
 
 from ..oids import builtins
 from ..adapt import Dumper, Loader
@@ -34,7 +34,7 @@ class UUIDBinaryDumper(UUIDDumper):
 
 @Loader.text(builtins["uuid"].oid)
 class UUIDLoader(Loader):
-    def __init__(self, oid: int, context: AdaptContext = None):
+    def __init__(self, oid: int, context: Optional[AdaptContext] = None):
         super().__init__(oid, context)
         global UUID
         from uuid import UUID
index d765fa41cb8e61daf306480b907d18054d02152d..87bf301034a868803a62b85fa0b4968e031bc649 100644 (file)
@@ -9,38 +9,32 @@ information. Will submit a bug.
 
 from typing import Any, Iterable, List, Optional, Sequence, Tuple
 
-from psycopg3.adapt import Dumper, Loader
-from psycopg3.proto import AdaptContext, DumpFunc, DumpersMap, DumperType
-from psycopg3.proto import LoadFunc, LoadersMap, LoaderType, PQGen, PQGenConn
+from psycopg3.adapt import Dumper, Loader, AdaptersMap
+from psycopg3.proto import AdaptContext, PQGen, PQGenConn
 from psycopg3.connection import BaseConnection
-from psycopg3 import pq
+from psycopg3.pq import Format
+from psycopg3.pq.proto import PGconn, PGresult
 
-class Transformer:
-    def __init__(self, context: AdaptContext = None): ...
+class Transformer(AdaptContext):
+    def __init__(self, context: Optional[AdaptContext] = None): ...
     @property
     def connection(self) -> Optional[BaseConnection]: ...
     @property
-    def encoding(self) -> str: ...
+    def adapters(self) -> AdaptersMap: ...
     @property
-    def dumpers(self) -> DumpersMap: ...
-    @property
-    def loaders(self) -> LoadersMap: ...
-    @property
-    def pgresult(self) -> Optional[pq.proto.PGresult]: ...
+    def pgresult(self) -> Optional[PGresult]: ...
     @pgresult.setter
-    def pgresult(self, result: Optional[pq.proto.PGresult]) -> None: ...
-    def set_row_types(
-        self, types: Sequence[Tuple[int, pq.Format]]
-    ) -> None: ...
-    def get_dumper(self, obj: Any, format: pq.Format) -> Dumper: ...
+    def pgresult(self, result: Optional[PGresult]) -> None: ...
+    def set_row_types(self, types: Sequence[Tuple[int, Format]]) -> None: ...
+    def get_dumper(self, obj: Any, format: Format) -> Dumper: ...
     def load_row(self, row: int) -> Optional[Tuple[Any, ...]]: ...
     def load_sequence(
         self, record: Sequence[Optional[bytes]]
     ) -> Tuple[Any, ...]: ...
-    def get_loader(self, oid: int, format: pq.Format) -> Loader: ...
+    def get_loader(self, oid: int, format: Format) -> Loader: ...
 
 def register_builtin_c_adapters() -> None: ...
-def connect(conninfo: str) -> PQGenConn[pq.proto.PGconn]: ...
-def execute(pgconn: pq.proto.PGconn) -> PQGen[List[pq.proto.PGresult]]: ...
+def connect(conninfo: str) -> PQGenConn[PGconn]: ...
+def execute(pgconn: PGconn) -> PQGen[List[PGresult]]: ...
 
 # vim: set syntax=python:
index a4afe3e7b9a65e62d66d2655f87c4cca6e0a9431..5f644375c37a5e05a6b63841e26421934c317867 100644 (file)
@@ -34,15 +34,13 @@ logger = logging.getLogger("psycopg3.adapt")
 
 cdef class CDumper:
     cdef object src
-    cdef public object context
-    cdef public object connection
     cdef public libpq.Oid oid
+    cdef readonly object connection
     cdef PGconn _pgconn
 
-    def __init__(self, src: type, context: AdaptContext = None):
+    def __init__(self, src: type, context: Optional["AdaptContext"] = None):
         self.src = src
-        self.context = context
-        self.connection = _connection_from_context(context)
+        self.connection = context.connection if context is not None else None
         self._pgconn = (
             self.connection.pgconn if self.connection is not None else None
         )
@@ -82,7 +80,7 @@ cdef class CDumper:
             )
             if error:
                 raise e.OperationalError(
-                    f"escape_string failed: {error_message(self.connection)}"
+                    f"escape_string failed: {error_message(self._pgconn)}"
                 )
         else:
             len_out = libpq.PQescapeString(ptr_out + 1, ptr, length)
@@ -97,28 +95,24 @@ cdef class CDumper:
     def register(
         cls,
         src: Union[type, str],
-        context: AdaptContext = None,
+        context: Optional[AdaptContext] = None,
         format: Format = Format.TEXT,
     ) -> None:
-        if not isinstance(src, (str, type)):
-            raise TypeError(
-                f"dumpers should be registered on classes, got {src} instead"
-            )
-        from psycopg3.adapt import Dumper
+        if context is not None:
+            adapters = context.adapters
+        else:
+            from psycopg3.adapt import global_adapters as adapters
 
-        where = context.dumpers if context else Dumper.globals
-        where[src, format] = cls
+        adapters.register_dumper(src, cls, format=format)
 
 
 cdef class CLoader:
     cdef public libpq.Oid oid
-    cdef public object context
-    cdef public object connection
+    cdef public connection
 
-    def __init__(self, oid: int, context: "AdaptContext" = None):
+    def __init__(self, oid: int, context: Optional["AdaptContext"] = None):
         self.oid = oid
-        self.context = context
-        self.connection = _connection_from_context(context)
+        self.connection = context.connection if context is not None else None
 
     cdef object cload(self, const char *data, size_t length):
         raise NotImplementedError()
@@ -133,23 +127,15 @@ cdef class CLoader:
     def register(
         cls,
         oid: int,
-        context: "AdaptContext" = None,
+        context: Optional["AdaptContext"] = None,
         format: Format = Format.TEXT,
     ) -> None:
-        if not isinstance(oid, int):
-            raise TypeError(
-                f"loaders should be registered on oid, got {oid} instead"
-            )
-
-        from psycopg3.adapt import Loader
-
-        where = context.loaders if context else Loader.globals
-        where[oid, format] = cls
-
+        if context is not None:
+            adapters = context.adapters
+        else:
+            from psycopg3.adapt import global_adapters as adapters
 
-cdef _connection_from_context(object context):
-    from psycopg3.adapt import connection_from_context
-    return connection_from_context(context)
+        adapters.register_loader(oid, cls, format=format)
 
 
 def register_builtin_c_adapters():
index a32e542f31bd7216561e8e7afde121407377a683..d1947edcd05e95311021d3aac6cdde5ebdcdc8ff 100644 (file)
@@ -35,20 +35,21 @@ cdef class Transformer:
     state so adapting several values of the same type can use optimisations.
     """
 
-    cdef readonly dict dumpers, loaders
     cdef readonly object connection
-    cdef readonly str encoding
-    cdef list _dumpers_maps, _loaders_maps
+    cdef readonly object adapters
     cdef dict _dumpers_cache, _loaders_cache
     cdef PGresult _pgresult
     cdef int _nfields, _ntuples
-
     cdef list _row_loaders
 
-    def __cinit__(self, context: "AdaptContext" = None):
-        self._dumpers_maps: List["DumpersMap"] = []
-        self._loaders_maps: List["LoadersMap"] = []
-        self._setup_context(context)
+    def __cinit__(self, context: Optional["AdaptContext"] = None):
+        if context is not None:
+            self.adapters = context.adapters
+            self.connection = context.connection
+        else:
+            from psycopg3.adapt import global_adapters
+            self.adapters = global_adapters
+            self.connection = None
 
         # mapping class, fmt -> Dumper instance
         self._dumpers_cache: Dict[Tuple[type, Format], "Dumper"] = {}
@@ -59,56 +60,6 @@ cdef class Transformer:
         self.pgresult = None
         self._row_loaders = []
 
-    def _setup_context(self, context: "AdaptContext") -> None:
-        from psycopg3.adapt import Dumper, Loader
-        from psycopg3.cursor import BaseCursor
-        from psycopg3.connection import BaseConnection
-
-        cdef Transformer ctx
-        if context is None:
-            self.connection = None
-            self.encoding = "utf-8"
-            self.dumpers = {}
-            self.loaders = {}
-            self._dumpers_maps = [self.dumpers]
-            self._loaders_maps = [self.loaders]
-
-        elif isinstance(context, Transformer):
-            # A transformer created from a transformers: usually it happens
-            # for nested types: share the entire state of the parent
-            ctx = context
-            self.connection = ctx.connection
-            self.encoding = ctx.encoding
-            self.dumpers = ctx.dumpers
-            self.loaders = ctx.loaders
-            self._dumpers_maps.extend(ctx._dumpers_maps)
-            self._loaders_maps.extend(ctx._loaders_maps)
-            # the global maps are already in the lists
-            return
-
-        elif isinstance(context, BaseCursor):
-            self.connection = context.connection
-            self.encoding = context.connection.client_encoding
-            self.dumpers = {}
-            self._dumpers_maps.extend(
-                (self.dumpers, context.dumpers, self.connection.dumpers)
-            )
-            self.loaders = {}
-            self._loaders_maps.extend(
-                (self.loaders, context.loaders, self.connection.loaders)
-            )
-
-        elif isinstance(context, BaseConnection):
-            self.connection = context
-            self.encoding = context.client_encoding
-            self.dumpers = {}
-            self._dumpers_maps.extend((self.dumpers, context.dumpers))
-            self.loaders = {}
-            self._loaders_maps.extend((self.loaders, context.loaders))
-
-        self._dumpers_maps.append(Dumper.globals)
-        self._loaders_maps.append(Loader.globals)
-
     @property
     def pgresult(self) -> Optional[PGresult]:
         return self._pgresult
@@ -170,27 +121,25 @@ cdef class Transformer:
         # in contexts from the most specific to the most generic.
         # Also look for superclasses: if you can adapt a type you should be
         # able to adapt its subtypes, otherwise Liskov is sad.
-        for dmap in self._dumpers_maps:
-            for scls in cls.__mro__:
-                dumper_class = dmap.get((scls, format))
-                if not dumper_class:
-                    continue
+        cdef dict dmap = self.adapters._dumpers
+        for scls in cls.__mro__:
+            dumper_class = dmap.get((scls, format))
+            if not dumper_class:
+                continue
 
-                self._dumpers_cache[cls, format] = dumper = dumper_class(cls, self)
-                return dumper
+            self._dumpers_cache[cls, format] = dumper = dumper_class(cls, self)
+            return dumper
 
         # If the adapter is not found, look for its name as a string
-        for dmap in self._dumpers_maps:
-            for scls in cls.__mro__:
-                fqn = f"{cls.__module__}.{scls.__qualname__}"
-                dumper_class = dmap.get((fqn, format))
-                if dumper_class is None:
-                    continue
+        for scls in cls.__mro__:
+            fqn = f"{cls.__module__}.{scls.__qualname__}"
+            dumper_class = dmap.get((fqn, format))
+            if dumper_class is None:
+                continue
 
-                key = (cls, format)
-                dmap[key] = dumper_class
-                self._dumpers_cache[key] = dumper = dumper_class(cls, self)
-                return dumper
+            dmap[cls, format] = dumper_class
+            self._dumpers_cache[cls, format] = dumper = dumper_class(cls, self)
+            return dumper
 
         raise e.ProgrammingError(
             f"cannot adapt type {type(obj).__name__}"
@@ -256,13 +205,8 @@ cdef class Transformer:
         except KeyError:
             pass
 
-        for tcmap in self._loaders_maps:
-            if key in tcmap:
-                loader_cls = tcmap[key]
-                break
-        else:
-            from psycopg3.adapt import Loader
-            loader_cls = Loader.globals[oids.INVALID_OID, format]
-
-        self._loaders_cache[key] = loader = loader_cls(key[0], self)
+        loader_cls = self.adapters._loaders.get(key)
+        if loader_cls is None:
+            loader_cls = self.adapters._loaders[oids.INVALID_OID, format]
+        loader = self._loaders_cache[key] = loader_cls(oid, self)
         return loader
index 80947bc4ce3b09dfb62fc3e7f3368c7c33e11050..075ad91ef884044c1b0959e1523bd3de6711743e 100644 (file)
@@ -25,7 +25,7 @@ cdef class IntDumper(CDumper):
     def __cinit__(self):
         self.oid = oids.INT8_OID
 
-    def __init__(self, src: type, context: AdaptContext = None):
+    def __init__(self, src: type, context: Optional[AdaptContext] = None):
         super().__init__(src, context)
 
     def dump(self, obj) -> bytes:
index 1c48d80ff1b15f83d7285a33582809b18669be8d..a7d7c019ed2768f969ae0bcd5b03fc3daf7bce76 100644 (file)
@@ -18,7 +18,7 @@ cdef class _StringDumper(CDumper):
     cdef char *encoding
     cdef bytes _bytes_encoding  # needed to keep `encoding` alive
 
-    def __init__(self, src: type, context: AdaptContext):
+    def __init__(self, src: type, context: Optional[AdaptContext]):
         super().__init__(src, context)
 
         self.is_utf8 = 0
@@ -72,7 +72,7 @@ cdef class TextLoader(CLoader):
     cdef char *encoding
     cdef bytes _bytes_encoding  # needed to keep `encoding` alive
 
-    def __init__(self, oid: int, context: "AdaptContext" = None):
+    def __init__(self, oid: int, context: Optional[AdaptContext] = None):
         super().__init__(oid, context)
 
         self.is_utf8 = 0
@@ -102,7 +102,7 @@ cdef class BytesDumper(CDumper):
     def __cinit__(self):
         self.oid = oids.BYTEA_OID
 
-    def __init__(self, src: type, context: AdaptContext):
+    def __init__(self, src: type, context: Optional[AdaptContext] = None):
         super().__init__(src, context)
         self.esc = Escaping(self._pgconn)
 
index 94721eda65d7869d230416d1b8c00f22dbc75ee6..dc0bf156ece1a6d17f8a30e7f2134fdcf8b99e55 100644 (file)
@@ -2,7 +2,7 @@ import pytest
 
 from psycopg3.sql import Identifier
 from psycopg3.oids import builtins
-from psycopg3.adapt import Format, Loader
+from psycopg3.adapt import Format, global_adapters
 from psycopg3.types.composite import CompositeInfo
 
 
@@ -40,7 +40,7 @@ def test_dump_tuple(conn, rec, obj):
     info = CompositeInfo.fetch(conn, "tmptype")
     info.register(context=conn)
 
-    res = cur.execute("select %s::tmptype", [obj]).fetchone()[0]
+    res = conn.execute("select %s::tmptype", [obj]).fetchone()[0]
     assert res == obj
 
 
@@ -169,10 +169,10 @@ def test_dump_composite_all_chars(conn, fmt_in, testcomp):
 
 @pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])
 def test_load_composite(conn, testcomp, fmt_out):
-    cur = conn.cursor(format=fmt_out)
     info = CompositeInfo.fetch(conn, "testcomp")
     info.register(conn)
 
+    cur = conn.cursor(format=fmt_out)
     res = cur.execute("select row('hello', 10, 20)::testcomp").fetchone()[0]
     assert res.foo == "hello"
     assert res.bar == 10
@@ -189,7 +189,6 @@ def test_load_composite(conn, testcomp, fmt_out):
 
 @pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])
 def test_load_composite_factory(conn, testcomp, fmt_out):
-    cur = conn.cursor(format=fmt_out)
     info = CompositeInfo.fetch(conn, "testcomp")
 
     class MyThing:
@@ -198,6 +197,7 @@ def test_load_composite_factory(conn, testcomp, fmt_out):
 
     info.register(conn, factory=MyThing)
 
+    cur = conn.cursor(format=fmt_out)
     res = cur.execute("select row('hello', 10, 20)::testcomp").fetchone()[0]
     assert isinstance(res, MyThing)
     assert res.baz == 20.0
@@ -216,20 +216,20 @@ def test_register_scope(conn):
     info.register()
     for fmt in (Format.TEXT, Format.BINARY):
         for oid in (info.oid, info.array_oid):
-            assert Loader.globals.pop((oid, fmt))
+            assert global_adapters._loaders.pop((oid, fmt))
 
     cur = conn.cursor()
     info.register(cur)
     for fmt in (Format.TEXT, Format.BINARY):
         for oid in (info.oid, info.array_oid):
             key = oid, fmt
-            assert key not in Loader.globals
-            assert key not in conn.loaders
-            assert key in cur.loaders
+            assert key not in global_adapters._loaders
+            assert key not in conn.adapters._loaders
+            assert key in cur.adapters._loaders
 
     info.register(conn)
     for fmt in (Format.TEXT, Format.BINARY):
         for oid in (info.oid, info.array_oid):
             key = oid, fmt
-            assert key not in Loader.globals
-            assert key in conn.loaders
+            assert key not in global_adapters._loaders
+            assert key in conn.adapters._loaders