class Adapter:
globals: AdaptersMap = {}
+ connection: Optional[BaseConnection]
+ cursor: Optional[BaseCursor]
- def __init__(self, src: type, conn: Optional[BaseConnection]):
+ def __init__(self, src: type, context: AdaptContext = None):
self.src = src
- self.conn = conn
+ self.context = context
+ self.connection, self.cursor = _solve_context(context)
def adapt(self, obj: Any) -> Union[bytes, Tuple[bytes, int]]:
raise NotImplementedError()
class TypeCaster:
globals: TypeCastersMap = {}
+ connection: Optional[BaseConnection]
+ cursor: Optional[BaseCursor]
- def __init__(self, oid: int, conn: Optional[BaseConnection]):
+ def __init__(self, oid: int, context: AdaptContext = None):
self.oid = oid
- self.conn = conn
+ self.context = context
+ self.connection, self.cursor = _solve_context(context)
def cast(self, data: bytes) -> Any:
raise NotImplementedError()
cursor: Optional[BaseCursor]
def __init__(self, context: AdaptContext = None):
- if context is None:
- self.connection = None
- self.cursor = None
- elif isinstance(context, BaseConnection):
- self.connection = context
- self.cursor = None
- elif isinstance(context, BaseCursor):
- self.connection = context.conn
- self.cursor = context
- else:
- raise TypeError(
- f"the context should be a connection or cursor,"
- f" got {type(context)}"
- )
+ self.connection, self.cursor = _solve_context(context)
# mapping class, fmt -> adaptation function
self._adapt_funcs: Dict[Tuple[type, Format], AdapterFunc] = {}
Fallback object to convert unknown types to Python
"""
- def __init__(self, oid: int, conn: Optional[BaseConnection]):
- super().__init__(oid, conn)
+ def __init__(self, oid: int, context: AdaptContext):
+ super().__init__(oid, context)
self.decode: DecodeFunc
- if conn is not None:
- self.decode = conn.codec.decode
+ if self.connection is not None:
+ self.decode = self.connection.codec.decode
else:
self.decode = codecs.lookup("utf8").decode
@TypeCaster.binary(INVALID_OID)
def cast_unknown(data: bytes) -> bytes:
return data
+
+
+def _solve_context(
+ context: AdaptContext,
+) -> Tuple[Optional[BaseConnection], Optional[BaseCursor]]:
+ if context is None:
+ return None, None
+ elif isinstance(context, BaseConnection):
+ return context, None
+ elif isinstance(context, BaseCursor):
+ return context.conn, context
+ else:
+ raise TypeError(
+ f"the context should be a connection or cursor,"
+ f" got {type(context)}"
+ )
# Copyright (C) 2020 The Psycopg Team
import re
-from typing import Any, List, Optional, TYPE_CHECKING
+from typing import Any, List, Optional
from .. import errors as e
from ..pq import Format
from ..adapt import Adapter, TypeCaster, Transformer, UnknownCaster
from ..adapt import AdaptContext, TypeCasterType, TypeCasterFunc
-if TYPE_CHECKING:
- from ..connection import BaseConnection
-
# from https://www.postgresql.org/docs/current/arrays.html#ARRAYS-IO
#
@Adapter.text(list)
class ListAdapter(Adapter):
- def __init__(self, cls: type, conn: "BaseConnection"):
- super().__init__(cls, conn)
- self.tx = Transformer(conn)
+ def __init__(self, cls: type, context: AdaptContext = None):
+ super().__init__(cls, context)
+ self.tx = Transformer(context)
def adapt(self, obj: List[Any]) -> bytes:
tokens: List[bytes] = []
class ArrayCasterBase(TypeCaster):
base_caster: TypeCasterType
- def __init__(
- self, oid: int, conn: Optional["BaseConnection"],
- ):
- super().__init__(oid, conn)
+ def __init__(self, oid: int, context: AdaptContext = None):
+ super().__init__(oid, context)
self.caster_func = TypeCasterFunc # type: ignore
if isinstance(self.base_caster, type):
- self.caster_func = self.base_caster(oid, conn).cast
+ self.caster_func = self.base_caster(oid, context).cast
else:
self.caster_func = type(self).base_caster
import codecs
from typing import Optional, Tuple, Union
-from ..adapt import (
- Adapter,
- TypeCaster,
-)
-from ..connection import BaseConnection
+from ..adapt import Adapter, TypeCaster, AdaptContext
from ..utils.typing import EncodeFunc, DecodeFunc
from ..pq import Escaping
from .oids import builtins
@Adapter.text(str)
@Adapter.binary(str)
class StringAdapter(Adapter):
- def __init__(self, cls: type, conn: BaseConnection):
- super().__init__(cls, conn)
+ def __init__(self, cls: type, context: AdaptContext):
+ super().__init__(cls, context)
self._encode: EncodeFunc
- if conn is not None:
- if conn.encoding != "SQL_ASCII":
- self._encode = conn.codec.encode
+ if self.connection is not None:
+ if self.connection.encoding != "SQL_ASCII":
+ self._encode = self.connection.codec.encode
else:
self._encode = codecs.lookup("utf8").encode
else:
@TypeCaster.text(builtins["text"].oid)
@TypeCaster.binary(builtins["text"].oid)
@ArrayCaster.text(builtins["text"].array_oid)
+@ArrayCaster.binary(builtins["text"].array_oid)
class StringCaster(TypeCaster):
decode: Optional[DecodeFunc]
- def __init__(self, oid: int, conn: BaseConnection):
- super().__init__(oid, conn)
+ def __init__(self, oid: int, context: AdaptContext):
+ super().__init__(oid, context)
- if conn is not None:
- if conn.encoding != "SQL_ASCII":
- self.decode = conn.codec.decode
+ if self.connection is not None:
+ if self.connection.encoding != "SQL_ASCII":
+ self.decode = self.connection.codec.decode
else:
self.decode = None
else:
@Adapter.text(bytes)
class BytesAdapter(Adapter):
- def __init__(self, cls: type, conn: BaseConnection):
- super().__init__(cls, conn)
+ def __init__(self, cls: type, context: AdaptContext = None):
+ super().__init__(cls, context)
self.esc = Escaping(
- self.conn.pgconn if self.conn is not None else None
+ self.connection.pgconn if self.connection is not None else None
)
def adapt(self, obj: bytes) -> Tuple[bytes, int]: