From: Daniele Varrazzo Date: Tue, 27 Oct 2020 03:10:36 +0000 (+0100) Subject: Added interval loading X-Git-Tag: 3.0.dev0~429 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=532901553d7a10c8c5d27d07d07112f9b843ebbc;p=thirdparty%2Fpsycopg.git Added interval loading --- diff --git a/psycopg3/psycopg3/types/date.py b/psycopg3/psycopg3/types/date.py index 23b34f1d1..9f888e6d8 100644 --- a/psycopg3/psycopg3/types/date.py +++ b/psycopg3/psycopg3/types/date.py @@ -4,6 +4,7 @@ Adapters for date/time types. # Copyright (C) 2020 The Psycopg Team +import re import codecs from datetime import date, datetime, time, timedelta from typing import cast @@ -257,22 +258,27 @@ class TimestamptzLoader(TimestampLoader): ds = self._get_datestyle() if ds.startswith(b"I"): # ISO return "%Y-%m-%d %H:%M:%S.%f%z" - elif ds.startswith(b"G"): # German - return "%d.%m.%Y %H:%M:%S.%f %Z" - elif ds.startswith(b"S"): # SQL - return ( - "%d/%m/%Y %H:%M:%S.%f %Z" - if ds.endswith(b"DMY") - else "%m/%d/%Y %H:%M:%S.%f %Z" - ) - elif ds.startswith(b"P"): # Postgres - return ( - "%a %d %b %H:%M:%S.%f %Y %Z" - if ds.endswith(b"DMY") - else "%a %b %d %H:%M:%S.%f %Y %Z" - ) + + # These don't work: the timezone name is not always displayed + # elif ds.startswith(b"G"): # German + # return "%d.%m.%Y %H:%M:%S.%f %Z" + # elif ds.startswith(b"S"): # SQL + # return ( + # "%d/%m/%Y %H:%M:%S.%f %Z" + # if ds.endswith(b"DMY") + # else "%m/%d/%Y %H:%M:%S.%f %Z" + # ) + # elif ds.startswith(b"P"): # Postgres + # return ( + # "%a %d %b %H:%M:%S.%f %Y %Z" + # if ds.endswith(b"DMY") + # else "%a %b %d %H:%M:%S.%f %Y %Z" + # ) + # else: + # raise InterfaceError(f"unexpected DateStyle: {ds.decode('ascii')}") else: - raise InterfaceError(f"unexpected DateStyle: {ds.decode('ascii')}") + self.load = self._load_notimpl # type: ignore[assignment] + return "" def load(self, data: bytes) -> datetime: # Hack to convert +HH in +HHMM @@ -280,3 +286,80 @@ class TimestamptzLoader(TimestampLoader): data += b"00" return super().load(data) + + def _load_notimpl(self, data: bytes) -> datetime: + raise NotImplementedError( + "can't parse datetimetz with DateStyle" + f" {self._get_datestyle().decode('ascii')}: {data.decode('ascii')}" + ) + + +@Loader.text(builtins["interval"].oid) +class IntervalLoader(Loader): + + _decode = codecs.lookup("ascii").decode + _re_interval = re.compile( + br""" + (?: (?P [-+]?\d+) \s+ years? \s* )? + (?: (?P [-+]?\d+) \s+ mons? \s* )? + (?: (?P [-+]?\d+) \s+ days? \s* )? + (?: (?P [-+])? + (?P \d+ ) + : (?P \d+ ) + : (?P \d+ (?:\.\d+)? ) + )? + """, + re.VERBOSE, + ) + + def __init__(self, oid: int, context: AdaptContext): + super().__init__(oid, context) + if self.connection: + ints = self.connection.pgconn.parameter_status(b"IntervalStyle") + if ints != b"postgres": + self.load = self._load_notimpl # type: ignore[assignment] + + def load(self, data: bytes) -> timedelta: + m = self._re_interval.match(data) + if not m: + raise ValueError("can't parse interval: {data.decode('ascii')}") + + days = 0 + seconds = 0.0 + + tmp = m.group("years") + if tmp: + days += 365 * int(tmp) + + tmp = m.group("months") + if tmp: + days += 30 * int(tmp) + + tmp = m.group("days") + if tmp: + days += int(tmp) + + if m.group("hours"): + seconds = ( + 3600 * int(m.group("hours")) + + 60 * int(m.group("minutes")) + + float(m.group("seconds")) + ) + if m.group("hsign") == b"-": + seconds = -seconds + + try: + return timedelta(days=days, seconds=seconds) + except OverflowError as e: + raise DataError(str(e)) + + def _load_notimpl(self, data: bytes) -> timedelta: + ints = ( + self.connection + and self.connection.pgconn.parameter_status(b"IntervalStyle") + or b"unknown" + ) + raise NotImplementedError( + "can't parse interval with IntervalStyle" + f" {ints.decode('ascii')}: {data.decode('ascii')}" + ) diff --git a/tests/types/test_date.py b/tests/types/test_date.py index 9dc95e793..8810de925 100644 --- a/tests/types/test_date.py +++ b/tests/types/test_date.py @@ -28,10 +28,10 @@ def test_dump_date(conn, val, expr): @pytest.mark.xfail # TODO: binary dump -@pytest.mark.parametrize("val, expr", [(dt.date(2000, 1, 1), "2000-01-01")]) +@pytest.mark.parametrize("val, expr", [("2000,1,1", "2000-01-01")]) def test_dump_date_binary(conn, val, expr): cur = conn.cursor() - cur.execute(f"select '{expr}'::date = %b", (val,)) + cur.execute(f"select '{expr}'::date = %b", (as_date(val),)) assert cur.fetchone()[0] is True @@ -76,20 +76,14 @@ def test_load_date_datestyle(conn, datestyle_out): assert cur.fetchone()[0] == dt.date(2000, 1, 2) +@pytest.mark.parametrize("val", ["min", "max"]) @pytest.mark.parametrize("datestyle_out", ["ISO", "Postgres", "SQL", "German"]) -def test_load_date_bc(conn, datestyle_out): - cur = conn.cursor() - cur.execute(f"set datestyle = {datestyle_out}, YMD") - cur.execute("select %s - 1", (dt.date.min,)) - with pytest.raises(DataError): - cur.fetchone()[0] - - -@pytest.mark.parametrize("datestyle_out", ["ISO", "Postgres", "SQL", "German"]) -def test_load_date_too_large(conn, datestyle_out): +def test_load_date_overflow(conn, val, datestyle_out): cur = conn.cursor() cur.execute(f"set datestyle = {datestyle_out}, YMD") - cur.execute("select %s + 1", (dt.date.max,)) + cur.execute( + "select %s + %s::int", (as_date(val), -1 if val == "min" else 1) + ) with pytest.raises(DataError): cur.fetchone()[0] @@ -164,20 +158,15 @@ def test_load_datetime(conn, val, expr, datestyle_in, datestyle_out): assert cur.fetchone()[0] == as_dt(val) +@pytest.mark.parametrize("val", ["min", "max"]) @pytest.mark.parametrize("datestyle_out", ["ISO", "Postgres", "SQL", "German"]) -def test_load_datetime_bc(conn, datestyle_out): +def test_load_datetime_overflow(conn, val, datestyle_out): cur = conn.cursor() cur.execute(f"set datestyle = {datestyle_out}, YMD") - cur.execute("select %s - '1s'::interval", (dt.datetime.min,)) - with pytest.raises(DataError): - cur.fetchone()[0] - - -@pytest.mark.parametrize("datestyle_out", ["ISO", "SQL", "Postgres", "German"]) -def test_load_datetime_too_large(conn, datestyle_out): - cur = conn.cursor() - cur.execute(f"set datestyle = {datestyle_out}, YMD") - cur.execute("select %s + '1s'::interval", (dt.datetime.max,)) + cur.execute( + "select %s::timestamp + %s * '1s'::interval", + (as_dt(val), -1 if val == "min" else 1), + ) with pytest.raises(DataError): cur.fetchone()[0] @@ -421,6 +410,71 @@ def test_dump_interval(conn, val, expr, intervalstyle): assert cur.fetchone()[0] is True +@pytest.mark.xfail # TODO: binary dump +@pytest.mark.parametrize("val, expr", [("1s", "1s")]) +def test_dump_interval_binary(conn, val, expr): + cur = conn.cursor() + cur.execute(f"select '{expr}'::interval = %b", (as_td(val),)) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize( + "val, expr", + [ + ("1s", "1 sec"), + ("-1s", "-1 sec"), + ("60s", "1 min"), + ("3600s", "1 hour"), + ("1s,1000m", "1.001 sec"), + ("1s,1m", "1.000001 sec"), + ("1d", "1 day"), + ("-10d", "-10 day"), + ("1d,1s,1m", "1 day 1.000001 sec"), + ("-86399s,-999999m", "-23:59:59.999999"), + ("-3723s,-400000m", "-1:2:3.4"), + ("3723s,400000m", "1:2:3.4"), + ("86399s,999999m", "23:59:59.999999"), + ("30d", "30 day"), + ("365d", "1 year"), + ("-365d", "-1 year"), + ("-730d", "-2 years"), + ("1460d", "4 year"), + ("30d", "1 month"), + ("-30d", "-1 month"), + ("60d", "2 month"), + ("-90d", "-3 month"), + ], +) +def test_load_interval(conn, val, expr): + cur = conn.cursor() + cur.execute(f"select '{expr}'::interval") + assert cur.fetchone()[0] == as_td(val) + + +@pytest.mark.xfail # weird interval outputs +@pytest.mark.parametrize("val, expr", [("1d,1s", "1 day 1 sec")]) +@pytest.mark.parametrize( + "intervalstyle", + ["sql_standard", "postgres_verbose", "iso_8601"], +) +def test_load_interval_intervalstyle(conn, val, expr, intervalstyle): + cur = conn.cursor() + cur.execute(f"set IntervalStyle to '{intervalstyle}'") + cur.execute(f"select '{expr}'::interval") + assert cur.fetchone()[0] == as_td(val) + + +@pytest.mark.parametrize("val", ["min", "max"]) +def test_load_interval_overflow(conn, val): + cur = conn.cursor() + cur.execute( + "select %s + %s * '1s'::interval", + (as_td(val), -1 if val == "min" else 1), + ) + with pytest.raises(DataError): + cur.fetchone()[0] + + # # Support #