]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add Loader protocol
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 10 Jul 2021 16:46:07 +0000 (18:46 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 11 Jul 2021 12:43:58 +0000 (14:43 +0200)
psycopg/psycopg/_transform.py
psycopg/psycopg/adapt.py
psycopg/psycopg/proto.py
psycopg_c/psycopg_c/_psycopg.pyi
tests/test_adapt.py
tests/typing_example.py

index d2351e2b34eee01e20e2e27c0baf4ea39c303630..c821ea859b95813583be5ef411760000745ab360 100644 (file)
@@ -16,8 +16,8 @@ from .proto import LoadFunc, AdaptContext, PyFormat, DumperKey
 
 if TYPE_CHECKING:
     from .pq.proto import PGresult
-    from .adapt import Loader, AdaptersMap
-    from .proto import Dumper
+    from .adapt import AdaptersMap
+    from .proto import Dumper, Loader
     from .connection import BaseConnection
 
 DumperCache = Dict[DumperKey, "Dumper"]
index f6b1c46d32887943bead9e904ede4496da9656bd..da8d2cccdc27df96f34b8598a2269ebc92fb24f2 100644 (file)
@@ -141,7 +141,7 @@ class AdaptersMap(AdaptContext):
     """
 
     _dumpers: Dict[PyFormat, Dict[Union[type, str], Type["proto.Dumper"]]]
-    _loaders: List[Dict[int, Type["Loader"]]]
+    _loaders: List[Dict[int, Type["proto.Loader"]]]
     types: TypesRegistry
 
     # Record if a dumper or loader has an optimised version.
@@ -200,7 +200,7 @@ class AdaptersMap(AdaptContext):
             self._dumpers[fmt][cls] = dumper
 
     def register_loader(
-        self, oid: Union[int, str], loader: Type[Loader]
+        self, oid: Union[int, str], loader: Type["proto.Loader"]
     ) -> None:
         """
         Configure the context to use *loader* to convert data of oid *oid*.
@@ -252,7 +252,7 @@ class AdaptersMap(AdaptContext):
 
     def get_loader(
         self, oid: int, format: pq.Format
-    ) -> Optional[Type[Loader]]:
+    ) -> Optional[Type["proto.Loader"]]:
         """
         Return the loader class for the given oid and format.
 
index b0c92286a291df4fe22caf2e712649f91686cf8c..f318404a9041b5cde8f0a22966bda6b761ed7bf9 100644 (file)
@@ -15,7 +15,6 @@ from .compat import Protocol
 if TYPE_CHECKING:
     from .sql import Composable
     from .rows import Row, RowMaker
-    from . import adapt
     from .adapt import AdaptersMap
     from .pq.proto import PGresult
 
@@ -93,6 +92,16 @@ class Dumper(Protocol):
         ...
 
 
+class Loader(Protocol):
+    format: pq.Format
+
+    def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+        ...
+
+    def load(self, data: Buffer) -> Any:
+        ...
+
+
 class Transformer(Protocol):
     def __init__(self, context: Optional[AdaptContext] = None):
         ...
@@ -140,5 +149,5 @@ class Transformer(Protocol):
     ) -> Tuple[Any, ...]:
         ...
 
-    def get_loader(self, oid: int, format: pq.Format) -> "adapt.Loader":
+    def get_loader(self, oid: int, format: pq.Format) -> Loader:
         ...
index e5ba1ad1edcd2f1ef593e36611192d635179dc37..f0e9846d061afd0ec871a9dbd0c7710589ce4285 100644 (file)
@@ -12,8 +12,8 @@ from typing import Any, Iterable, List, Optional, Sequence, Tuple
 from psycopg import pq
 from psycopg import proto
 from psycopg.rows import Row, RowMaker
-from psycopg.adapt import  Loader, AdaptersMap, PyFormat
-from psycopg.proto import Dumper
+from psycopg.adapt import AdaptersMap, PyFormat
+from psycopg.proto import Dumper, Loader
 from psycopg.pq.proto import PGconn, PGresult
 from psycopg.connection import BaseConnection
 
index caf14d48e7e4ca990fdc83f8fc7eb83863ac2f16..9a5f1b6891e4dad8f332d44c720d29c2a2c10b93 100644 (file)
@@ -118,6 +118,18 @@ def test_dumper_protocol(conn):
     assert sql.Literal("hello").as_string(conn) == "'qelloqello'"
 
 
+def test_loader_protocol(conn):
+
+    # This class doesn't inherit from adapt.Loader but passes a mypy check
+    from .typing_example import MyTextLoader
+
+    conn.adapters.register_loader("text", MyTextLoader)
+    cur = conn.execute("select 'hello'::text")
+    assert cur.fetchone()[0] == "hellohello"
+    cur = conn.execute("select '{hi,ha}'::text[]")
+    assert cur.fetchone()[0] == ["hihi", "haha"]
+
+
 def test_subclass_loader(conn):
     # This might be a C fast object: make sure that the Python code is called
     from psycopg.types.string import TextLoader
index 42be3826297b9156d7e8b53b1b97a01a58c265bf..6746125faf4a0e767d5a70ca74ca660381739540 100644 (file)
@@ -7,7 +7,7 @@ from typing import Any, Callable, Optional, Sequence, Tuple, Union
 
 from psycopg import AnyCursor, Connection, Cursor, ServerCursor, connect
 from psycopg import pq
-from psycopg.proto import Dumper, AdaptContext, PyFormat
+from psycopg.proto import Dumper, Loader, AdaptContext, PyFormat, Buffer
 
 
 def int_row_factory(cursor: AnyCursor[int]) -> Callable[[Sequence[int]], int]:
@@ -94,6 +94,9 @@ def f() -> None:
     assert d.dump("abc") == b"abcabc"
     assert d.quote("abc") == b"'abcabc'"
 
+    lo: Loader = MyTextLoader(0, None)
+    assert lo.load(b"abc") == "abcabc"
+
 
 class MyStrDumper:
     format = pq.Format.TEXT
@@ -117,6 +120,16 @@ class MyStrDumper:
         return self
 
 
+class MyTextLoader:
+    format = pq.Format.TEXT
+
+    def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+        pass
+
+    def load(self, data: Buffer) -> str:
+        return (bytes(data) * 2).decode("utf-8")
+
+
 # This should be the definition of psycopg.adapt.DumperKey, but mypy doesn't
 # support recursive types. When it will, this statement will give an error
 # (unused type: ignore) so we can fix our definition.