]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added timedelta dump
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 27 Oct 2020 01:40:29 +0000 (02:40 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 28 Oct 2020 03:19:24 +0000 (04:19 +0100)
psycopg3/psycopg3/types/date.py
tests/types/test_date.py

index b54ca4560ae6e1320744728996cfba93d2aabdcc..23b34f1d17f69debc4e2314e395845fb8eb41b31 100644 (file)
@@ -5,7 +5,7 @@ Adapters for date/time types.
 # Copyright (C) 2020 The Psycopg Team
 
 import codecs
-from datetime import date, datetime, time
+from datetime import date, datetime, time, timedelta
 from typing import cast
 
 from ..adapt import Dumper, Loader
@@ -60,6 +60,37 @@ class DateTimeDumper(Dumper):
         return self.TIMESTAMPTZ_OID
 
 
+@Dumper.text(timedelta)
+class TimeDeltaDumper(Dumper):
+    _encode = codecs.lookup("ascii").encode
+    INTERVAL_OID = builtins["interval"].oid
+
+    def __init__(self, src: type, context: AdaptContext = None):
+        super().__init__(src, context)
+        if self.connection:
+            if (
+                self.connection.pgconn.parameter_status(b"IntervalStyle")
+                == b"sql_standard"
+            ):
+                self.dump = self._dump_sql  # type: ignore[assignment]
+
+    def dump(self, obj: timedelta) -> bytes:
+        return self._encode(str(obj))[0]
+
+    def _dump_sql(self, obj: timedelta) -> bytes:
+        # sql_standard format needs explicit signs
+        # otherwise -1 day 1 sec will mean -1 sec
+        return b"%+d day %+d second %+d microsecond" % (
+            obj.days,
+            obj.seconds,
+            obj.microseconds,
+        )
+
+    @property
+    def oid(self) -> int:
+        return self.INTERVAL_OID
+
+
 @Loader.text(builtins["date"].oid)
 class DateLoader(Loader):
 
index 1b76d7bf56348d312b11eacec3ad6c85eed370dc..9dc95e793d6c38ba5aa154dae242949b0fa46958 100644 (file)
@@ -392,6 +392,35 @@ def test_load_timetz_24(conn):
         cur.fetchone()[0]
 
 
+#
+# Interval
+#
+
+
+@pytest.mark.parametrize(
+    "val, expr",
+    [
+        ("min", "-999999999 days"),
+        ("1d", "1 day"),
+        ("-1d", "-1 day"),
+        ("1s", "1 s"),
+        ("-1s", "-1 s"),
+        ("-1m", "-0.000001 s"),
+        ("1m", "0.000001 s"),
+        ("max", "999999999 days 23:59:59.999999"),
+    ],
+)
+@pytest.mark.parametrize(
+    "intervalstyle",
+    ["sql_standard", "postgres", "postgres_verbose", "iso_8601"],
+)
+def test_dump_interval(conn, val, expr, intervalstyle):
+    cur = conn.cursor()
+    cur.execute(f"set IntervalStyle to '{intervalstyle}'")
+    cur.execute(f"select '{expr}'::interval = %s", (as_td(val),))
+    assert cur.fetchone()[0] is True
+
+
 #
 # Support
 #
@@ -444,3 +473,15 @@ def as_tzinfo(s):
         **dict(zip(("hours", "minutes", "seconds"), map(int, s.split(":"))))
     )
     return dt.timezone(tzoff)
+
+
+def as_td(s):
+    if s in ("min", "max"):
+        return getattr(dt.timedelta, s)
+
+    suffixes = {"d": "days", "s": "seconds", "m": "microseconds"}
+    kwargs = {}
+    for part in s.split(","):
+        kwargs[suffixes[part[-1]]] = int(part[:-1])
+
+    return dt.timedelta(**kwargs)