]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add timedelta/interval binary adapter
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 13 May 2021 02:10:31 +0000 (04:10 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 13 May 2021 02:10:31 +0000 (04:10 +0200)
psycopg3/psycopg3/types/__init__.py
psycopg3/psycopg3/types/date.py
tests/fix_faker.py
tests/types/test_date.py

index aecec425b6378fa96280517aaac758a12713f276..3826058bc287051c6f2b8a9295f501fde8ccebde 100644 (file)
@@ -91,6 +91,7 @@ from .date import (
     DateTimeDumper as DateTimeDumper,
     DateTimeBinaryDumper as DateTimeBinaryDumper,
     TimeDeltaDumper as TimeDeltaDumper,
+    TimeDeltaBinaryDumper as TimeDeltaBinaryDumper,
     DateLoader as DateLoader,
     DateBinaryLoader as DateBinaryLoader,
     TimeLoader as TimeLoader,
@@ -102,6 +103,7 @@ from .date import (
     TimestampTzLoader as TimestampTzLoader,
     TimestampTzBinaryLoader as TimestampTzBinaryLoader,
     IntervalLoader as IntervalLoader,
+    IntervalBinaryLoader as IntervalBinaryLoader,
 )
 from .json import (
     JsonDumper as JsonDumper,
@@ -220,6 +222,7 @@ def register_default_globals(ctx: AdaptContext) -> None:
     DateTimeTzDumper.register("datetime.datetime", ctx)
     DateTimeTzBinaryDumper.register("datetime.datetime", ctx)
     TimeDeltaDumper.register("datetime.timedelta", ctx)
+    TimeDeltaBinaryDumper.register("datetime.timedelta", ctx)
     DateLoader.register("date", ctx)
     DateBinaryLoader.register("date", ctx)
     TimeLoader.register("time", ctx)
@@ -231,6 +234,7 @@ def register_default_globals(ctx: AdaptContext) -> None:
     TimestampTzLoader.register("timestamptz", ctx)
     TimestampTzBinaryLoader.register("timestamptz", ctx)
     IntervalLoader.register("interval", ctx)
+    IntervalBinaryLoader.register("interval", ctx)
 
     # Currently json binary format is nothing different than text, maybe with
     # an extra memcopy we can avoid.
index 8dcef7bb46879bdf70720056ae4e0c6df2ce26ac..dfd6c9bce1b6ef8022acdd2ac2372e6150522a95 100644 (file)
@@ -24,10 +24,17 @@ _pack_int8 = cast(_PackInt, struct.Struct("!q").pack)
 _unpack_int4 = cast(_UnpackInt, struct.Struct("!i").unpack)
 _unpack_int8 = cast(_UnpackInt, struct.Struct("!q").unpack)
 
+_pack_timetz = cast(Callable[[int, int], bytes], struct.Struct("!qi").pack)
 _unpack_timetz = cast(
     Callable[[bytes], Tuple[int, int]], struct.Struct("!qi").unpack
 )
-_pack_timetz = cast(Callable[[int, int], bytes], struct.Struct("!qi").pack)
+_pack_interval = cast(
+    Callable[[int, int, int], bytes], struct.Struct("!qii").pack
+)
+_unpack_interval = cast(
+    Callable[[bytes], Tuple[int, int, int]], struct.Struct("!qii").unpack
+)
+
 
 _pg_date_epoch_days = date(2000, 1, 1).toordinal()
 _pg_datetime_epoch = datetime(2000, 1, 1)
@@ -234,6 +241,16 @@ class TimeDeltaDumper(Dumper):
         )
 
 
+class TimeDeltaBinaryDumper(Dumper):
+
+    format = Format.BINARY
+    _oid = builtins["interval"].oid
+
+    def dump(self, obj: timedelta) -> bytes:
+        micros = 1_000_000 * obj.seconds + obj.microseconds
+        return _pack_interval(micros, obj.days, 0)
+
+
 class DateLoader(Loader):
 
     format = Format.TEXT
@@ -651,3 +668,18 @@ class IntervalLoader(Loader):
             "can't parse interval with IntervalStyle"
             f" {ints.decode('ascii')}: {data.decode('ascii')}"
         )
+
+
+class IntervalBinaryLoader(Loader):
+
+    format = Format.BINARY
+
+    def load(self, data: Buffer) -> timedelta:
+        micros, days, months = _unpack_interval(data)
+        if months > 0:
+            years, months = divmod(months, 12)
+            days = days + 30 * months + 365 * years
+        elif months < 0:
+            years, months = divmod(-months, 12)
+            days = days - 30 * months - 365 * years
+        return timedelta(days=days, microseconds=micros)
index 73dd1025ee5ead0dd7239f52c1494bf2ce202435..2a74f2c8e6a28c389291f126d90d25b9bd7b9082 100644 (file)
@@ -365,6 +365,9 @@ class Faker:
         h, m = divmod(val, 60)
         return dt.time(h, m, s, ms)
 
+    def make_timedelta(self, spec):
+        return choice([dt.timedelta.min, dt.timedelta.max]) * random()
+
     def make_TimeTz(self, spec):
         rv = self.make_time(spec)
         return rv.replace(tzinfo=self._make_tz(spec))
index b95e932d42053ab9de79145219b6c2e568f90783..2b3d4f816160185658478e87f402ec06b7576ab0 100644 (file)
@@ -428,20 +428,19 @@ def test_dump_time_tz_or_not_tz(conn, val, type, fmt_in):
 # Interval
 #
 
+dump_timedelta_samples = [
+    ("min", "-999999999 days"),
+    ("1d", "1 day"),
+    ("-1d", "-1 day"),
+    ("1s", "1 s"),
+    ("-1s", "-1 s"),
+    ("-1m", "-0.000001 s"),
+    ("1m", "0.000001 s"),
+    ("max", "999999999 days 23:59:59.999999"),
+]
 
-@pytest.mark.parametrize(
-    "val, expr",
-    [
-        ("min", "-999999999 days"),
-        ("1d", "1 day"),
-        ("-1d", "-1 day"),
-        ("1s", "1 s"),
-        ("-1s", "-1 s"),
-        ("-1m", "-0.000001 s"),
-        ("1m", "0.000001 s"),
-        ("max", "999999999 days 23:59:59.999999"),
-    ],
-)
+
+@pytest.mark.parametrize("val, expr", dump_timedelta_samples)
 @pytest.mark.parametrize(
     "intervalstyle",
     ["sql_standard", "postgres", "postgres_verbose", "iso_8601"],
@@ -453,11 +452,10 @@ def test_dump_interval(conn, val, expr, intervalstyle):
     assert cur.fetchone()[0] is True
 
 
-@pytest.mark.xfail  # TODO: binary dump
-@pytest.mark.parametrize("val, expr", [("1s", "1s")])
+@pytest.mark.parametrize("val, expr", dump_timedelta_samples)
 def test_dump_interval_binary(conn, val, expr):
-    cur = conn.cursor()
-    cur.execute(f"select '{expr}'::interval = %b", (as_td(val),))
+    cur = conn.cursor(binary=True)
+    cur.execute(f"select '{expr}'::interval = %t", (as_td(val),))
     assert cur.fetchone()[0] is True
 
 
@@ -488,8 +486,9 @@ def test_dump_interval_binary(conn, val, expr):
         ("-90d", "-3 month"),
     ],
 )
-def test_load_interval(conn, val, expr):
-    cur = conn.cursor()
+@pytest.mark.parametrize("fmt_out", [pq.Format.TEXT, pq.Format.BINARY])
+def test_load_interval(conn, val, expr, fmt_out):
+    cur = conn.cursor(binary=fmt_out)
     cur.execute(f"select '{expr}'::interval")
     assert cur.fetchone()[0] == as_td(val)