From 21ffa6468f3e3da8fb9f7ec2f2a21e9777ca754a Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Mon, 19 Jul 2021 20:10:48 +0200 Subject: [PATCH] Fix Range get_key/upgrade The adapter doesn't need upgrading not when it knows its oid but when it adapts a Range strict subclass. The difference emerges if a query handles different types of ranges in composite types. Problem found by random testing: https://github.com/psycopg/psycopg/runs/3104464032 Thank you, faker! --- psycopg/psycopg/types/range.py | 8 ++++---- tests/types/test_range.py | 14 ++++++++++++++ 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/psycopg/psycopg/types/range.py b/psycopg/psycopg/types/range.py index 41e89806f..561f5f466 100644 --- a/psycopg/psycopg/types/range.py +++ b/psycopg/psycopg/types/range.py @@ -15,7 +15,7 @@ from ..pq import Format from ..abc import AdaptContext, Buffer, Dumper, DumperKey from ..adapt import RecursiveDumper, RecursiveLoader, PyFormat from .._struct import pack_len, unpack_len -from ..postgres import INVALID_OID +from ..postgres import INVALID_OID, TEXT_OID from .._typeinfo import RangeInfo as RangeInfo # exported here from .composite import SequenceDumper, BaseCompositeLoader @@ -261,7 +261,7 @@ class BaseRangeDumper(RecursiveDumper): def get_key(self, obj: Range[Any], format: PyFormat) -> DumperKey: # If we are a subclass whose oid is specified we don't need upgrade - if self.oid != INVALID_OID: + if self.cls is not Range: return self.cls item = self._get_item(obj) @@ -273,7 +273,7 @@ class BaseRangeDumper(RecursiveDumper): def upgrade(self, obj: Range[Any], format: PyFormat) -> "BaseRangeDumper": # If we are a subclass whose oid is specified we don't need upgrade - if self.oid != INVALID_OID: + if self.cls is not Range: return self item = self._get_item(obj) @@ -295,7 +295,7 @@ class BaseRangeDumper(RecursiveDumper): dumper.sub_dumper = sd if sd.oid == INVALID_OID and isinstance(item, str): # Work around the normal mapping where text is dumped as unknown - dumper.oid = self._get_range_oid(self._types["text"].oid) + dumper.oid = self._get_range_oid(TEXT_OID) else: dumper.oid = self._get_range_oid(sd.oid) diff --git a/tests/types/test_range.py b/tests/types/test_range.py index 634aadc15..9a35230ad 100644 --- a/tests/types/test_range.py +++ b/tests/types/test_range.py @@ -317,6 +317,20 @@ def test_load_quoting(conn, testrange, fmt_out): assert ord(got.upper) == i + 1 +@pytest.mark.parametrize("fmt_out", [pq.Format.TEXT, pq.Format.BINARY]) +def test_mixed_array_types(conn, fmt_out): + conn.execute("create table testmix (a daterange[], b tstzrange[])") + r1 = Range(dt.date(2000, 1, 1), dt.date(2001, 1, 1), "[)") + r2 = Range( + dt.datetime(2000, 1, 1, tzinfo=dt.timezone.utc), + dt.datetime(2001, 1, 1, tzinfo=dt.timezone.utc), + "[)", + ) + conn.execute("insert into testmix values (%s, %s)", [[r1], [r2]]) + got = conn.execute("select * from testmix").fetchone() + assert got == ([r1], [r2]) + + class TestRangeObject: def test_noparam(self): r = Range() -- 2.47.3