From: Daniele Varrazzo Date: Fri, 3 Apr 2020 14:23:33 +0000 (+1300) Subject: Adaptation context transferred to the adapters objects X-Git-Tag: 3.0.dev0~613 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=677e535f69b51a9aa87c652086c5006388b63b52;p=thirdparty%2Fpsycopg.git Adaptation context transferred to the adapters objects --- diff --git a/psycopg3/adapt.py b/psycopg3/adapt.py index d0f3b9d95..bfe24251c 100644 --- a/psycopg3/adapt.py +++ b/psycopg3/adapt.py @@ -34,10 +34,13 @@ TypeCastersMap = Dict[Tuple[int, Format], TypeCasterType] class Adapter: globals: AdaptersMap = {} + connection: Optional[BaseConnection] + cursor: Optional[BaseCursor] - def __init__(self, src: type, conn: Optional[BaseConnection]): + def __init__(self, src: type, context: AdaptContext = None): self.src = src - self.conn = conn + self.context = context + self.connection, self.cursor = _solve_context(context) def adapt(self, obj: Any) -> Union[bytes, Tuple[bytes, int]]: raise NotImplementedError() @@ -101,10 +104,13 @@ class Adapter: class TypeCaster: globals: TypeCastersMap = {} + connection: Optional[BaseConnection] + cursor: Optional[BaseCursor] - def __init__(self, oid: int, conn: Optional[BaseConnection]): + def __init__(self, oid: int, context: AdaptContext = None): self.oid = oid - self.conn = conn + self.context = context + self.connection, self.cursor = _solve_context(context) def cast(self, data: bytes) -> Any: raise NotImplementedError() @@ -179,20 +185,7 @@ class Transformer: cursor: Optional[BaseCursor] def __init__(self, context: AdaptContext = None): - if context is None: - self.connection = None - self.cursor = None - elif isinstance(context, BaseConnection): - self.connection = context - self.cursor = None - elif isinstance(context, BaseCursor): - self.connection = context.conn - self.cursor = context - else: - raise TypeError( - f"the context should be a connection or cursor," - f" got {type(context)}" - ) + self.connection, self.cursor = _solve_context(context) # mapping class, fmt -> adaptation function self._adapt_funcs: Dict[Tuple[type, Format], AdapterFunc] = {} @@ -333,11 +326,11 @@ class UnknownCaster(TypeCaster): Fallback object to convert unknown types to Python """ - def __init__(self, oid: int, conn: Optional[BaseConnection]): - super().__init__(oid, conn) + def __init__(self, oid: int, context: AdaptContext): + super().__init__(oid, context) self.decode: DecodeFunc - if conn is not None: - self.decode = conn.codec.decode + if self.connection is not None: + self.decode = self.connection.codec.decode else: self.decode = codecs.lookup("utf8").decode @@ -348,3 +341,19 @@ class UnknownCaster(TypeCaster): @TypeCaster.binary(INVALID_OID) def cast_unknown(data: bytes) -> bytes: return data + + +def _solve_context( + context: AdaptContext, +) -> Tuple[Optional[BaseConnection], Optional[BaseCursor]]: + if context is None: + return None, None + elif isinstance(context, BaseConnection): + return context, None + elif isinstance(context, BaseCursor): + return context.conn, context + else: + raise TypeError( + f"the context should be a connection or cursor," + f" got {type(context)}" + ) diff --git a/psycopg3/types/array.py b/psycopg3/types/array.py index c0373a4bc..df7ec02cd 100644 --- a/psycopg3/types/array.py +++ b/psycopg3/types/array.py @@ -5,16 +5,13 @@ Adapters for arrays # Copyright (C) 2020 The Psycopg Team import re -from typing import Any, List, Optional, TYPE_CHECKING +from typing import Any, List, Optional from .. import errors as e from ..pq import Format from ..adapt import Adapter, TypeCaster, Transformer, UnknownCaster from ..adapt import AdaptContext, TypeCasterType, TypeCasterFunc -if TYPE_CHECKING: - from ..connection import BaseConnection - # from https://www.postgresql.org/docs/current/arrays.html#ARRAYS-IO # @@ -59,9 +56,9 @@ def escape_item(item: Optional[bytes]) -> bytes: @Adapter.text(list) class ListAdapter(Adapter): - def __init__(self, cls: type, conn: "BaseConnection"): - super().__init__(cls, conn) - self.tx = Transformer(conn) + def __init__(self, cls: type, context: AdaptContext = None): + super().__init__(cls, context) + self.tx = Transformer(context) def adapt(self, obj: List[Any]) -> bytes: tokens: List[bytes] = [] @@ -93,14 +90,12 @@ class ListAdapter(Adapter): class ArrayCasterBase(TypeCaster): base_caster: TypeCasterType - def __init__( - self, oid: int, conn: Optional["BaseConnection"], - ): - super().__init__(oid, conn) + def __init__(self, oid: int, context: AdaptContext = None): + super().__init__(oid, context) self.caster_func = TypeCasterFunc # type: ignore if isinstance(self.base_caster, type): - self.caster_func = self.base_caster(oid, conn).cast + self.caster_func = self.base_caster(oid, context).cast else: self.caster_func = type(self).base_caster diff --git a/psycopg3/types/text.py b/psycopg3/types/text.py index 5533fe80f..91fdbd305 100644 --- a/psycopg3/types/text.py +++ b/psycopg3/types/text.py @@ -7,11 +7,7 @@ Adapters of textual types. import codecs from typing import Optional, Tuple, Union -from ..adapt import ( - Adapter, - TypeCaster, -) -from ..connection import BaseConnection +from ..adapt import Adapter, TypeCaster, AdaptContext from ..utils.typing import EncodeFunc, DecodeFunc from ..pq import Escaping from .oids import builtins @@ -24,13 +20,13 @@ BYTEA_OID = builtins["bytea"].oid @Adapter.text(str) @Adapter.binary(str) class StringAdapter(Adapter): - def __init__(self, cls: type, conn: BaseConnection): - super().__init__(cls, conn) + def __init__(self, cls: type, context: AdaptContext): + super().__init__(cls, context) self._encode: EncodeFunc - if conn is not None: - if conn.encoding != "SQL_ASCII": - self._encode = conn.codec.encode + if self.connection is not None: + if self.connection.encoding != "SQL_ASCII": + self._encode = self.connection.codec.encode else: self._encode = codecs.lookup("utf8").encode else: @@ -43,16 +39,17 @@ class StringAdapter(Adapter): @TypeCaster.text(builtins["text"].oid) @TypeCaster.binary(builtins["text"].oid) @ArrayCaster.text(builtins["text"].array_oid) +@ArrayCaster.binary(builtins["text"].array_oid) class StringCaster(TypeCaster): decode: Optional[DecodeFunc] - def __init__(self, oid: int, conn: BaseConnection): - super().__init__(oid, conn) + def __init__(self, oid: int, context: AdaptContext): + super().__init__(oid, context) - if conn is not None: - if conn.encoding != "SQL_ASCII": - self.decode = conn.codec.decode + if self.connection is not None: + if self.connection.encoding != "SQL_ASCII": + self.decode = self.connection.codec.decode else: self.decode = None else: @@ -68,10 +65,10 @@ class StringCaster(TypeCaster): @Adapter.text(bytes) class BytesAdapter(Adapter): - def __init__(self, cls: type, conn: BaseConnection): - super().__init__(cls, conn) + def __init__(self, cls: type, context: AdaptContext = None): + super().__init__(cls, context) self.esc = Escaping( - self.conn.pgconn if self.conn is not None else None + self.connection.pgconn if self.connection is not None else None ) def adapt(self, obj: bytes) -> Tuple[bytes, int]: