]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added interval loading
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 27 Oct 2020 03:10:36 +0000 (04:10 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 28 Oct 2020 03:19:24 +0000 (04:19 +0100)
psycopg3/psycopg3/types/date.py
tests/types/test_date.py

index 23b34f1d17f69debc4e2314e395845fb8eb41b31..9f888e6d876c3d12c403df36c4b34329bfbe7313 100644 (file)
@@ -4,6 +4,7 @@ Adapters for date/time types.
 
 # Copyright (C) 2020 The Psycopg Team
 
+import re
 import codecs
 from datetime import date, datetime, time, timedelta
 from typing import cast
@@ -257,22 +258,27 @@ class TimestamptzLoader(TimestampLoader):
         ds = self._get_datestyle()
         if ds.startswith(b"I"):  # ISO
             return "%Y-%m-%d %H:%M:%S.%f%z"
-        elif ds.startswith(b"G"):  # German
-            return "%d.%m.%Y %H:%M:%S.%f %Z"
-        elif ds.startswith(b"S"):  # SQL
-            return (
-                "%d/%m/%Y %H:%M:%S.%f %Z"
-                if ds.endswith(b"DMY")
-                else "%m/%d/%Y %H:%M:%S.%f %Z"
-            )
-        elif ds.startswith(b"P"):  # Postgres
-            return (
-                "%a %d %b %H:%M:%S.%f %Y %Z"
-                if ds.endswith(b"DMY")
-                else "%a %b %d %H:%M:%S.%f %Y %Z"
-            )
+
+        # These don't work: the timezone name is not always displayed
+        # elif ds.startswith(b"G"):  # German
+        #     return "%d.%m.%Y %H:%M:%S.%f %Z"
+        # elif ds.startswith(b"S"):  # SQL
+        #     return (
+        #         "%d/%m/%Y %H:%M:%S.%f %Z"
+        #         if ds.endswith(b"DMY")
+        #         else "%m/%d/%Y %H:%M:%S.%f %Z"
+        #     )
+        # elif ds.startswith(b"P"):  # Postgres
+        #     return (
+        #         "%a %d %b %H:%M:%S.%f %Y %Z"
+        #         if ds.endswith(b"DMY")
+        #         else "%a %b %d %H:%M:%S.%f %Y %Z"
+        #     )
+        # else:
+        #     raise InterfaceError(f"unexpected DateStyle: {ds.decode('ascii')}")
         else:
-            raise InterfaceError(f"unexpected DateStyle: {ds.decode('ascii')}")
+            self.load = self._load_notimpl  # type: ignore[assignment]
+            return ""
 
     def load(self, data: bytes) -> datetime:
         # Hack to convert +HH in +HHMM
@@ -280,3 +286,80 @@ class TimestamptzLoader(TimestampLoader):
             data += b"00"
 
         return super().load(data)
+
+    def _load_notimpl(self, data: bytes) -> datetime:
+        raise NotImplementedError(
+            "can't parse datetimetz with DateStyle"
+            f" {self._get_datestyle().decode('ascii')}: {data.decode('ascii')}"
+        )
+
+
+@Loader.text(builtins["interval"].oid)
+class IntervalLoader(Loader):
+
+    _decode = codecs.lookup("ascii").decode
+    _re_interval = re.compile(
+        br"""
+        (?: (?P<years> [-+]?\d+) \s+ years? \s* )?
+        (?: (?P<months> [-+]?\d+) \s+ mons? \s* )?
+        (?: (?P<days> [-+]?\d+) \s+ days? \s* )?
+        (?: (?P<hsign> [-+])?
+            (?P<hours> \d+ )
+          : (?P<minutes> \d+ )
+          : (?P<seconds> \d+ (?:\.\d+)? )
+        )?
+        """,
+        re.VERBOSE,
+    )
+
+    def __init__(self, oid: int, context: AdaptContext):
+        super().__init__(oid, context)
+        if self.connection:
+            ints = self.connection.pgconn.parameter_status(b"IntervalStyle")
+            if ints != b"postgres":
+                self.load = self._load_notimpl  # type: ignore[assignment]
+
+    def load(self, data: bytes) -> timedelta:
+        m = self._re_interval.match(data)
+        if not m:
+            raise ValueError("can't parse interval: {data.decode('ascii')}")
+
+        days = 0
+        seconds = 0.0
+
+        tmp = m.group("years")
+        if tmp:
+            days += 365 * int(tmp)
+
+        tmp = m.group("months")
+        if tmp:
+            days += 30 * int(tmp)
+
+        tmp = m.group("days")
+        if tmp:
+            days += int(tmp)
+
+        if m.group("hours"):
+            seconds = (
+                3600 * int(m.group("hours"))
+                + 60 * int(m.group("minutes"))
+                + float(m.group("seconds"))
+            )
+            if m.group("hsign") == b"-":
+                seconds = -seconds
+
+        try:
+            return timedelta(days=days, seconds=seconds)
+        except OverflowError as e:
+            raise DataError(str(e))
+
+    def _load_notimpl(self, data: bytes) -> timedelta:
+        ints = (
+            self.connection
+            and self.connection.pgconn.parameter_status(b"IntervalStyle")
+            or b"unknown"
+        )
+        raise NotImplementedError(
+            "can't parse interval with IntervalStyle"
+            f" {ints.decode('ascii')}: {data.decode('ascii')}"
+        )
index 9dc95e793d6c38ba5aa154dae242949b0fa46958..8810de92521b9ffe727ca21424084255030c2821 100644 (file)
@@ -28,10 +28,10 @@ def test_dump_date(conn, val, expr):
 
 
 @pytest.mark.xfail  # TODO: binary dump
-@pytest.mark.parametrize("val, expr", [(dt.date(2000, 1, 1), "2000-01-01")])
+@pytest.mark.parametrize("val, expr", [("2000,1,1", "2000-01-01")])
 def test_dump_date_binary(conn, val, expr):
     cur = conn.cursor()
-    cur.execute(f"select '{expr}'::date = %b", (val,))
+    cur.execute(f"select '{expr}'::date = %b", (as_date(val),))
     assert cur.fetchone()[0] is True
 
 
@@ -76,20 +76,14 @@ def test_load_date_datestyle(conn, datestyle_out):
     assert cur.fetchone()[0] == dt.date(2000, 1, 2)
 
 
+@pytest.mark.parametrize("val", ["min", "max"])
 @pytest.mark.parametrize("datestyle_out", ["ISO", "Postgres", "SQL", "German"])
-def test_load_date_bc(conn, datestyle_out):
-    cur = conn.cursor()
-    cur.execute(f"set datestyle = {datestyle_out}, YMD")
-    cur.execute("select %s - 1", (dt.date.min,))
-    with pytest.raises(DataError):
-        cur.fetchone()[0]
-
-
-@pytest.mark.parametrize("datestyle_out", ["ISO", "Postgres", "SQL", "German"])
-def test_load_date_too_large(conn, datestyle_out):
+def test_load_date_overflow(conn, val, datestyle_out):
     cur = conn.cursor()
     cur.execute(f"set datestyle = {datestyle_out}, YMD")
-    cur.execute("select %s + 1", (dt.date.max,))
+    cur.execute(
+        "select %s + %s::int", (as_date(val), -1 if val == "min" else 1)
+    )
     with pytest.raises(DataError):
         cur.fetchone()[0]
 
@@ -164,20 +158,15 @@ def test_load_datetime(conn, val, expr, datestyle_in, datestyle_out):
     assert cur.fetchone()[0] == as_dt(val)
 
 
+@pytest.mark.parametrize("val", ["min", "max"])
 @pytest.mark.parametrize("datestyle_out", ["ISO", "Postgres", "SQL", "German"])
-def test_load_datetime_bc(conn, datestyle_out):
+def test_load_datetime_overflow(conn, val, datestyle_out):
     cur = conn.cursor()
     cur.execute(f"set datestyle = {datestyle_out}, YMD")
-    cur.execute("select %s - '1s'::interval", (dt.datetime.min,))
-    with pytest.raises(DataError):
-        cur.fetchone()[0]
-
-
-@pytest.mark.parametrize("datestyle_out", ["ISO", "SQL", "Postgres", "German"])
-def test_load_datetime_too_large(conn, datestyle_out):
-    cur = conn.cursor()
-    cur.execute(f"set datestyle = {datestyle_out}, YMD")
-    cur.execute("select %s + '1s'::interval", (dt.datetime.max,))
+    cur.execute(
+        "select %s::timestamp + %s * '1s'::interval",
+        (as_dt(val), -1 if val == "min" else 1),
+    )
     with pytest.raises(DataError):
         cur.fetchone()[0]
 
@@ -421,6 +410,71 @@ def test_dump_interval(conn, val, expr, intervalstyle):
     assert cur.fetchone()[0] is True
 
 
+@pytest.mark.xfail  # TODO: binary dump
+@pytest.mark.parametrize("val, expr", [("1s", "1s")])
+def test_dump_interval_binary(conn, val, expr):
+    cur = conn.cursor()
+    cur.execute(f"select '{expr}'::interval = %b", (as_td(val),))
+    assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize(
+    "val, expr",
+    [
+        ("1s", "1 sec"),
+        ("-1s", "-1 sec"),
+        ("60s", "1 min"),
+        ("3600s", "1 hour"),
+        ("1s,1000m", "1.001 sec"),
+        ("1s,1m", "1.000001 sec"),
+        ("1d", "1 day"),
+        ("-10d", "-10 day"),
+        ("1d,1s,1m", "1 day 1.000001 sec"),
+        ("-86399s,-999999m", "-23:59:59.999999"),
+        ("-3723s,-400000m", "-1:2:3.4"),
+        ("3723s,400000m", "1:2:3.4"),
+        ("86399s,999999m", "23:59:59.999999"),
+        ("30d", "30 day"),
+        ("365d", "1 year"),
+        ("-365d", "-1 year"),
+        ("-730d", "-2 years"),
+        ("1460d", "4 year"),
+        ("30d", "1 month"),
+        ("-30d", "-1 month"),
+        ("60d", "2 month"),
+        ("-90d", "-3 month"),
+    ],
+)
+def test_load_interval(conn, val, expr):
+    cur = conn.cursor()
+    cur.execute(f"select '{expr}'::interval")
+    assert cur.fetchone()[0] == as_td(val)
+
+
+@pytest.mark.xfail  # weird interval outputs
+@pytest.mark.parametrize("val, expr", [("1d,1s", "1 day 1 sec")])
+@pytest.mark.parametrize(
+    "intervalstyle",
+    ["sql_standard", "postgres_verbose", "iso_8601"],
+)
+def test_load_interval_intervalstyle(conn, val, expr, intervalstyle):
+    cur = conn.cursor()
+    cur.execute(f"set IntervalStyle to '{intervalstyle}'")
+    cur.execute(f"select '{expr}'::interval")
+    assert cur.fetchone()[0] == as_td(val)
+
+
+@pytest.mark.parametrize("val", ["min", "max"])
+def test_load_interval_overflow(conn, val):
+    cur = conn.cursor()
+    cur.execute(
+        "select %s + %s * '1s'::interval",
+        (as_td(val), -1 if val == "min" else 1),
+    )
+    with pytest.raises(DataError):
+        cur.fetchone()[0]
+
+
 #
 # Support
 #