]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
The Transformer is an adaptation context
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 7 Apr 2020 10:24:09 +0000 (22:24 +1200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 7 Apr 2020 10:24:09 +0000 (22:24 +1200)
Dropped cursor from Adapter, TypeCaster, Transformer. Added adapters and
casters map on Trasformer.

Solves a problem of customization in composite types: if a type is
customized now composite types containing that type use the same
customization. Also fixed a circular reference between transformer and
cursor.

psycopg3/adapt.py
psycopg3/cursor.py

index 8eef3587bb02f1ded4902ce4d4db94c22388c301..37d6cc17676a24e1bbbf14e8e94e7a8f70664234 100644 (file)
@@ -22,7 +22,7 @@ Format = pq.Format
 
 # Type system
 
-AdaptContext = Union[None, BaseConnection, BaseCursor]
+AdaptContext = Union[None, BaseConnection, BaseCursor, "Transformer"]
 
 MaybeOid = Union[Optional[bytes], Tuple[Optional[bytes], int]]
 AdapterFunc = Callable[[Any], MaybeOid]
@@ -37,12 +37,11 @@ TypeCastersMap = Dict[Tuple[int, Format], TypeCasterType]
 class Adapter:
     globals: AdaptersMap = {}
     connection: Optional[BaseConnection]
-    cursor: Optional[BaseCursor]
 
     def __init__(self, src: type, context: AdaptContext = None):
         self.src = src
         self.context = context
-        self.connection, self.cursor = _solve_context(context)
+        self.connection = _connection_from_context(context)
 
     def adapt(self, obj: Any) -> Union[bytes, Tuple[bytes, int]]:
         raise NotImplementedError()
@@ -60,14 +59,6 @@ class Adapter:
                 f"adapters should be registered on classes, got {src} instead"
             )
 
-        if context is not None and not isinstance(
-            context, (BaseConnection, BaseCursor)
-        ):
-            raise TypeError(
-                f"the context should be a connection or cursor,"
-                f" got {type(context)}"
-            )
-
         if not (
             callable(adapter)
             or (isinstance(adapter, type) and issubclass(adapter, Adapter))
@@ -107,12 +98,11 @@ class Adapter:
 class TypeCaster:
     globals: TypeCastersMap = {}
     connection: Optional[BaseConnection]
-    cursor: Optional[BaseCursor]
 
     def __init__(self, oid: int, context: AdaptContext = None):
         self.oid = oid
         self.context = context
-        self.connection, self.cursor = _solve_context(context)
+        self.connection = _connection_from_context(context)
 
     def cast(self, data: bytes) -> Any:
         raise NotImplementedError()
@@ -130,14 +120,6 @@ class TypeCaster:
                 f"typecasters should be registered on oid, got {oid} instead"
             )
 
