]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Adaptation context transferred to the adapters objects
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 3 Apr 2020 14:23:33 +0000 (03:23 +1300)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 3 Apr 2020 14:23:33 +0000 (03:23 +1300)
psycopg3/adapt.py
psycopg3/types/array.py
psycopg3/types/text.py

index d0f3b9d95bcac4e33fb2db68b5d3da7022542cab..bfe24251ccf78b8cdb4528729a40569d60237b52 100644 (file)
@@ -34,10 +34,13 @@ TypeCastersMap = Dict[Tuple[int, Format], TypeCasterType]
 
 class Adapter:
     globals: AdaptersMap = {}
+    connection: Optional[BaseConnection]
+    cursor: Optional[BaseCursor]
 
-    def __init__(self, src: type, conn: Optional[BaseConnection]):
+    def __init__(self, src: type, context: AdaptContext = None):
         self.src = src
-        self.conn = conn
+        self.context = context
+        self.connection, self.cursor = _solve_context(context)
 
     def adapt(self, obj: Any) -> Union[bytes, Tuple[bytes, int]]:
         raise NotImplementedError()
@@ -101,10 +104,13 @@ class Adapter:
 
 class TypeCaster:
     globals: TypeCastersMap = {}
+    connection: Optional[BaseConnection]
+    cursor: Optional[BaseCursor]
 
-    def __init__(self, oid: int, conn: Optional[BaseConnection]):
+    def __init__(self, oid: int, context: AdaptContext = None):
         self.oid = oid
-        self.conn = conn
+        self.context = context
+        self.connection, self.cursor = _solve_context(context)
 
     def cast(self, data: bytes) -> Any:
         raise NotImplementedError()
@@ -179,20 +185,7 @@ class Transformer:
     cursor: Optional[BaseCursor]
 
     def __init__(self, context: AdaptContext = None):
-        if context is None:
-            self.connection = None
-            self.cursor = None
-        elif isinstance(context, BaseConnection):
-            self.connection = context
-            self.cursor = None
-        elif isinstance(context, BaseCursor):
-            self.connection = context.conn
-            self.cursor = context
-        else:
-            raise TypeError(
-                f"the context should be a connection or cursor,"
-                f" got {type(context)}"
-            )
+        self.connection, self.cursor = _solve_context(context)
 
         # mapping class, fmt -> adaptation function
         self._adapt_funcs: Dict[Tuple[type, Format], AdapterFunc] = {}
