if TYPE_CHECKING:
from .pq.proto import PGresult
- from .adapt import Dumper, Loader, AdaptersMap
+ from .adapt import Loader, AdaptersMap
+ from .proto import Dumper
from .connection import BaseConnection
DumperKey = Union[type, Tuple[type, ...]]
from typing import cast, TYPE_CHECKING, TypeVar
from . import pq
-from . import proto
from . import errors as e
from ._enums import Format as Format
from .oids import postgres_types
from ._typeinfo import TypesRegistry
if TYPE_CHECKING:
+ from . import proto
from .connection import BaseConnection
RV = TypeVar("RV")
is cheap: a copy is made only on customisation.
"""
- _dumpers: Dict[Format, Dict[Union[type, str], Type["Dumper"]]]
+ _dumpers: Dict[Format, Dict[Union[type, str], Type["proto.Dumper"]]]
_loaders: List[Dict[int, Type["Loader"]]]
types: TypesRegistry
if _psycopg:
dumper = self._get_optimised(dumper)
- # Register the dumper both as its format and as default
+ # Register the dumper both as its format and as auto
+ # so that the last dumper registered is used in auto (%s) format
for fmt in (Format.from_pq(dumper.format), Format.AUTO):
if not self._own_dumpers[fmt]:
self._dumpers[fmt] = self._dumpers[fmt].copy()
self._loaders[fmt][oid] = loader
- def get_dumper(self, cls: type, format: Format) -> Type[Dumper]:
+ def get_dumper(self, cls: type, format: Format) -> Type["proto.Dumper"]:
"""
Return the dumper class for the given type and format.
global_adapters = AdaptersMap(types=postgres_types)
-Transformer: Type[proto.Transformer]
+Transformer: Type["proto.Transformer"]
# Override it with fast object if available
if _psycopg:
from typing import TYPE_CHECKING
from . import pq
-from ._enums import Format
+from . import _enums
from .compat import Protocol
if TYPE_CHECKING:
from .sql import Composable
from .rows import Row, RowMaker
- from .adapt import Dumper, Loader, AdaptersMap
+ from . import adapt
+ from .adapt import AdaptersMap
+ from .pq.proto import PGresult
+
from .waiting import Wait, Ready
from .connection import BaseConnection
+# NOMERGE: change name of _enums.Format
+PyFormat = _enums.Format
+
# An object implementing the buffer protocol
Buffer = Union[bytes, bytearray, memoryview]
DumpFunc = Callable[[Any], bytes]
LoadFunc = Callable[[bytes], Any]
-# TODO: Loader, Dumper should probably become protocols
-# as there are both C and a Python implementation
-
class AdaptContext(Protocol):
"""
...
+class Dumper(Protocol):
+ format: pq.Format
+ oid: int
+ cls: type
+
+ def __init__(self, cls: type, context: Optional[AdaptContext] = None):
+ ...
+
+ def dump(self, obj: Any) -> Buffer:
+ ...
+
+ def quote(self, obj: Any) -> Buffer:
+ ...
+
+ def get_key(self, obj: Any, format: PyFormat) -> DumperKey:
+ ...
+
+ def upgrade(self, obj: Any, format: PyFormat) -> "Dumper":
+ ...
+
+
class Transformer(Protocol):
def __init__(self, context: Optional[AdaptContext] = None):
...
...
@property
- def pgresult(self) -> Optional[pq.proto.PGresult]:
+ def pgresult(self) -> Optional["PGresult"]:
...
def set_pgresult(
- self, result: Optional[pq.proto.PGresult], set_loaders: bool = True
+ self, result: Optional["PGresult"], set_loaders: bool = True
) -> None:
...
...
def dump_sequence(
- self, params: Sequence[Any], formats: Sequence[Format]
+ self, params: Sequence[Any], formats: Sequence[PyFormat]
) -> Tuple[List[Any], Tuple[int, ...], Sequence[pq.Format]]:
...
- def get_dumper(self, obj: Any, format: Format) -> "Dumper":
+ def get_dumper(self, obj: Any, format: PyFormat) -> Dumper:
...
def load_rows(
) -> Tuple[Any, ...]:
...
- def get_loader(self, oid: int, format: pq.Format) -> "Loader":
+ def get_loader(self, oid: int, format: pq.Format) -> "adapt.Loader":
...
from .. import pq
from .. import errors as e
from ..oids import postgres_types, TEXT_OID, TEXT_ARRAY_OID, INVALID_OID
-from ..adapt import Dumper, RecursiveDumper, RecursiveLoader
-from ..adapt import Format as Pg3Format
-from ..proto import AdaptContext, Buffer
+from ..adapt import RecursiveDumper, RecursiveLoader, Format as Pg3Format
+from ..proto import Dumper, AdaptContext, Buffer
from .._struct import pack_len, unpack_len
from .._typeinfo import TypeInfo
from ..pq import Format
from ..oids import postgres_types as builtins, INVALID_OID
-from ..adapt import Dumper, RecursiveDumper, RecursiveLoader
-from ..adapt import Format as Pg3Format
-from ..proto import AdaptContext, Buffer
+from ..adapt import RecursiveDumper, RecursiveLoader, Format as Pg3Format
+from ..proto import Dumper, AdaptContext, Buffer
from .._struct import pack_len, unpack_len
from .._typeinfo import RangeInfo as RangeInfo # exported here
from psycopg import pq
from psycopg import proto
from psycopg.rows import Row, RowMaker
-from psycopg.adapt import Dumper, Loader, AdaptersMap, Format
+from psycopg.adapt import Loader, AdaptersMap, Format
+from psycopg.proto import Dumper
from psycopg.pq.proto import PGconn, PGresult
from psycopg.connection import BaseConnection
import pytest
import psycopg
-from psycopg import pq
+from psycopg import pq, sql
from psycopg.adapt import Transformer, Format, Dumper, Loader
from psycopg.oids import postgres_types as builtins, TEXT_OID
from psycopg._cmodule import _psycopg
assert conn.execute("select %t", ["hello"]).fetchone()[0] == "hellohello"
+def test_dumper_protocol(conn):
+
+ # This class doesn't inherit from adapt.Dumper but passes a mypy check
+ from .typing_example import MyStrDumper
+
+ conn.adapters.register_dumper(str, MyStrDumper)
+ cur = conn.execute("select %s", ["hello"])
+ assert cur.fetchone()[0] == "hellohello"
+ cur = conn.execute("select %s", [["hi", "ha"]])
+ assert cur.fetchone()[0] == ["hihi", "haha"]
+ assert sql.Literal("hello").as_string(conn) == "'qelloqello'"
+
+
def test_subclass_loader(conn):
# This might be a C fast object: make sure that the Python code is called
from psycopg.types.string import TextLoader
from typing import Any, Callable, Optional, Sequence, Tuple
from psycopg import AnyCursor, Connection, Cursor, ServerCursor, connect
+from psycopg import pq, adapt
+from psycopg.proto import Dumper, AdaptContext
def int_row_factory(cursor: AnyCursor[int]) -> Callable[[Sequence[int]], int]:
cur3.execute("select 42")
r3 = cur3.fetchone()
r3 and len(r3)
+
+
+def f() -> None:
+ d: Dumper = MyStrDumper(str, None)
+ assert d.dump("abc") == b"abcabc"
+ assert d.quote("abc") == b"'abcabc'"
+
+
+class MyStrDumper:
+ format = pq.Format.TEXT
+
+ def __init__(self, cls: type, context: Optional[AdaptContext] = None):
+ self.cls = cls
+ self.oid = 25 # text
+
+ def dump(self, obj: str) -> bytes:
+ return (obj * 2).encode("utf-8")
+
+ def quote(self, obj: str) -> bytes:
+ value = self.dump(obj)
+ esc = pq.Escaping()
+ return b"'%s'" % esc.escape_string(value.replace(b"h", b"q"))
+
+ def get_key(self, obj: str, format: adapt.Format) -> type:
+ return self.cls
+
+ def upgrade(self, obj: str, format: adapt.Format) -> "MyStrDumper":
+ return self