]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Subclasses of dumpable objects are dumpable too
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 27 Oct 2020 17:30:50 +0000 (18:30 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 27 Oct 2020 17:35:37 +0000 (18:35 +0100)
psycopg3/psycopg3/transform.py
psycopg3_c/psycopg3_c/transform.pyx
tests/test_adapt.py

index 83593e866f54fcdded79b2401fceb4ff8047fce6..345b7d8ea55e3b7bed5ceeb271c9582084a73962 100644 (file)
@@ -144,24 +144,28 @@ class Transformer:
             rc.append(self.get_loader(oid, fmt).load)
 
     def get_dumper(self, obj: Any, format: Format) -> "Dumper":
-        key = (type(obj), format)
+        # Fast path: return a Dumper class already instantiated from the same type
+        cls = type(obj)
         try:
-            return self._dumpers_cache[key]
+            return self._dumpers_cache[cls, format]
         except KeyError:
             pass
 
-        for amap in self._dumpers_maps:
-            if key in amap:
-                dumper_cls = amap[key]
-                break
-        else:
-            raise e.ProgrammingError(
-                f"cannot adapt type {type(obj).__name__}"
-                f" to format {Format(format).name}"
-            )
-
-        self._dumpers_cache[key] = dumper = dumper_cls(key[0], self)
-        return dumper
+        # We haven't seen this type in this query yet. Look for an adapter
+        # in contexts from the most specific to the most generic.
+        # Also look for superclasses: if you can adapt a type you should be
+        # able to adapt its subtypes, otherwise Liskov is sad.
+        for dmap in self._dumpers_maps:
+            for scls in cls.__mro__:
+                key = (scls, format)
+                if key in dmap:
+                    self._dumpers_cache[key] = dumper = dmap[key](scls, self)
+                    return dumper
+
+        raise e.ProgrammingError(
+            f"cannot adapt type {type(obj).__name__}"
+            f" to format {Format(format).name}"
+        )
 
     def load_row(self, row: int) -> Optional[Tuple[Any, ...]]:
         res = self.pgresult
index ca5a711da6554bedd82af8e7270c3cfc24ffcadb..e79a32dbad4092e6c812de3a58af9bb22e133070 100644 (file)
@@ -178,24 +178,28 @@ cdef class Transformer:
         return row_loader
 
     def get_dumper(self, obj: Any, format: Format) -> "Dumper":
-        key = (type(obj), format)
+        # Fast path: return a Dumper class already instantiated from the same type
+        cls = type(obj)
         try:
-            return self._dumpers_cache[key]
+            return self._dumpers_cache[cls, format]
         except KeyError:
             pass
 
-        for amap in self._dumpers_maps:
-            if key in amap:
-                dumper_cls = amap[key]
-                break
-        else:
-            raise e.ProgrammingError(
-                f"cannot adapt type {type(obj).__name__}"
-                f" to format {Format(format).name}"
-            )
-
-        self._dumpers_cache[key] = dumper = dumper_cls(key[0], self)
-        return dumper
+        # We haven't seen this type in this query yet. Look for an adapter
+        # in contexts from the most specific to the most generic.
+        # Also look for superclasses: if you can adapt a type you should be
+        # able to adapt its subtypes, otherwise Liskov is sad.
+        for dmap in self._dumpers_maps:
+            for scls in cls.__mro__:
+                key = (scls, format)
+                if key in dmap:
+                    self._dumpers_cache[key] = dumper = dmap[key](scls, self)
+                    return dumper
+
+        raise e.ProgrammingError(
+            f"cannot adapt type {type(obj).__name__}"
+            f" to format {Format(format).name}"
+        )
 
     def load_row(self, row: int) -> Optional[Tuple[Any, ...]]:
         if self._pgresult is None:
index 9bb0745531caed90dcbb0bfd7b849e046cd05597..b566c1859990b582d7e3f40a09868864a006de1f 100644 (file)
@@ -45,6 +45,16 @@ def test_dump_cursor_ctx(conn):
     assert cur.fetchone() == ("hellot", "worldb")
 
 
+@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])
+def test_dump_subclass(conn, fmt_out):
+    class MyString(str):
+        pass
+
+    cur = conn.cursor()
+    cur.execute("select %s, %b", [MyString("hello"), MyString("world")])
+    assert cur.fetchone() == ("hello", "world")
+
+
 @pytest.mark.parametrize(
     "data, format, type, result",
     [