@@ -333,11 +326,11 @@ class UnknownCaster(TypeCaster):
     Fallback object to convert unknown types to Python
     """
 
-    def __init__(self, oid: int, conn: Optional[BaseConnection]):
-        super().__init__(oid, conn)
+    def __init__(self, oid: int, context: AdaptContext):
+        super().__init__(oid, context)
         self.decode: DecodeFunc
-        if conn is not None:
-            self.decode = conn.codec.decode
+        if self.connection is not None:
+            self.decode = self.connection.codec.decode
         else:
             self.decode = codecs.lookup("utf8").decode
 
@@ -348,3 +341,19 @@ class UnknownCaster(TypeCaster):
 @TypeCaster.binary(INVALID_OID)
 def cast_unknown(data: bytes) -> bytes:
     return data
+
+
+def _solve_context(
+    context: AdaptContext,
+) -> Tuple[Optional[BaseConnection], Optional[BaseCursor]]:
+    if context is None:
+        return None, None
+    elif isinstance(context, BaseConnection):
+        return context, None
+    elif isinstance(context, BaseCursor):
+        return context.conn, context
+    else:
+        raise TypeError(
+            f"the context should be a connection or cursor,"
+            f" got {type(context)}"
+        )
index c0373a4bc1eced75e78b838bc1172cff00320694..df7ec02cdf4637fc873bf0e5dd2e040a59cc9530 100644 (file)
@@ -5,16 +5,13 @@ Adapters for arrays
 # Copyright (C) 2020 The Psycopg Team
 
 import re
-from typing import Any, List, Optional, TYPE_CHECKING
+from typing import Any, List, Optional
 
 from .. import errors as e
 from ..pq import Format
 from ..adapt import Adapter, TypeCaster, Transformer, UnknownCaster
 from ..adapt import AdaptContext, TypeCasterType, TypeCasterFunc
 
-if TYPE_CHECKING:
-    from ..connection import BaseConnection
-
 
 # from https://www.postgresql.org/docs/current/arrays.html#ARRAYS-IO
 #
@@ -59,9 +56,9 @@ def escape_item(item: Optional[bytes]) -> bytes:
 
 @Adapter.text(list)
 class ListAdapter(Adapter):
-    def __init__(self, cls: type, conn: "BaseConnection"):
-        super().__init__(cls, conn)
-        self.tx = Transformer(conn)
+    def __init__(self, cls: type, context: AdaptContext = None):
+        super().__init__(cls, context)
+        self.tx = Transformer(context)
 
     def adapt(self, obj: List[Any]) -> bytes:
         tokens: List[bytes] = []
@@ -93,14 +90,12 @@ class ListAdapter(Adapter):
 class ArrayCasterBase(TypeCaster):
     base_caster: TypeCasterType
 
-    def __init__(
-        self, oid: int, conn: Optional["BaseConnection"],
-    ):
-        super().__init__(oid, conn)
+    def __init__(self, oid: int, context: AdaptContext = None):
+        super().__init__(oid, context)
         self.caster_func = TypeCasterFunc  # type: ignore
 
         if isinstance(self.base_caster, type):
-            self.caster_func = self.base_caster(oid, conn).cast
+            self.caster_func = self.base_caster(oid, context).cast
         else:
             self.caster_func = type(self).base_caster
 
index 5533fe80f3e4657df02dc4744c1998e909b8748c..91fdbd3056d1a0e97d9858dc3181faa0f63ecc4a 100644 (file)
@@ -7,11 +7,7 @@ Adapters of textual types.
 import codecs
 from typing import Optional, Tuple, Union
 
-from ..adapt import (
-    Adapter,
-    TypeCaster,
-)
-from ..connection import BaseConnection
+from ..adapt import Adapter, TypeCaster, AdaptContext
 from ..utils.typing import EncodeFunc, DecodeFunc
 from ..pq import Escaping
 from .oids import builtins
@@ -24,13 +20,13 @@ BYTEA_OID = builtins["bytea"].oid
 @Adapter.text(str)
 @Adapter.binary(str)
 class StringAdapter(Adapter):
-    def __init__(self, cls: type, conn: BaseConnection):
-        super().__init__(cls, conn)
+    def __init__(self, cls: type, context: AdaptContext):
+        super().__init__(cls, context)
 
         self._encode: EncodeFunc
-        if conn is not None:
-            if conn.encoding != "SQL_ASCII":
-                self._encode = conn.codec.encode
+        if self.connection is not None:
+            if self.connection.encoding != "SQL_ASCII":
+                self._encode = self.connection.codec.encode
             else:
                 self._encode = codecs.lookup("utf8").encode
         else:
@@ -43,16 +39,17 @@ class StringAdapter(Adapter):
 @TypeCaster.text(builtins["text"].oid)
 @TypeCaster.binary(builtins["text"].oid)
 @ArrayCaster.text(builtins["text"].array_oid)
+@ArrayCaster.binary(builtins["text"].array_oid)
 class StringCaster(TypeCaster):
 
     decode: Optional[DecodeFunc]
 
-    def __init__(self, oid: int, conn: BaseConnection):
-        super().__init__(oid, conn)
+    def __init__(self, oid: int, context: AdaptContext):
+        super().__init__(oid, context)
 
-        if conn is not None:
-            if conn.encoding != "SQL_ASCII":
-                self.decode = conn.codec.decode
+        if self.connection is not None:
+            if self.connection.encoding != "SQL_ASCII":
+                self.decode = self.connection.codec.decode
             else:
                 self.decode = None
         else:
@@ -68,10 +65,10 @@ class StringCaster(TypeCaster):
 
 @Adapter.text(bytes)
 class BytesAdapter(Adapter):
-    def __init__(self, cls: type, conn: BaseConnection):
-        super().__init__(cls, conn)
+    def __init__(self, cls: type, context: AdaptContext = None):
+        super().__init__(cls, context)
         self.esc = Escaping(
-            self.conn.pgconn if self.conn is not None else None
+            self.connection.pgconn if self.connection is not None else None
         )
 
     def adapt(self, obj: bytes) -> Tuple[bytes, int]: