]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Fix Range get_key/upgrade
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 19 Jul 2021 18:10:48 +0000 (20:10 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 19 Jul 2021 18:31:22 +0000 (20:31 +0200)
The adapter doesn't need upgrading not when it knows its oid but when it
adapts a Range strict subclass. The difference emerges if a query
handles different types of ranges in composite types.

Problem found by random testing: https://github.com/psycopg/psycopg/runs/3104464032
Thank you, faker!

psycopg/psycopg/types/range.py
tests/types/test_range.py

index 41e89806fe427deffa0b2ea17130650c1b7fbfd8..561f5f466a11d0588f68dd8497dac60b0e23a99c 100644 (file)
@@ -15,7 +15,7 @@ from ..pq import Format
 from ..abc import AdaptContext, Buffer, Dumper, DumperKey
 from ..adapt import RecursiveDumper, RecursiveLoader, PyFormat
 from .._struct import pack_len, unpack_len
-from ..postgres import INVALID_OID
+from ..postgres import INVALID_OID, TEXT_OID
 from .._typeinfo import RangeInfo as RangeInfo  # exported here
 from .composite import SequenceDumper, BaseCompositeLoader
 
@@ -261,7 +261,7 @@ class BaseRangeDumper(RecursiveDumper):
 
     def get_key(self, obj: Range[Any], format: PyFormat) -> DumperKey:
         # If we are a subclass whose oid is specified we don't need upgrade
-        if self.oid != INVALID_OID:
+        if self.cls is not Range:
             return self.cls
 
         item = self._get_item(obj)
@@ -273,7 +273,7 @@ class BaseRangeDumper(RecursiveDumper):
 
     def upgrade(self, obj: Range[Any], format: PyFormat) -> "BaseRangeDumper":
         # If we are a subclass whose oid is specified we don't need upgrade
-        if self.oid != INVALID_OID:
+        if self.cls is not Range:
             return self
 
         item = self._get_item(obj)
@@ -295,7 +295,7 @@ class BaseRangeDumper(RecursiveDumper):
         dumper.sub_dumper = sd
         if sd.oid == INVALID_OID and isinstance(item, str):
             # Work around the normal mapping where text is dumped as unknown
-            dumper.oid = self._get_range_oid(self._types["text"].oid)
+            dumper.oid = self._get_range_oid(TEXT_OID)
         else:
             dumper.oid = self._get_range_oid(sd.oid)
 
index 634aadc1507dc87f2516fae7c28cb390fbf5f9e2..9a35230ad5944958ded33d0db57a4cffe9bb5613 100644 (file)
@@ -317,6 +317,20 @@ def test_load_quoting(conn, testrange, fmt_out):
         assert ord(got.upper) == i + 1
 
 
+@pytest.mark.parametrize("fmt_out", [pq.Format.TEXT, pq.Format.BINARY])
+def test_mixed_array_types(conn, fmt_out):
+    conn.execute("create table testmix (a daterange[], b tstzrange[])")
+    r1 = Range(dt.date(2000, 1, 1), dt.date(2001, 1, 1), "[)")
+    r2 = Range(
+        dt.datetime(2000, 1, 1, tzinfo=dt.timezone.utc),
+        dt.datetime(2001, 1, 1, tzinfo=dt.timezone.utc),
+        "[)",
+    )
+    conn.execute("insert into testmix values (%s, %s)", [[r1], [r2]])
+    got = conn.execute("select * from testmix").fetchone()
+    assert got == ([r1], [r2])
+
+
 class TestRangeObject:
     def test_noparam(self):
         r = Range()