]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Fixed dumping of integer ranges
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 23 Jan 2021 01:49:14 +0000 (02:49 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 23 Jan 2021 01:55:41 +0000 (02:55 +0100)
psycopg3/psycopg3/types/range.py
tests/types/test_range.py

index 1bc17bffeb2c808b644b8222353dc699bda89143..37c7e0af0f599778c3fafd0c1cca7ad60247a5c4 100644 (file)
@@ -13,7 +13,7 @@ from datetime import date, datetime
 from .. import sql
 from .. import errors as e
 from ..pq import Format
-from ..oids import builtins, TypeInfo
+from ..oids import builtins, TypeInfo, INVALID_OID
 from ..adapt import Buffer, Dumper, Loader, Format as Pg3Format
 from ..proto import AdaptContext
 
@@ -259,32 +259,20 @@ class RangeDumper(SequenceDumper):
         sd = self._tx.get_dumper(item, Pg3Format.TEXT)
         dumper = type(self)(self.cls, self._tx)
         dumper.sub_dumper = sd
-        dumper.oid = self._get_range_oid(sd.oid)
+        if not isinstance(item, int):
+            dumper.oid = self._get_range_oid(sd.oid)
+        else:
+            # postgres won't cast int4range -> int8range so we must use
+            # text format and unknown oid here
+            dumper.oid = INVALID_OID
         return dumper
 
     def _get_item(self, obj: Range[Any]) -> Any:
         """
         Return a member representative of the range
-
-        If the range is numeric return the bigger number in absolute value, to
-        decide the best type to use. Otherwise return any non-null in the
-        range. Return None for an empty range.
         """
-        lo, up = obj._lower, obj._upper
-        if lo is None:
-            rv = up
-        elif up is None:
-            rv = lo
-        else:
-            if isinstance(lo, int):
-                rv = up if abs(up) > abs(lo) else lo
-            else:
-                rv = up
-
-        # Upgrade int2 -> int4 as there's no int2range
-        if isinstance(rv, int) and -(2 ** 15) <= rv < 2 ** 15:
-            rv = 2 ** 15
-        return rv
+        rv = obj.lower
+        return rv if rv is not None else obj.upper
 
     def _get_range_oid(self, sub_oid: int) -> int:
         """
@@ -296,10 +284,7 @@ class RangeDumper(SequenceDumper):
         contexts too
         """
         info = builtins.get_range(sub_oid)
-        if info:
-            return info.oid
-        else:
-            raise e.InterfaceError(f"range for type {sub_oid} unknown")
+        return info.oid if info else INVALID_OID
 
 
 class RangeLoader(BaseCompositeLoader, Generic[T]):
index ac601c230436cd894e3993e7ab8e82ddfcd6af63..294b6b5c8463ca579e04b8c35e4e47fd64456ba6 100644 (file)
@@ -202,8 +202,14 @@ def test_dump_quoting(conn, testrange):
     info.register(conn)
     cur = conn.cursor()
     for i in range(1, 254):
+        # TODO: when types registry is merged to adaptation context and we
+        # are able to establish "the type of the range whose element is text",
+        # this should work without ::testrange cast.
         cur.execute(
-            "select ascii(lower(%(r)s)) = %(low)s and ascii(upper(%(r)s)) = %(up)s",
+            """
+            select ascii(lower(%(r)s::testrange)) = %(low)s
+                and ascii(upper(%(r)s::testrange)) = %(up)s
+            """,
             {"r": Range(chr(i), chr(i + 1)), "low": i, "up": i + 1},
         )
         assert cur.fetchone()[0] is True