]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add Dumper protocol
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 10 Jul 2021 15:13:21 +0000 (17:13 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 10 Jul 2021 16:46:28 +0000 (18:46 +0200)
The presence of this protocol allows to implement custom dumper that
don't inherit from psycopg ones and still be statically checked.

psycopg/psycopg/_transform.py
psycopg/psycopg/adapt.py
psycopg/psycopg/proto.py
psycopg/psycopg/types/array.py
psycopg/psycopg/types/range.py
psycopg_c/psycopg_c/_psycopg.pyi
tests/test_adapt.py
tests/typing_example.py

index f9c5e58c57db993c75bfcb85cb32fa71d90dfdd1..00f1d555bbc2947443e811b2d6c65462fa789916 100644 (file)
@@ -17,7 +17,8 @@ from ._enums import Format
 
 if TYPE_CHECKING:
     from .pq.proto import PGresult
-    from .adapt import Dumper, Loader, AdaptersMap
+    from .adapt import Loader, AdaptersMap
+    from .proto import Dumper
     from .connection import BaseConnection
 
 DumperKey = Union[type, Tuple[type, ...]]
index 6acc9bb1f87ebea97ed739311a568aba936123bd..ea2d78c5967950a069677afd2daf522a4ee89cbf 100644 (file)
@@ -9,7 +9,6 @@ from typing import Any, Dict, List, Optional, Type, Tuple, Union
 from typing import cast, TYPE_CHECKING, TypeVar
 
 from . import pq
-from . import proto
 from . import errors as e
 from ._enums import Format as Format
 from .oids import postgres_types
@@ -18,6 +17,7 @@ from ._cmodule import _psycopg
 from ._typeinfo import TypesRegistry
 
 if TYPE_CHECKING:
+    from . import proto
     from .connection import BaseConnection
 
 RV = TypeVar("RV")
@@ -140,7 +140,7 @@ class AdaptersMap(AdaptContext):
     is cheap: a copy is made only on customisation.
     """
 
-    _dumpers: Dict[Format, Dict[Union[type, str], Type["Dumper"]]]
+    _dumpers: Dict[Format, Dict[Union[type, str], Type["proto.Dumper"]]]
     _loaders: List[Dict[int, Type["Loader"]]]
     types: TypesRegistry
 
@@ -190,7 +190,8 @@ class AdaptersMap(AdaptContext):
         if _psycopg:
             dumper = self._get_optimised(dumper)
 
-        # Register the dumper both as its format and as default
+        # Register the dumper both as its format and as auto
+        # so that the last dumper registered is used in auto (%s) format
         for fmt in (Format.from_pq(dumper.format), Format.AUTO):
             if not self._own_dumpers[fmt]:
                 self._dumpers[fmt] = self._dumpers[fmt].copy()
@@ -221,7 +222,7 @@ class AdaptersMap(AdaptContext):
 
         self._loaders[fmt][oid] = loader
 
-    def get_dumper(self, cls: type, format: Format) -> Type[Dumper]:
+    def get_dumper(self, cls: type, format: Format) -> Type["proto.Dumper"]:
         """
         Return the dumper class for the given type and format.
 
@@ -289,7 +290,7 @@ _dumpers_shared = dict.fromkeys(Format, False)
 
 global_adapters = AdaptersMap(types=postgres_types)
 
-Transformer: Type[proto.Transformer]
+Transformer: Type["proto.Transformer"]
 
 # Override it with fast object if available
 if _psycopg:
index 6977235748be9535e1c5defd8a538ae0381d3e17..8c96ff07b68269adea97c91ee91b91825a86a312 100644 (file)
@@ -9,16 +9,22 @@ from typing import List, Optional, Sequence, Tuple, TypeVar, Union
 from typing import TYPE_CHECKING
 
 from . import pq
-from ._enums import Format
+from . import _enums
 from .compat import Protocol
 
 if TYPE_CHECKING:
     from .sql import Composable
     from .rows import Row, RowMaker
-    from .adapt import Dumper, Loader, AdaptersMap
+    from . import adapt
+    from .adapt import AdaptersMap
+    from .pq.proto import PGresult
+
     from .waiting import Wait, Ready
     from .connection import BaseConnection
 
+# NOMERGE: change name of _enums.Format
+PyFormat = _enums.Format
+
 # An object implementing the buffer protocol
 Buffer = Union[bytes, bytearray, memoryview]
 
@@ -50,9 +56,6 @@ Wait states.
 DumpFunc = Callable[[Any], bytes]
 LoadFunc = Callable[[bytes], Any]
 
-# TODO: Loader, Dumper should probably become protocols
-# as there are both C and a Python implementation
-
 
 class AdaptContext(Protocol):
     """
@@ -70,6 +73,27 @@ class AdaptContext(Protocol):
         ...
 
 
+class Dumper(Protocol):
+    format: pq.Format
+    oid: int
+    cls: type
+
+    def __init__(self, cls: type, context: Optional[AdaptContext] = None):
+        ...
+
+    def dump(self, obj: Any) -> Buffer:
+        ...
+
+    def quote(self, obj: Any) -> Buffer:
+        ...
+
+    def get_key(self, obj: Any, format: PyFormat) -> DumperKey:
+        ...
+
+    def upgrade(self, obj: Any, format: PyFormat) -> "Dumper":
+        ...
+
+
 class Transformer(Protocol):
     def __init__(self, context: Optional[AdaptContext] = None):
         ...
@@ -83,11 +107,11 @@ class Transformer(Protocol):
         ...
 
     @property
-    def pgresult(self) -> Optional[pq.proto.PGresult]:
+    def pgresult(self) -> Optional["PGresult"]:
         ...
 
     def set_pgresult(
-        self, result: Optional[pq.proto.PGresult], set_loaders: bool = True
+        self, result: Optional["PGresult"], set_loaders: bool = True
     ) -> None:
         ...
 
@@ -97,11 +121,11 @@ class Transformer(Protocol):
         ...
 
     def dump_sequence(
-        self, params: Sequence[Any], formats: Sequence[Format]
+        self, params: Sequence[Any], formats: Sequence[PyFormat]
     ) -> Tuple[List[Any], Tuple[int, ...], Sequence[pq.Format]]:
         ...
 
-    def get_dumper(self, obj: Any, format: Format) -> "Dumper":
+    def get_dumper(self, obj: Any, format: PyFormat) -> Dumper:
         ...
 
     def load_rows(
@@ -117,5 +141,5 @@ class Transformer(Protocol):
     ) -> Tuple[Any, ...]:
         ...
 
-    def get_loader(self, oid: int, format: pq.Format) -> "Loader":
+    def get_loader(self, oid: int, format: pq.Format) -> "adapt.Loader":
         ...
index 85b2c7fcfb46306b2b1f2a2a96ffb397e7445395..b5dfe2c9f2af3d5c922a2ed981849daa33052c2c 100644 (file)
@@ -12,9 +12,8 @@ from typing import cast
 from .. import pq
 from .. import errors as e
 from ..oids import postgres_types, TEXT_OID, TEXT_ARRAY_OID, INVALID_OID
-from ..adapt import Dumper, RecursiveDumper, RecursiveLoader
-from ..adapt import Format as Pg3Format
-from ..proto import AdaptContext, Buffer
+from ..adapt import RecursiveDumper, RecursiveLoader, Format as Pg3Format
+from ..proto import Dumper, AdaptContext, Buffer
 from .._struct import pack_len, unpack_len
 from .._typeinfo import TypeInfo
 
index 212a4460e967993777c5bc5f4b86098c840a400c..a55954a23eb8e03b5493f5f0a1f13a88315f339c 100644 (file)
@@ -12,9 +12,8 @@ from datetime import date, datetime
 
 from ..pq import Format
 from ..oids import postgres_types as builtins, INVALID_OID
-from ..adapt import Dumper, RecursiveDumper, RecursiveLoader
-from ..adapt import Format as Pg3Format
-from ..proto import AdaptContext, Buffer
+from ..adapt import RecursiveDumper, RecursiveLoader, Format as Pg3Format
+from ..proto import Dumper, AdaptContext, Buffer
 from .._struct import pack_len, unpack_len
 from .._typeinfo import RangeInfo as RangeInfo  # exported here
 
index 9854b8bd1589a9b7576c4d23ef9e29d168b28df8..90b3728806a60667382a1e7c38506aeb6b529759 100644 (file)
@@ -12,7 +12,8 @@ from typing import Any, Iterable, List, Optional, Sequence, Tuple
 from psycopg import pq
 from psycopg import proto
 from psycopg.rows import Row, RowMaker
-from psycopg.adapt import Dumper, Loader, AdaptersMap, Format
+from psycopg.adapt import  Loader, AdaptersMap, Format
+from psycopg.proto import Dumper
 from psycopg.pq.proto import PGconn, PGresult
 from psycopg.connection import BaseConnection
 
index 6815d27f1eae00fb139a6d9f7b390f949f0bc014..4f60228c7f88f2f8a921e472850b5c4e5928c850 100644 (file)
@@ -4,7 +4,7 @@ from types import ModuleType
 import pytest
 
 import psycopg
-from psycopg import pq
+from psycopg import pq, sql
 from psycopg.adapt import Transformer, Format, Dumper, Loader
 from psycopg.oids import postgres_types as builtins, TEXT_OID
 from psycopg._cmodule import _psycopg
@@ -105,6 +105,19 @@ def test_subclass_dumper(conn):
     assert conn.execute("select %t", ["hello"]).fetchone()[0] == "hellohello"
 
 
+def test_dumper_protocol(conn):
+
+    # This class doesn't inherit from adapt.Dumper but passes a mypy check
+    from .typing_example import MyStrDumper
+
+    conn.adapters.register_dumper(str, MyStrDumper)
+    cur = conn.execute("select %s", ["hello"])
+    assert cur.fetchone()[0] == "hellohello"
+    cur = conn.execute("select %s", [["hi", "ha"]])
+    assert cur.fetchone()[0] == ["hihi", "haha"]
+    assert sql.Literal("hello").as_string(conn) == "'qelloqello'"
+
+
 def test_subclass_loader(conn):
     # This might be a C fast object: make sure that the Python code is called
     from psycopg.types.string import TextLoader
index b0d4c3e611fc48397eb22a93df73e86771cdb30e..ad4931355692b87160fb4417738df8e3b9483f0a 100644 (file)
@@ -6,6 +6,8 @@ from dataclasses import dataclass
 from typing import Any, Callable, Optional, Sequence, Tuple
 
 from psycopg import AnyCursor, Connection, Cursor, ServerCursor, connect
+from psycopg import pq, adapt
+from psycopg.proto import Dumper, AdaptContext
 
 
 def int_row_factory(cursor: AnyCursor[int]) -> Callable[[Sequence[int]], int]:
@@ -85,3 +87,31 @@ def check_row_factory_connection() -> None:
         cur3.execute("select 42")
         r3 = cur3.fetchone()
         r3 and len(r3)
+
+
+def f() -> None:
+    d: Dumper = MyStrDumper(str, None)
+    assert d.dump("abc") == b"abcabc"
+    assert d.quote("abc") == b"'abcabc'"
+
+
+class MyStrDumper:
+    format = pq.Format.TEXT
+
+    def __init__(self, cls: type, context: Optional[AdaptContext] = None):
+        self.cls = cls
+        self.oid = 25  # text
+
+    def dump(self, obj: str) -> bytes:
+        return (obj * 2).encode("utf-8")
+
+    def quote(self, obj: str) -> bytes:
+        value = self.dump(obj)
+        esc = pq.Escaping()
+        return b"'%s'" % esc.escape_string(value.replace(b"h", b"q"))
+
+    def get_key(self, obj: str, format: adapt.Format) -> type:
+        return self.cls
+
+    def upgrade(self, obj: str, format: adapt.Format) -> "MyStrDumper":
+        return self