From: Daniele Varrazzo Date: Tue, 27 Oct 2020 01:40:29 +0000 (+0100) Subject: Added timedelta dump X-Git-Tag: 3.0.dev0~430 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=f76daa832fb7a3cddf17c6e54dbd9c9017be51cb;p=thirdparty%2Fpsycopg.git Added timedelta dump --- diff --git a/psycopg3/psycopg3/types/date.py b/psycopg3/psycopg3/types/date.py index b54ca4560..23b34f1d1 100644 --- a/psycopg3/psycopg3/types/date.py +++ b/psycopg3/psycopg3/types/date.py @@ -5,7 +5,7 @@ Adapters for date/time types. # Copyright (C) 2020 The Psycopg Team import codecs -from datetime import date, datetime, time +from datetime import date, datetime, time, timedelta from typing import cast from ..adapt import Dumper, Loader @@ -60,6 +60,37 @@ class DateTimeDumper(Dumper): return self.TIMESTAMPTZ_OID +@Dumper.text(timedelta) +class TimeDeltaDumper(Dumper): + _encode = codecs.lookup("ascii").encode + INTERVAL_OID = builtins["interval"].oid + + def __init__(self, src: type, context: AdaptContext = None): + super().__init__(src, context) + if self.connection: + if ( + self.connection.pgconn.parameter_status(b"IntervalStyle") + == b"sql_standard" + ): + self.dump = self._dump_sql # type: ignore[assignment] + + def dump(self, obj: timedelta) -> bytes: + return self._encode(str(obj))[0] + + def _dump_sql(self, obj: timedelta) -> bytes: + # sql_standard format needs explicit signs + # otherwise -1 day 1 sec will mean -1 sec + return b"%+d day %+d second %+d microsecond" % ( + obj.days, + obj.seconds, + obj.microseconds, + ) + + @property + def oid(self) -> int: + return self.INTERVAL_OID + + @Loader.text(builtins["date"].oid) class DateLoader(Loader): diff --git a/tests/types/test_date.py b/tests/types/test_date.py index 1b76d7bf5..9dc95e793 100644 --- a/tests/types/test_date.py +++ b/tests/types/test_date.py @@ -392,6 +392,35 @@ def test_load_timetz_24(conn): cur.fetchone()[0] +# +# Interval +# + + +@pytest.mark.parametrize( + "val, expr", + [ + ("min", "-999999999 days"), + ("1d", "1 day"), + ("-1d", "-1 day"), + ("1s", "1 s"), + ("-1s", "-1 s"), + ("-1m", "-0.000001 s"), + ("1m", "0.000001 s"), + ("max", "999999999 days 23:59:59.999999"), + ], +) +@pytest.mark.parametrize( + "intervalstyle", + ["sql_standard", "postgres", "postgres_verbose", "iso_8601"], +) +def test_dump_interval(conn, val, expr, intervalstyle): + cur = conn.cursor() + cur.execute(f"set IntervalStyle to '{intervalstyle}'") + cur.execute(f"select '{expr}'::interval = %s", (as_td(val),)) + assert cur.fetchone()[0] is True + + # # Support # @@ -444,3 +473,15 @@ def as_tzinfo(s): **dict(zip(("hours", "minutes", "seconds"), map(int, s.split(":")))) ) return dt.timezone(tzoff) + + +def as_td(s): + if s in ("min", "max"): + return getattr(dt.timedelta, s) + + suffixes = {"d": "days", "s": "seconds", "m": "microseconds"} + kwargs = {} + for part in s.split(","): + kwargs[suffixes[part[-1]]] = int(part[:-1]) + + return dt.timedelta(**kwargs)