From: Daniele Varrazzo Date: Sat, 10 Jul 2021 15:13:21 +0000 (+0200) Subject: Add Dumper protocol X-Git-Tag: 3.0.dev1~28^2~11 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=7f9e63139d35048a3e0d5291eb014e614b97d532;p=thirdparty%2Fpsycopg.git Add Dumper protocol The presence of this protocol allows to implement custom dumper that don't inherit from psycopg ones and still be statically checked. --- diff --git a/psycopg/psycopg/_transform.py b/psycopg/psycopg/_transform.py index f9c5e58c5..00f1d555b 100644 --- a/psycopg/psycopg/_transform.py +++ b/psycopg/psycopg/_transform.py @@ -17,7 +17,8 @@ from ._enums import Format if TYPE_CHECKING: from .pq.proto import PGresult - from .adapt import Dumper, Loader, AdaptersMap + from .adapt import Loader, AdaptersMap + from .proto import Dumper from .connection import BaseConnection DumperKey = Union[type, Tuple[type, ...]] diff --git a/psycopg/psycopg/adapt.py b/psycopg/psycopg/adapt.py index 6acc9bb1f..ea2d78c59 100644 --- a/psycopg/psycopg/adapt.py +++ b/psycopg/psycopg/adapt.py @@ -9,7 +9,6 @@ from typing import Any, Dict, List, Optional, Type, Tuple, Union from typing import cast, TYPE_CHECKING, TypeVar from . import pq -from . import proto from . import errors as e from ._enums import Format as Format from .oids import postgres_types @@ -18,6 +17,7 @@ from ._cmodule import _psycopg from ._typeinfo import TypesRegistry if TYPE_CHECKING: + from . import proto from .connection import BaseConnection RV = TypeVar("RV") @@ -140,7 +140,7 @@ class AdaptersMap(AdaptContext): is cheap: a copy is made only on customisation. """ - _dumpers: Dict[Format, Dict[Union[type, str], Type["Dumper"]]] + _dumpers: Dict[Format, Dict[Union[type, str], Type["proto.Dumper"]]] _loaders: List[Dict[int, Type["Loader"]]] types: TypesRegistry @@ -190,7 +190,8 @@ class AdaptersMap(AdaptContext): if _psycopg: dumper = self._get_optimised(dumper) - # Register the dumper both as its format and as default + # Register the dumper both as its format and as auto + # so that the last dumper registered is used in auto (%s) format for fmt in (Format.from_pq(dumper.format), Format.AUTO): if not self._own_dumpers[fmt]: self._dumpers[fmt] = self._dumpers[fmt].copy() @@ -221,7 +222,7 @@ class AdaptersMap(AdaptContext): self._loaders[fmt][oid] = loader - def get_dumper(self, cls: type, format: Format) -> Type[Dumper]: + def get_dumper(self, cls: type, format: Format) -> Type["proto.Dumper"]: """ Return the dumper class for the given type and format. @@ -289,7 +290,7 @@ _dumpers_shared = dict.fromkeys(Format, False) global_adapters = AdaptersMap(types=postgres_types) -Transformer: Type[proto.Transformer] +Transformer: Type["proto.Transformer"] # Override it with fast object if available if _psycopg: diff --git a/psycopg/psycopg/proto.py b/psycopg/psycopg/proto.py index 697723574..8c96ff07b 100644 --- a/psycopg/psycopg/proto.py +++ b/psycopg/psycopg/proto.py @@ -9,16 +9,22 @@ from typing import List, Optional, Sequence, Tuple, TypeVar, Union from typing import TYPE_CHECKING from . import pq -from ._enums import Format +from . import _enums from .compat import Protocol if TYPE_CHECKING: from .sql import Composable from .rows import Row, RowMaker - from .adapt import Dumper, Loader, AdaptersMap + from . import adapt + from .adapt import AdaptersMap + from .pq.proto import PGresult + from .waiting import Wait, Ready from .connection import BaseConnection +# NOMERGE: change name of _enums.Format +PyFormat = _enums.Format + # An object implementing the buffer protocol Buffer = Union[bytes, bytearray, memoryview] @@ -50,9 +56,6 @@ Wait states. DumpFunc = Callable[[Any], bytes] LoadFunc = Callable[[bytes], Any] -# TODO: Loader, Dumper should probably become protocols -# as there are both C and a Python implementation - class AdaptContext(Protocol): """ @@ -70,6 +73,27 @@ class AdaptContext(Protocol): ... +class Dumper(Protocol): + format: pq.Format + oid: int + cls: type + + def __init__(self, cls: type, context: Optional[AdaptContext] = None): + ... + + def dump(self, obj: Any) -> Buffer: + ... + + def quote(self, obj: Any) -> Buffer: + ... + + def get_key(self, obj: Any, format: PyFormat) -> DumperKey: + ... + + def upgrade(self, obj: Any, format: PyFormat) -> "Dumper": + ... + + class Transformer(Protocol): def __init__(self, context: Optional[AdaptContext] = None): ... @@ -83,11 +107,11 @@ class Transformer(Protocol): ... @property - def pgresult(self) -> Optional[pq.proto.PGresult]: + def pgresult(self) -> Optional["PGresult"]: ... def set_pgresult( - self, result: Optional[pq.proto.PGresult], set_loaders: bool = True + self, result: Optional["PGresult"], set_loaders: bool = True ) -> None: ... @@ -97,11 +121,11 @@ class Transformer(Protocol): ... def dump_sequence( - self, params: Sequence[Any], formats: Sequence[Format] + self, params: Sequence[Any], formats: Sequence[PyFormat] ) -> Tuple[List[Any], Tuple[int, ...], Sequence[pq.Format]]: ... - def get_dumper(self, obj: Any, format: Format) -> "Dumper": + def get_dumper(self, obj: Any, format: PyFormat) -> Dumper: ... def load_rows( @@ -117,5 +141,5 @@ class Transformer(Protocol): ) -> Tuple[Any, ...]: ... - def get_loader(self, oid: int, format: pq.Format) -> "Loader": + def get_loader(self, oid: int, format: pq.Format) -> "adapt.Loader": ... diff --git a/psycopg/psycopg/types/array.py b/psycopg/psycopg/types/array.py index 85b2c7fcf..b5dfe2c9f 100644 --- a/psycopg/psycopg/types/array.py +++ b/psycopg/psycopg/types/array.py @@ -12,9 +12,8 @@ from typing import cast from .. import pq from .. import errors as e from ..oids import postgres_types, TEXT_OID, TEXT_ARRAY_OID, INVALID_OID -from ..adapt import Dumper, RecursiveDumper, RecursiveLoader -from ..adapt import Format as Pg3Format -from ..proto import AdaptContext, Buffer +from ..adapt import RecursiveDumper, RecursiveLoader, Format as Pg3Format +from ..proto import Dumper, AdaptContext, Buffer from .._struct import pack_len, unpack_len from .._typeinfo import TypeInfo diff --git a/psycopg/psycopg/types/range.py b/psycopg/psycopg/types/range.py index 212a4460e..a55954a23 100644 --- a/psycopg/psycopg/types/range.py +++ b/psycopg/psycopg/types/range.py @@ -12,9 +12,8 @@ from datetime import date, datetime from ..pq import Format from ..oids import postgres_types as builtins, INVALID_OID -from ..adapt import Dumper, RecursiveDumper, RecursiveLoader -from ..adapt import Format as Pg3Format -from ..proto import AdaptContext, Buffer +from ..adapt import RecursiveDumper, RecursiveLoader, Format as Pg3Format +from ..proto import Dumper, AdaptContext, Buffer from .._struct import pack_len, unpack_len from .._typeinfo import RangeInfo as RangeInfo # exported here diff --git a/psycopg_c/psycopg_c/_psycopg.pyi b/psycopg_c/psycopg_c/_psycopg.pyi index 9854b8bd1..90b372880 100644 --- a/psycopg_c/psycopg_c/_psycopg.pyi +++ b/psycopg_c/psycopg_c/_psycopg.pyi @@ -12,7 +12,8 @@ from typing import Any, Iterable, List, Optional, Sequence, Tuple from psycopg import pq from psycopg import proto from psycopg.rows import Row, RowMaker -from psycopg.adapt import Dumper, Loader, AdaptersMap, Format +from psycopg.adapt import Loader, AdaptersMap, Format +from psycopg.proto import Dumper from psycopg.pq.proto import PGconn, PGresult from psycopg.connection import BaseConnection diff --git a/tests/test_adapt.py b/tests/test_adapt.py index 6815d27f1..4f60228c7 100644 --- a/tests/test_adapt.py +++ b/tests/test_adapt.py @@ -4,7 +4,7 @@ from types import ModuleType import pytest import psycopg -from psycopg import pq +from psycopg import pq, sql from psycopg.adapt import Transformer, Format, Dumper, Loader from psycopg.oids import postgres_types as builtins, TEXT_OID from psycopg._cmodule import _psycopg @@ -105,6 +105,19 @@ def test_subclass_dumper(conn): assert conn.execute("select %t", ["hello"]).fetchone()[0] == "hellohello" +def test_dumper_protocol(conn): + + # This class doesn't inherit from adapt.Dumper but passes a mypy check + from .typing_example import MyStrDumper + + conn.adapters.register_dumper(str, MyStrDumper) + cur = conn.execute("select %s", ["hello"]) + assert cur.fetchone()[0] == "hellohello" + cur = conn.execute("select %s", [["hi", "ha"]]) + assert cur.fetchone()[0] == ["hihi", "haha"] + assert sql.Literal("hello").as_string(conn) == "'qelloqello'" + + def test_subclass_loader(conn): # This might be a C fast object: make sure that the Python code is called from psycopg.types.string import TextLoader diff --git a/tests/typing_example.py b/tests/typing_example.py index b0d4c3e61..ad4931355 100644 --- a/tests/typing_example.py +++ b/tests/typing_example.py @@ -6,6 +6,8 @@ from dataclasses import dataclass from typing import Any, Callable, Optional, Sequence, Tuple from psycopg import AnyCursor, Connection, Cursor, ServerCursor, connect +from psycopg import pq, adapt +from psycopg.proto import Dumper, AdaptContext def int_row_factory(cursor: AnyCursor[int]) -> Callable[[Sequence[int]], int]: @@ -85,3 +87,31 @@ def check_row_factory_connection() -> None: cur3.execute("select 42") r3 = cur3.fetchone() r3 and len(r3) + + +def f() -> None: + d: Dumper = MyStrDumper(str, None) + assert d.dump("abc") == b"abcabc" + assert d.quote("abc") == b"'abcabc'" + + +class MyStrDumper: + format = pq.Format.TEXT + + def __init__(self, cls: type, context: Optional[AdaptContext] = None): + self.cls = cls + self.oid = 25 # text + + def dump(self, obj: str) -> bytes: + return (obj * 2).encode("utf-8") + + def quote(self, obj: str) -> bytes: + value = self.dump(obj) + esc = pq.Escaping() + return b"'%s'" % esc.escape_string(value.replace(b"h", b"q")) + + def get_key(self, obj: str, format: adapt.Format) -> type: + return self.cls + + def upgrade(self, obj: str, format: adapt.Format) -> "MyStrDumper": + return self