-        if context is not None and not isinstance(
-            context, (BaseConnection, BaseCursor)
-        ):
-            raise TypeError(
-                f"the context should be a connection or cursor,"
-                f" got {type(context)}"
-            )
-
         if not (
             callable(caster)
             or (isinstance(caster, type) and issubclass(caster, TypeCaster))
@@ -183,11 +165,13 @@ class Transformer:
     state so adapting several values of the same type can use optimisations.
     """
 
-    connection: Optional[BaseConnection]
-    cursor: Optional[BaseCursor]
-
     def __init__(self, context: AdaptContext = None):
-        self.connection, self.cursor = _solve_context(context)
+        self.connection: Optional[BaseConnection]
+        self.adapters: AdaptersMap
+        self.casters: TypeCastersMap
+        self._adapters_maps: List[AdaptersMap] = []
+        self._casters_maps: List[TypeCastersMap] = []
+        self._setup_context(context)
 
         # mapping class, fmt -> adaptation function
         self._adapt_funcs: Dict[Tuple[type, Format], AdapterFunc] = {}
@@ -199,6 +183,46 @@ class Transformer:
         # the length of the result columns
         self._row_casters: List[TypeCasterFunc] = []
 
+    def _setup_context(self, context: AdaptContext) -> None:
+        if context is None:
+            self.connection = None
+            self.adapters = {}
+            self.casters = {}
+            self._adapters_maps = [self.adapters]
+            self._casters_maps = [self.casters]
+
+        elif isinstance(context, Transformer):
+            # A transformer created from a transformers: usually it happens
+            # for nested types: share the entire state of the parent
+            self.connection = context.connection
+            self.adapters = context.adapters
+            self.casters = context.casters
+            self._adapters_maps.extend(context._adapters_maps)
+            self._casters_maps.extend(context._casters_maps)
+            # the global maps are already in the lists
+            return
+
+        elif isinstance(context, BaseCursor):
+            self.connection = context.conn
+            self.adapters = {}
+            self._adapters_maps.extend(
+                (self.adapters, context.adapters, self.connection.adapters)
+            )
+            self.casters = {}
+            self._casters_maps.extend(
+                (self.casters, context.casters, self.connection.casters)
+            )
+
+        elif isinstance(context, BaseConnection):
+            self.connection = context
+            self.adapters = {}
+            self._adapters_maps.extend((self.adapters, context.adapters))
+            self.casters = {}
+            self._casters_maps.extend((self.casters, context.casters))
+
+        self._adapters_maps.append(Adapter.globals)
+        self._casters_maps.append(TypeCaster.globals)
+
     def adapt_sequence(
         self, objs: Iterable[Any], formats: Iterable[Format]
     ) -> Tuple[List[Optional[bytes]], List[int]]:
@@ -245,20 +269,12 @@ class Transformer:
 
     def lookup_adapter(self, src: type, format: Format) -> AdapterType:
         key = (src, format)
-
-        cur = self.cursor
-        if cur is not None and key in cur.adapters:
-            return cur.adapters[key]
-
-        conn = self.connection
-        if conn is not None and key in conn.adapters:
-            return conn.adapters[key]
-
-        if key in Adapter.globals:
-            return Adapter.globals[key]
+        for amap in self._adapters_maps:
+            if key in amap:
+                return amap[key]
 
         raise e.ProgrammingError(
-            f"cannot adapt type {src} to format {Format(format).name}"
+            f"cannot adapt type {src.__name__} to format {Format(format).name}"
         )
 
     def set_row_types(self, types: Iterable[Tuple[int, Format]]) -> None:
@@ -302,16 +318,9 @@ class Transformer:
     def lookup_caster(self, oid: int, format: Format) -> TypeCasterType:
         key = (oid, format)
 
-        cur = self.cursor
-        if cur is not None and key in cur.casters:
-            return cur.casters[key]
-
-        conn = self.connection
-        if conn is not None and key in conn.casters:
-            return conn.casters[key]
-
-        if key in TypeCaster.globals:
-            return TypeCaster.globals[key]
+        for tcmap in self._casters_maps:
+            if key in tcmap:
+                return tcmap[key]
 
         return TypeCaster.globals[INVALID_OID, format]
 
@@ -339,17 +348,16 @@ def cast_unknown(data: bytes) -> bytes:
     return data
 
 
-def _solve_context(
+def _connection_from_context(
     context: AdaptContext,
-) -> Tuple[Optional[BaseConnection], Optional[BaseCursor]]:
+) -> Optional[BaseConnection]:
     if context is None:
-        return None, None
+        return None
     elif isinstance(context, BaseConnection):
-        return context, None
+        return context
     elif isinstance(context, BaseCursor):
-        return context.conn, context
+        return context.conn
+    elif isinstance(context, Transformer):
+        return context.connection
     else:
-        raise TypeError(
-            f"the context should be a connection or cursor,"
-            f" got {type(context)}"
-        )
+        raise TypeError(f"can't get a connection from {type(context)}")
index deb212059a082ef0f16d9e69618fbc4b54e5943d..6c809db73e5c4c097a37d779a9e21f8fc7eb7b42 100644 (file)
@@ -32,7 +32,7 @@ class BaseCursor:
     def _reset(self) -> None:
         from .adapt import Transformer
 
-        self._transformer = Transformer(self)  # TODO: circular reference
+        self._transformer = Transformer(self)
         self._results: List[PGresult] = []
         self.pgresult: Optional[PGresult] = None
         self._pos = 0