_parts: List[QueryPart]
_query = b""
+ _encoding: str = "utf-8"
params: Optional[List[Optional[bytes]]] = None
# these are tuples so they can be used as keys e.g. in prepared stmts
types: Tuple[int, ...] = ()
def __init__(self, transformer: "Transformer"):
self._tx = transformer
- if (
- self._tx.connection
- and self._tx.connection.pgconn.server_version < 100000
- ):
- self._unknown_oid = TEXT_OID
+ conn = transformer.connection
+ if conn:
+ self._encoding = conn.client_encoding
+ if conn.pgconn.server_version < 100000:
+ self._unknown_oid = TEXT_OID
def convert(self, query: Query, vars: Optional[Params]) -> None:
"""
if vars is not None:
self.query, self.formats, self._order, self._parts = _query2pg(
- query, self._tx.encoding
+ query, self._encoding
)
else:
if isinstance(query, str):
- query = query.encode(self._tx.encoding)
+ query = query.encode(self._encoding)
self.query = query
self.formats = self._order = None
self._parts, vars, self._order
)
assert self.formats is not None
- ps = self.params = []
- ts = []
+ ps: List[Optional[bytes]] = [None] * len(params)
+ ts = [self._unknown_oid] * len(params)
for i in range(len(params)):
param = params[i]
if param is not None:
dumper = self._tx.get_dumper(param, self.formats[i])
- ps.append(dumper.dump(param))
- ts.append(dumper.oid)
- else:
- ps.append(None)
- ts.append(self._unknown_oid)
+ ps[i] = dumper.dump(param)
+ ts[i] = dumper.oid
+ self.params = ps
self.types = tuple(ts)
else:
self.params = None
from . import errors as e
from .pq import Format
from .oids import INVALID_OID
-from .proto import AdaptContext, DumpersMap
-from .proto import LoadFunc, LoadersMap
-from .cursor import BaseCursor
-from .connection import BaseConnection
+from .proto import LoadFunc, AdaptContext
if TYPE_CHECKING:
from .pq.proto import PGresult
- from .adapt import Dumper, Loader
+ from .adapt import Dumper, Loader, AdaptersMap
+ from .connection import BaseConnection
-class Transformer:
+class Transformer(AdaptContext):
"""
An object that can adapt efficiently between Python and PostgreSQL.
"""
__module__ = "psycopg3.adapt"
+ _adapters: "AdaptersMap"
+ _pgresult: Optional["PGresult"] = None
- def __init__(self, context: AdaptContext = None):
- self._dumpers: DumpersMap
- self._loaders: LoadersMap
- self._dumpers_maps: List[DumpersMap] = []
- self._loaders_maps: List[LoadersMap] = []
- self._setup_context(context)
- self._pgresult: Optional["PGresult"] = None
+ def __init__(self, context: Optional[AdaptContext] = None):
+ # WARNING: don't store context, or you'll create a loop with the Cursor
+ if context:
+ self._adapters = context.adapters
+ self._connection = context.connection
+
+ else:
+ from .adapt import global_adapters
+
+ self._adapters = global_adapters
+ self._connection = None
# mapping class, fmt -> Dumper instance
self._dumpers_cache: Dict[Tuple[type, Format], "Dumper"] = {}
# the length of the result columns
self._row_loaders: List[LoadFunc] = []
- def _setup_context(self, context: AdaptContext) -> None:
- if not context:
- self._connection = None
- self._encoding = "utf-8"
- self._dumpers = {}
- self._loaders = {}
- self._dumpers_maps = [self._dumpers]
- self._loaders_maps = [self._loaders]
-
- elif isinstance(context, Transformer):
- # A transformer created from a transformers: usually it happens
- # for nested types: share the entire state of the parent
- self._connection = context.connection
- self._encoding = context.encoding
- self._dumpers = context.dumpers
- self._loaders = context.loaders
- self._dumpers_maps.extend(context._dumpers_maps)
- self._loaders_maps.extend(context._loaders_maps)
- # the global maps are already in the lists
- return
-
- elif isinstance(context, BaseCursor):
- self._connection = context.connection
- self._encoding = context.connection.client_encoding
- self._dumpers = {}
- self._dumpers_maps.extend(
- (self._dumpers, context.dumpers, context.connection.dumpers)
- )
- self._loaders = {}
- self._loaders_maps.extend(
- (self._loaders, context.loaders, context.connection.loaders)
- )
-
- elif isinstance(context, BaseConnection):
- self._connection = context
- self._encoding = context.client_encoding
- self._dumpers = {}
- self._dumpers_maps.extend((self._dumpers, context.dumpers))
- self._loaders = {}
- self._loaders_maps.extend((self._loaders, context.loaders))
-
- from .adapt import Dumper, Loader
-
- self._dumpers_maps.append(Dumper.globals)
- self._loaders_maps.append(Loader.globals)
-
@property
def connection(self) -> Optional["BaseConnection"]:
return self._connection
@property
- def encoding(self) -> str:
- return self._encoding
+ def adapters(self) -> "AdaptersMap":
+ return self._adapters
@property
def pgresult(self) -> Optional["PGresult"]:
fmt = result.fformat(i)
rc.append(self.get_loader(oid, fmt).load)
- @property
- def dumpers(self) -> DumpersMap:
- return self._dumpers
-
- @property
- def loaders(self) -> LoadersMap:
- return self._loaders
-
def set_row_types(self, types: Iterable[Tuple[int, Format]]) -> None:
rc = self._row_loaders = []
for oid, fmt in types:
# in contexts from the most specific to the most generic.
# Also look for superclasses: if you can adapt a type you should be
# able to adapt its subtypes, otherwise Liskov is sad.
- for dmap in self._dumpers_maps:
- for scls in cls.__mro__:
- dumper_class = dmap.get((scls, format))
- if not dumper_class:
- continue
+ dmap = self._adapters._dumpers
+ for scls in cls.__mro__:
+ dumper_class = dmap.get((scls, format))
+ if not dumper_class:
+ continue
- self._dumpers_cache[cls, format] = dumper = dumper_class(
- cls, self
- )
- return dumper
+ dumper = self._dumpers_cache[cls, format] = dumper_class(cls, self)
+ return dumper
# If the adapter is not found, look for its name as a string
- for dmap in self._dumpers_maps:
- for scls in cls.__mro__:
- fqn = f"{cls.__module__}.{scls.__qualname__}"
- dumper_class = dmap.get((fqn, format))
- if dumper_class is None:
- continue
-
- key = (cls, format)
- dmap[key] = dumper_class
- self._dumpers_cache[key] = dumper = dumper_class(cls, self)
- return dumper
+ for scls in cls.__mro__:
+ fqn = f"{cls.__module__}.{scls.__qualname__}"
+ dumper_class = dmap.get((fqn, format))
+ if dumper_class is None:
+ continue
+
+ dmap[cls, format] = dumper_class
+ dumper = self._dumpers_cache[cls, format] = dumper_class(cls, self)
+ return dumper
raise e.ProgrammingError(
f"cannot adapt type {type(obj).__name__}"
except KeyError:
pass
- for tcmap in self._loaders_maps:
- if key in tcmap:
- loader_cls = tcmap[key]
- break
- else:
- from .adapt import Loader # noqa
-
- loader_cls = Loader.globals[INVALID_OID, format]
-
- self._loaders_cache[key] = loader = loader_cls(key[0], self)
+ loader_cls = self._adapters._loaders.get(key)
+ if not loader_cls:
+ loader_cls = self._adapters._loaders[INVALID_OID, format]
+ loader = self._loaders_cache[key] = loader_cls(oid, self)
return loader
# Copyright (C) 2020 The Psycopg Team
from abc import ABC, abstractmethod
-from typing import Any, cast, Callable, Optional, Type, Union
+from typing import Any, Callable, Optional, Type, TYPE_CHECKING, Union
from . import pq
from . import proto
from .pq import Format as Format
from .oids import TEXT_OID
-from .proto import AdaptContext, DumpersMap, DumperType, LoadersMap, LoaderType
-from .cursor import BaseCursor
-from .connection import BaseConnection
+from .proto import DumpersMap, DumperType, LoadersMap, LoaderType, AdaptContext
+
+if TYPE_CHECKING:
+ from .connection import BaseConnection
class Dumper(ABC):
Convert Python object of the type *src* to PostgreSQL representation.
"""
- globals: DumpersMap = {}
- connection: Optional[BaseConnection]
+ connection: Optional["BaseConnection"] = None
# A class-wide oid, which will be used by default by instances unless
# the subclass overrides it in init.
_oid: int = 0
- def __init__(self, src: type, context: AdaptContext = None):
+ def __init__(self, src: type, context: Optional[AdaptContext] = None):
self.src = src
- self.context = context
- self.connection = connection_from_context(context)
+ self.connection = context.connection if context else None
+
self.oid = self._oid
"""The oid to pass to the server, if known."""
def register(
cls,
src: Union[type, str],
- context: AdaptContext = None,
+ context: Optional[AdaptContext] = None,
format: Format = Format.TEXT,
) -> None:
"""
Configure *context* to use this dumper to convert object of type *src*.
"""
- if not isinstance(src, (str, type)):
- raise TypeError(
- f"dumpers should be registered on classes, got {src} instead"
- )
-
- where = context.dumpers if context else Dumper.globals
- where[src, format] = cls
+ adapters = context.adapters if context else global_adapters
+ adapters.register_dumper(src, cls, format=format)
@classmethod
def text(cls, src: Union[type, str]) -> Callable[[DumperType], DumperType]:
Convert PostgreSQL objects with OID *oid* to Python objects.
"""
- globals: LoadersMap = {}
- connection: Optional[BaseConnection]
+ connection: Optional["BaseConnection"]
- def __init__(self, oid: int, context: AdaptContext = None):
+ def __init__(self, oid: int, context: Optional[AdaptContext] = None):
self.oid = oid
- self.context = context
- self.connection = connection_from_context(context)
+ self.connection = context.connection if context else None
@abstractmethod
def load(self, data: bytes) -> Any:
def register(
cls,
oid: int,
- context: AdaptContext = None,
+ context: Optional[AdaptContext] = None,
format: Format = Format.TEXT,
) -> None:
"""
Configure *context* to use this loader to convert values with OID *oid*.
"""
- if not isinstance(oid, int):
- raise TypeError(
- f"loaders should be registered on oid, got {oid} instead"
- )
-
- where = context.loaders if context else Loader.globals
- where[oid, format] = cls
+ adapters = context.adapters if context else global_adapters
+ adapters.register_loader(oid, cls, format=format)
@classmethod
def text(cls, oid: int) -> Callable[[LoaderType], LoaderType]:
return binary_
-def connection_from_context(
- context: AdaptContext,
-) -> Optional[BaseConnection]:
- if not context:
- return None
- elif isinstance(context, BaseConnection):
- return context
- elif isinstance(context, BaseCursor):
- return cast(BaseConnection, context.connection)
- elif isinstance(context, Transformer):
- return context.connection
- else:
- raise TypeError(f"can't get a connection from {type(context)}")
+class AdaptersMap:
+ """
+ Map oids to Loaders and types to Dumpers.
+
+ The object can start empty or copy from another object of the same class.
+ Copies are copy-on-write: if the maps are updated make a copy. This way
+ extending e.g. global map by a connection or a connection map from a cursor
+ is cheap: a copy is made only on customisation.
+ """
+
+ _dumpers: DumpersMap
+ _loaders: LoadersMap
+
+ def __init__(self, extend: Optional["AdaptersMap"] = None):
+ if extend:
+ self._dumpers = extend._dumpers
+ self._own_dumpers = False
+ self._loaders = extend._loaders
+ self._own_loaders = False
+ else:
+ self._dumpers = {}
+ self._own_dumpers = True
+ self._loaders = {}
+ self._own_loaders = True
+
+ def register_dumper(
+ self,
+ src: Union[type, str],
+ dumper: Type[Dumper],
+ format: Format = Format.TEXT,
+ ) -> None:
+ """
+ Configure the context to use *dumper* to convert object of type *src*.
+ """
+ if not isinstance(src, (str, type)):
+ raise TypeError(
+ f"dumpers should be registered on classes, got {src} instead"
+ )
+
+ if not self._own_dumpers:
+ self._dumpers = self._dumpers.copy()
+ self._own_dumpers = True
+
+ self._dumpers[src, format] = dumper
+
+ def register_loader(
+ self, oid: int, loader: Type[Loader], format: Format = Format.TEXT
+ ) -> None:
+ """
+ Configure the context to use *loader* to convert data of oid *oid*.
+ """
+ if not isinstance(oid, int):
+ raise TypeError(
+ f"loaders should be registered on oid, got {oid} instead"
+ )
+
+ if not self._own_loaders:
+ self._loaders = self._loaders.copy()
+ self._own_loaders = True
+
+ self._loaders[oid, format] = loader
+
+
+global_adapters = AdaptersMap()
Transformer: Type[proto.Transformer]
import threading
from types import TracebackType
from typing import Any, AsyncIterator, Callable, Iterator, List, NamedTuple
-from typing import Optional, Type, TYPE_CHECKING, TypeVar
+from typing import Optional, Type, TYPE_CHECKING
from weakref import ref, ReferenceType
from functools import partial
from contextlib import contextmanager
from .utils.context import asynccontextmanager
from . import pq
+from . import adapt
from . import cursor
from . import errors as e
from . import waiting
from . import encodings
from .pq import TransactionStatus, ExecStatus, Format
from .sql import Composable
-from .proto import DumpersMap, LoadersMap, PQGen, PQGenConn, RV, Query, Params
+from .proto import PQGen, PQGenConn, RV, Query, Params, AdaptContext
+from .proto import ConnectionType
from .conninfo import make_conninfo
from .generators import notifies
from .transaction import Transaction, AsyncTransaction
Notify.__module__ = "psycopg3"
-C = TypeVar("C", bound="BaseConnection")
NoticeHandler = Callable[[e.Diagnostic], None]
NotifyHandler = Callable[[Notify], None]
-class BaseConnection:
+class BaseConnection(AdaptContext):
"""
Base class for different types of connections.
def __init__(self, pgconn: "PGconn"):
self.pgconn = pgconn # TODO: document this
self._autocommit = False
- self.dumpers: DumpersMap = {}
- self.loaders: LoadersMap = {}
+ self._adapters = adapt.AdaptersMap(adapt.global_adapters)
self._notice_handlers: List[NoticeHandler] = []
self._notify_handlers: List[NotifyHandler] = []
# only a begin/commit and not a savepoint.
self._savepoints: List[str] = []
- self._prepared = PrepareManager()
+ self._prepared: PrepareManager = PrepareManager()
wself = ref(self)
if result.status != ExecStatus.TUPLES_OK:
raise e.error_from_result(result, encoding=self.client_encoding)
+ @property
+ def adapters(self) -> adapt.AdaptersMap:
+ return self._adapters
+
+ @property
+ def connection(self) -> "BaseConnection":
+ # implement the AdaptContext protocol
+ return self
+
def cancel(self) -> None:
"""Cancel the current operation on the connection."""
c = self.pgconn.get_cancel()
@classmethod
def _connect_gen(
- cls: Type[C],
+ cls: Type[ConnectionType],
conninfo: str = "",
*,
autocommit: bool = False,
**kwargs: Any,
- ) -> PQGenConn[C]:
+ ) -> PQGenConn[ConnectionType]:
"""Generator to connect to the database and create a new instance."""
conninfo = make_conninfo(conninfo, **kwargs)
pgconn = yield from connect(conninfo)
from typing import Optional, Sequence, Type, TYPE_CHECKING
from contextlib import contextmanager
-from . import errors as e
from . import pq
+from . import adapt
+from . import errors as e
+
from .pq import ExecStatus, Format
from .copy import Copy, AsyncCopy
-from .proto import ConnectionType, Query, Params, DumpersMap, LoadersMap, PQGen
+from .proto import ConnectionType, Query, Params, PQGen
from ._column import Column
from ._queries import PostgresQuery
from ._preparing import Prepare
):
self._conn = connection
self.format = format
- self.dumpers: DumpersMap = {}
- self.loaders: LoadersMap = {}
+ self._adapters = adapt.AdaptersMap(connection.adapters)
self._reset()
self.arraysize = 1
self._closed = False
"""The connection this cursor is using."""
return self._conn
+ @property
+ def adapters(self) -> adapt.AdaptersMap:
+ return self._adapters
+
@property
def closed(self) -> bool:
"""`True` if the cursor is closed."""
It is implemented as generator because it may send additional queries,
such as `begin`.
"""
- from . import adapt
-
if self.closed:
raise e.InterfaceError("the cursor is closed")
if TYPE_CHECKING:
from .connection import BaseConnection
- from .cursor import BaseCursor
- from .adapt import Dumper, Loader
+ from .adapt import Dumper, Loader, AdaptersMap
from .waiting import Wait, Ready
from .sql import Composable
# Adaptation types
-AdaptContext = Union[None, "BaseConnection", "BaseCursor", "Transformer"]
-
DumpFunc = Callable[[Any], bytes]
DumperType = Type["Dumper"]
DumpersMap = Dict[Tuple[Union[type, str], Format], DumperType]
# as there are both C and a Python implementation
-class Transformer(Protocol):
- def __init__(self, context: AdaptContext = None):
+class AdaptContext(Protocol):
+ """
+ A context describing how types are adapted.
+
+ Example of AdaptContext are connections, cursors, transformers.
+ """
+
+ @property
+ def adapters(self) -> "AdaptersMap":
...
@property
def connection(self) -> Optional["BaseConnection"]:
...
- @property
- def encoding(self) -> str:
+
+class Transformer(Protocol):
+ def __init__(self, context: Optional[AdaptContext] = None):
...
@property
- def pgresult(self) -> Optional[pq.proto.PGresult]:
+ def connection(self) -> Optional["BaseConnection"]:
...
- @pgresult.setter
- def pgresult(self, result: Optional[pq.proto.PGresult]) -> None:
+ @property
+ def adapters(self) -> "AdaptersMap":
...
@property
- def dumpers(self) -> DumpersMap:
+ def pgresult(self) -> Optional[pq.proto.PGresult]:
...
- @property
- def loaders(self) -> LoadersMap:
+ @pgresult.setter
+ def pgresult(self, result: Optional[pq.proto.PGresult]) -> None:
...
def set_row_types(self, types: Sequence[Tuple[int, Format]]) -> None:
# Copyright (C) 2020 The Psycopg Team
+import codecs
import string
from typing import Any, Iterator, List, Optional, Sequence, Union
-from typing import TYPE_CHECKING
from .pq import Escaping, Format
+from .adapt import Transformer
from .proto import AdaptContext
-if TYPE_CHECKING:
- from .connection import BaseConnection
-
-def quote(obj: Any, context: AdaptContext = None) -> str:
+def quote(obj: Any, context: Optional[AdaptContext] = None) -> str:
"""
Adapt a Python object to a quoted SQL string.
rules used, otherwise only global rules are used.
"""
- from .adapt import connection_from_context
-
- conn = connection_from_context(context)
- enc = conn.client_encoding if conn else "utf-8"
- return Literal(obj).as_bytes(context).decode(enc)
+ return Literal(obj).as_string(context)
class Composable(object):
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self._obj!r})"
- def as_bytes(self, context: AdaptContext) -> bytes:
+ def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
"""
Return the value of the object as bytes.
# TODO: add tests and docs for as_bytes
raise NotImplementedError
- def as_string(self, context: AdaptContext) -> str:
+ def as_string(self, context: Optional[AdaptContext]) -> str:
"""
Return the value of the object as string.
:type context: `connection` or `cursor`
"""
- conn = _connection_from_context(context)
- return self.as_bytes(context).decode(conn.client_encoding)
+ conn = context.connection if context else None
+ enc = conn.client_encoding if conn else "utf-8"
+ b = self.as_bytes(context)
+ if isinstance(b, bytes):
+ return b.decode(enc)
+ else:
+ # buffer object
+ return codecs.lookup(enc).decode(b)[0]
def __add__(self, other: "Composable") -> "Composed":
if isinstance(other, Composed):
]
super().__init__(seq)
- def as_bytes(self, context: AdaptContext) -> bytes:
+ def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
return b"".join(obj.as_bytes(context) for obj in self._obj)
def __iter__(self) -> Iterator[Composable]:
if not isinstance(obj, str):
raise TypeError(f"SQL values must be strings, got {obj!r} instead")
- def as_string(self, context: AdaptContext) -> str:
+ def as_string(self, context: Optional[AdaptContext]) -> str:
return self._obj
- def as_bytes(self, context: AdaptContext) -> bytes:
- conn = _connection_from_context(context)
- return self._obj.encode(conn.client_encoding)
+ def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
+ enc = "utf-8"
+ if context:
+ conn = context.connection
+ if conn:
+ enc = conn.client_encoding
+ return self._obj.encode(enc)
def format(self, *args: Any, **kwargs: Any) -> Composed:
"""
def __repr__(self) -> str:
return f"{self.__class__.__name__}({', '.join(map(repr, self._obj))})"
- def as_bytes(self, context: AdaptContext) -> bytes:
- conn = _connection_from_context(context)
+ def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
+ conn = context.connection if context else None
+ if not conn:
+ raise ValueError("a connection is necessary for Identifier")
esc = Escaping(conn.pgconn)
enc = conn.client_encoding
escs = [esc.escape_identifier(s.encode(enc)) for s in self._obj]
"""
- def as_bytes(self, context: AdaptContext) -> bytes:
- from .adapt import Transformer
-
- tx = context if isinstance(context, Transformer) else Transformer()
+ def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
+ tx = Transformer(context)
dumper = tx.get_dumper(self._obj, Format.TEXT)
return dumper.quote(self._obj)
return f"{self.__class__.__name__}({', '.join(parts)})"
- def as_string(self, context: AdaptContext) -> str:
+ def as_string(self, context: Optional[AdaptContext]) -> str:
code = "s" if self._format == Format.TEXT else "b"
return f"%({self._obj}){code}" if self._obj else f"%{code}"
- def as_bytes(self, context: AdaptContext) -> bytes:
- conn = _connection_from_context(context)
- return self.as_string(context).encode(conn.client_encoding)
+ def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
+ conn = context.connection if context else None
+ enc = conn.client_encoding if conn else "utf-8"
+ return self.as_string(context).encode(enc)
# Literals
NULL = SQL("NULL")
DEFAULT = SQL("DEFAULT")
-
-
-def _connection_from_context(context: AdaptContext) -> "BaseConnection":
- from .adapt import connection_from_context
-
- conn = connection_from_context(context)
- if not conn:
- raise ValueError(f"no connection in the context: {context}")
-
- return conn
_oid = TEXT_ARRAY_OID
- def __init__(self, src: type, context: AdaptContext = None):
+ def __init__(self, src: type, context: Optional[AdaptContext] = None):
super().__init__(src, context)
self._tx = Transformer(context)
class BaseArrayLoader(Loader):
base_oid: int
- def __init__(self, oid: int, context: AdaptContext = None):
+ def __init__(self, oid: int, context: Optional[AdaptContext] = None):
super().__init__(oid, context)
self._tx = Transformer(context)
def register(
array_oid: int,
base_oid: int,
- context: AdaptContext = None,
+ context: Optional[AdaptContext] = None,
name: Optional[str] = None,
) -> None:
if not name:
def register(
self,
- context: AdaptContext = None,
+ context: Optional[AdaptContext] = None,
factory: Optional[Callable[..., Any]] = None,
) -> None:
if not factory:
class SequenceDumper(Dumper):
- def __init__(self, src: type, context: AdaptContext = None):
+ def __init__(self, src: type, context: Optional[AdaptContext] = None):
super().__init__(src, context)
self._tx = Transformer(context)
class BaseCompositeLoader(Loader):
- def __init__(self, oid: int, context: AdaptContext = None):
+ def __init__(self, oid: int, context: Optional[AdaptContext] = None):
super().__init__(oid, context)
self._tx = Transformer(context)
import re
import sys
from datetime import date, datetime, time, timedelta
-from typing import cast
+from typing import cast, Optional
from ..oids import builtins
from ..adapt import Dumper, Loader
_oid = builtins["interval"].oid
- def __init__(self, src: type, context: AdaptContext = None):
+ def __init__(self, src: type, context: Optional[AdaptContext] = None):
super().__init__(src, context)
if self.connection:
if (
@Loader.text(builtins["date"].oid)
class DateLoader(Loader):
- def __init__(self, oid: int, context: AdaptContext):
+ def __init__(self, oid: int, context: Optional[AdaptContext] = None):
super().__init__(oid, context)
self._format = self._format_from_context()
_format = "%H:%M:%S.%f%z"
_format_no_micro = _format.replace(".%f", "")
- def __init__(self, oid: int, context: AdaptContext):
+ def __init__(self, oid: int, context: Optional[AdaptContext] = None):
if sys.version_info < (3, 7):
setattr(self, "load", self._load_py36)
@Loader.text(builtins["timestamp"].oid)
class TimestampLoader(DateLoader):
- def __init__(self, oid: int, context: AdaptContext):
+ def __init__(self, oid: int, context: Optional[AdaptContext] = None):
super().__init__(oid, context)
self._format_no_micro = self._format.replace(".%f", "")
@Loader.text(builtins["timestamptz"].oid)
class TimestamptzLoader(TimestampLoader):
- def __init__(self, oid: int, context: AdaptContext):
+ def __init__(self, oid: int, context: Optional[AdaptContext] = None):
if sys.version_info < (3, 7):
setattr(self, "load", self._load_py36)
re.VERBOSE,
)
- def __init__(self, oid: int, context: AdaptContext):
+ def __init__(self, oid: int, context: Optional[AdaptContext] = None):
super().__init__(oid, context)
if self.connection:
ints = self.connection.pgconn.parameter_status(b"IntervalStyle")
# Copyright (C) 2020 The Psycopg Team
-from typing import Callable, Union, TYPE_CHECKING
+from typing import Callable, Optional, Union, TYPE_CHECKING
from ..oids import builtins
from ..adapt import Dumper, Loader
class _LazyIpaddress(Loader):
- def __init__(self, oid: int, context: AdaptContext = None):
+ def __init__(self, oid: int, context: Optional[AdaptContext] = None):
super().__init__(oid, context)
global ip_address, ip_interface, ip_network
from ipaddress import ip_address, ip_interface, ip_network
def register(
self,
- context: AdaptContext = None,
+ context: Optional[AdaptContext] = None,
range_class: Optional[Type[Range[Any]]] = None,
) -> None:
if not range_class:
# Copyright (C) 2020 The Psycopg Team
-from typing import Union, TYPE_CHECKING
+from typing import Optional, Union, TYPE_CHECKING
from ..pq import Escaping
from ..oids import builtins, INVALID_OID
_encoding = "utf-8"
- def __init__(self, src: type, context: AdaptContext):
+ def __init__(self, src: type, context: Optional[AdaptContext] = None):
super().__init__(src, context)
conn = self.connection
_encoding = "utf-8"
- def __init__(self, oid: int, context: AdaptContext):
+ def __init__(self, oid: int, context: Optional[AdaptContext] = None):
super().__init__(oid, context)
conn = self.connection
if conn:
_oid = builtins["bytea"].oid
- def __init__(self, src: type, context: AdaptContext = None):
+ def __init__(self, src: type, context: Optional[AdaptContext] = None):
super().__init__(src, context)
self._esc = Escaping(
self.connection.pgconn if self.connection else None
class ByteaLoader(Loader):
_escaping: "EscapingProto"
- def __init__(self, oid: int, context: AdaptContext = None):
+ def __init__(self, oid: int, context: Optional[AdaptContext] = None):
super().__init__(oid, context)
if not hasattr(self.__class__, "_escaping"):
self.__class__._escaping = Escaping()
# Copyright (C) 2020 The Psycopg Team
-from typing import Callable, TYPE_CHECKING
+from typing import Callable, Optional, TYPE_CHECKING
from ..oids import builtins
from ..adapt import Dumper, Loader
@Loader.text(builtins["uuid"].oid)
class UUIDLoader(Loader):
- def __init__(self, oid: int, context: AdaptContext = None):
+ def __init__(self, oid: int, context: Optional[AdaptContext] = None):
super().__init__(oid, context)
global UUID
from uuid import UUID
from typing import Any, Iterable, List, Optional, Sequence, Tuple
-from psycopg3.adapt import Dumper, Loader
-from psycopg3.proto import AdaptContext, DumpFunc, DumpersMap, DumperType
-from psycopg3.proto import LoadFunc, LoadersMap, LoaderType, PQGen, PQGenConn
+from psycopg3.adapt import Dumper, Loader, AdaptersMap
+from psycopg3.proto import AdaptContext, PQGen, PQGenConn
from psycopg3.connection import BaseConnection
-from psycopg3 import pq
+from psycopg3.pq import Format
+from psycopg3.pq.proto import PGconn, PGresult
-class Transformer:
- def __init__(self, context: AdaptContext = None): ...
+class Transformer(AdaptContext):
+ def __init__(self, context: Optional[AdaptContext] = None): ...
@property
def connection(self) -> Optional[BaseConnection]: ...
@property
- def encoding(self) -> str: ...
+ def adapters(self) -> AdaptersMap: ...
@property
- def dumpers(self) -> DumpersMap: ...
- @property
- def loaders(self) -> LoadersMap: ...
- @property
- def pgresult(self) -> Optional[pq.proto.PGresult]: ...
+ def pgresult(self) -> Optional[PGresult]: ...
@pgresult.setter
- def pgresult(self, result: Optional[pq.proto.PGresult]) -> None: ...
- def set_row_types(
- self, types: Sequence[Tuple[int, pq.Format]]
- ) -> None: ...
- def get_dumper(self, obj: Any, format: pq.Format) -> Dumper: ...
+ def pgresult(self, result: Optional[PGresult]) -> None: ...
+ def set_row_types(self, types: Sequence[Tuple[int, Format]]) -> None: ...
+ def get_dumper(self, obj: Any, format: Format) -> Dumper: ...
def load_row(self, row: int) -> Optional[Tuple[Any, ...]]: ...
def load_sequence(
self, record: Sequence[Optional[bytes]]
) -> Tuple[Any, ...]: ...
- def get_loader(self, oid: int, format: pq.Format) -> Loader: ...
+ def get_loader(self, oid: int, format: Format) -> Loader: ...
def register_builtin_c_adapters() -> None: ...
-def connect(conninfo: str) -> PQGenConn[pq.proto.PGconn]: ...
-def execute(pgconn: pq.proto.PGconn) -> PQGen[List[pq.proto.PGresult]]: ...
+def connect(conninfo: str) -> PQGenConn[PGconn]: ...
+def execute(pgconn: PGconn) -> PQGen[List[PGresult]]: ...
# vim: set syntax=python:
cdef class CDumper:
cdef object src
- cdef public object context
- cdef public object connection
cdef public libpq.Oid oid
+ cdef readonly object connection
cdef PGconn _pgconn
- def __init__(self, src: type, context: AdaptContext = None):
+ def __init__(self, src: type, context: Optional["AdaptContext"] = None):
self.src = src
- self.context = context
- self.connection = _connection_from_context(context)
+ self.connection = context.connection if context is not None else None
self._pgconn = (
self.connection.pgconn if self.connection is not None else None
)
)
if error:
raise e.OperationalError(
- f"escape_string failed: {error_message(self.connection)}"
+ f"escape_string failed: {error_message(self._pgconn)}"
)
else:
len_out = libpq.PQescapeString(ptr_out + 1, ptr, length)
def register(
cls,
src: Union[type, str],
- context: AdaptContext = None,
+ context: Optional[AdaptContext] = None,
format: Format = Format.TEXT,
) -> None:
- if not isinstance(src, (str, type)):
- raise TypeError(
- f"dumpers should be registered on classes, got {src} instead"
- )
- from psycopg3.adapt import Dumper
+ if context is not None:
+ adapters = context.adapters
+ else:
+ from psycopg3.adapt import global_adapters as adapters
- where = context.dumpers if context else Dumper.globals
- where[src, format] = cls
+ adapters.register_dumper(src, cls, format=format)
cdef class CLoader:
cdef public libpq.Oid oid
- cdef public object context
- cdef public object connection
+ cdef public connection
- def __init__(self, oid: int, context: "AdaptContext" = None):
+ def __init__(self, oid: int, context: Optional["AdaptContext"] = None):
self.oid = oid
- self.context = context
- self.connection = _connection_from_context(context)
+ self.connection = context.connection if context is not None else None
cdef object cload(self, const char *data, size_t length):
raise NotImplementedError()
def register(
cls,
oid: int,
- context: "AdaptContext" = None,
+ context: Optional["AdaptContext"] = None,
format: Format = Format.TEXT,
) -> None:
- if not isinstance(oid, int):
- raise TypeError(
- f"loaders should be registered on oid, got {oid} instead"
- )
-
- from psycopg3.adapt import Loader
-
- where = context.loaders if context else Loader.globals
- where[oid, format] = cls
-
+ if context is not None:
+ adapters = context.adapters
+ else:
+ from psycopg3.adapt import global_adapters as adapters
-cdef _connection_from_context(object context):
- from psycopg3.adapt import connection_from_context
- return connection_from_context(context)
+ adapters.register_loader(oid, cls, format=format)
def register_builtin_c_adapters():
state so adapting several values of the same type can use optimisations.
"""
- cdef readonly dict dumpers, loaders
cdef readonly object connection
- cdef readonly str encoding
- cdef list _dumpers_maps, _loaders_maps
+ cdef readonly object adapters
cdef dict _dumpers_cache, _loaders_cache
cdef PGresult _pgresult
cdef int _nfields, _ntuples
-
cdef list _row_loaders
- def __cinit__(self, context: "AdaptContext" = None):
- self._dumpers_maps: List["DumpersMap"] = []
- self._loaders_maps: List["LoadersMap"] = []
- self._setup_context(context)
+ def __cinit__(self, context: Optional["AdaptContext"] = None):
+ if context is not None:
+ self.adapters = context.adapters
+ self.connection = context.connection
+ else:
+ from psycopg3.adapt import global_adapters
+ self.adapters = global_adapters
+ self.connection = None
# mapping class, fmt -> Dumper instance
self._dumpers_cache: Dict[Tuple[type, Format], "Dumper"] = {}
self.pgresult = None
self._row_loaders = []
- def _setup_context(self, context: "AdaptContext") -> None:
- from psycopg3.adapt import Dumper, Loader
- from psycopg3.cursor import BaseCursor
- from psycopg3.connection import BaseConnection
-
- cdef Transformer ctx
- if context is None:
- self.connection = None
- self.encoding = "utf-8"
- self.dumpers = {}
- self.loaders = {}
- self._dumpers_maps = [self.dumpers]
- self._loaders_maps = [self.loaders]
-
- elif isinstance(context, Transformer):
- # A transformer created from a transformers: usually it happens
- # for nested types: share the entire state of the parent
- ctx = context
- self.connection = ctx.connection
- self.encoding = ctx.encoding
- self.dumpers = ctx.dumpers
- self.loaders = ctx.loaders
- self._dumpers_maps.extend(ctx._dumpers_maps)
- self._loaders_maps.extend(ctx._loaders_maps)
- # the global maps are already in the lists
- return
-
- elif isinstance(context, BaseCursor):
- self.connection = context.connection
- self.encoding = context.connection.client_encoding
- self.dumpers = {}
- self._dumpers_maps.extend(
- (self.dumpers, context.dumpers, self.connection.dumpers)
- )
- self.loaders = {}
- self._loaders_maps.extend(
- (self.loaders, context.loaders, self.connection.loaders)
- )
-
- elif isinstance(context, BaseConnection):
- self.connection = context
- self.encoding = context.client_encoding
- self.dumpers = {}
- self._dumpers_maps.extend((self.dumpers, context.dumpers))
- self.loaders = {}
- self._loaders_maps.extend((self.loaders, context.loaders))
-
- self._dumpers_maps.append(Dumper.globals)
- self._loaders_maps.append(Loader.globals)
-
@property
def pgresult(self) -> Optional[PGresult]:
return self._pgresult
# in contexts from the most specific to the most generic.
# Also look for superclasses: if you can adapt a type you should be
# able to adapt its subtypes, otherwise Liskov is sad.
- for dmap in self._dumpers_maps:
- for scls in cls.__mro__:
- dumper_class = dmap.get((scls, format))
- if not dumper_class:
- continue
+ cdef dict dmap = self.adapters._dumpers
+ for scls in cls.__mro__:
+ dumper_class = dmap.get((scls, format))
+ if not dumper_class:
+ continue
- self._dumpers_cache[cls, format] = dumper = dumper_class(cls, self)
- return dumper
+ self._dumpers_cache[cls, format] = dumper = dumper_class(cls, self)
+ return dumper
# If the adapter is not found, look for its name as a string
- for dmap in self._dumpers_maps:
- for scls in cls.__mro__:
- fqn = f"{cls.__module__}.{scls.__qualname__}"
- dumper_class = dmap.get((fqn, format))
- if dumper_class is None:
- continue
+ for scls in cls.__mro__:
+ fqn = f"{cls.__module__}.{scls.__qualname__}"
+ dumper_class = dmap.get((fqn, format))
+ if dumper_class is None:
+ continue
- key = (cls, format)
- dmap[key] = dumper_class
- self._dumpers_cache[key] = dumper = dumper_class(cls, self)
- return dumper
+ dmap[cls, format] = dumper_class
+ self._dumpers_cache[cls, format] = dumper = dumper_class(cls, self)
+ return dumper
raise e.ProgrammingError(
f"cannot adapt type {type(obj).__name__}"
except KeyError:
pass
- for tcmap in self._loaders_maps:
- if key in tcmap:
- loader_cls = tcmap[key]
- break
- else:
- from psycopg3.adapt import Loader
- loader_cls = Loader.globals[oids.INVALID_OID, format]
-
- self._loaders_cache[key] = loader = loader_cls(key[0], self)
+ loader_cls = self.adapters._loaders.get(key)
+ if loader_cls is None:
+ loader_cls = self.adapters._loaders[oids.INVALID_OID, format]
+ loader = self._loaders_cache[key] = loader_cls(oid, self)
return loader
def __cinit__(self):
self.oid = oids.INT8_OID
- def __init__(self, src: type, context: AdaptContext = None):
+ def __init__(self, src: type, context: Optional[AdaptContext] = None):
super().__init__(src, context)
def dump(self, obj) -> bytes:
cdef char *encoding
cdef bytes _bytes_encoding # needed to keep `encoding` alive
- def __init__(self, src: type, context: AdaptContext):
+ def __init__(self, src: type, context: Optional[AdaptContext]):
super().__init__(src, context)
self.is_utf8 = 0
cdef char *encoding
cdef bytes _bytes_encoding # needed to keep `encoding` alive
- def __init__(self, oid: int, context: "AdaptContext" = None):
+ def __init__(self, oid: int, context: Optional[AdaptContext] = None):
super().__init__(oid, context)
self.is_utf8 = 0
def __cinit__(self):
self.oid = oids.BYTEA_OID
- def __init__(self, src: type, context: AdaptContext):
+ def __init__(self, src: type, context: Optional[AdaptContext] = None):
super().__init__(src, context)
self.esc = Escaping(self._pgconn)
from psycopg3.sql import Identifier
from psycopg3.oids import builtins
-from psycopg3.adapt import Format, Loader
+from psycopg3.adapt import Format, global_adapters
from psycopg3.types.composite import CompositeInfo
info = CompositeInfo.fetch(conn, "tmptype")
info.register(context=conn)
- res = cur.execute("select %s::tmptype", [obj]).fetchone()[0]
+ res = conn.execute("select %s::tmptype", [obj]).fetchone()[0]
assert res == obj
@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])
def test_load_composite(conn, testcomp, fmt_out):
- cur = conn.cursor(format=fmt_out)
info = CompositeInfo.fetch(conn, "testcomp")
info.register(conn)
+ cur = conn.cursor(format=fmt_out)
res = cur.execute("select row('hello', 10, 20)::testcomp").fetchone()[0]
assert res.foo == "hello"
assert res.bar == 10
@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])
def test_load_composite_factory(conn, testcomp, fmt_out):
- cur = conn.cursor(format=fmt_out)
info = CompositeInfo.fetch(conn, "testcomp")
class MyThing:
info.register(conn, factory=MyThing)
+ cur = conn.cursor(format=fmt_out)
res = cur.execute("select row('hello', 10, 20)::testcomp").fetchone()[0]
assert isinstance(res, MyThing)
assert res.baz == 20.0
info.register()
for fmt in (Format.TEXT, Format.BINARY):
for oid in (info.oid, info.array_oid):
- assert Loader.globals.pop((oid, fmt))
+ assert global_adapters._loaders.pop((oid, fmt))
cur = conn.cursor()
info.register(cur)
for fmt in (Format.TEXT, Format.BINARY):
for oid in (info.oid, info.array_oid):
key = oid, fmt
- assert key not in Loader.globals
- assert key not in conn.loaders
- assert key in cur.loaders
+ assert key not in global_adapters._loaders
+ assert key not in conn.adapters._loaders
+ assert key in cur.adapters._loaders
info.register(conn)
for fmt in (Format.TEXT, Format.BINARY):
for oid in (info.oid, info.array_oid):
key = oid, fmt
- assert key not in Loader.globals
- assert key in conn.loaders
+ assert key not in global_adapters._loaders
+ assert key in conn.adapters._loaders