]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added time load/dump, and tests refactoring
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 26 Oct 2020 22:32:43 +0000 (23:32 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 28 Oct 2020 03:19:24 +0000 (04:19 +0100)
psycopg3/psycopg3/types/date.py
tests/types/test_date.py

index 576bcf4e84694497d13df40957e2c1afec7c6d1c..afbbc2ff0ee275468f60ba7c04ca0acc1afea907 100644 (file)
@@ -5,7 +5,7 @@ Adapters for date/time types.
 # Copyright (C) 2020 The Psycopg Team
 
 import codecs
-from datetime import date, datetime
+from datetime import date, datetime, time
 from typing import cast
 
 from ..adapt import Dumper, Loader
@@ -30,6 +30,20 @@ class DateDumper(Dumper):
         return self.DATE_OID
 
 
+@Dumper.text(time)
+class TimeDumper(Dumper):
+
+    _encode = codecs.lookup("ascii").encode
+    TIMETZ_OID = builtins["timetz"].oid
+
+    def dump(self, obj: time) -> bytes:
+        return self._encode(str(obj))[0]
+
+    @property
+    def oid(self) -> int:
+        return self.TIMETZ_OID
+
+
 @Dumper.text(datetime)
 class DateTimeDumper(Dumper):
 
@@ -53,12 +67,12 @@ class DateLoader(Loader):
 
     def __init__(self, oid: int, context: AdaptContext):
         super().__init__(oid, context)
-        self._date_format = self._format_from_context()
+        self._format = self._format_from_context()
 
     def load(self, data: bytes) -> date:
         try:
             return datetime.strptime(
-                self._decode(data)[0], self._date_format
+                self._decode(data)[0], self._format
             ).date()
         except ValueError as e:
             return self._raise_error(data, e)
@@ -109,15 +123,41 @@ class DateLoader(Loader):
         raise exc
 
 
+@Loader.text(builtins["time"].oid)
+class TimeLoader(Loader):
+
+    _decode = codecs.lookup("ascii").decode
+    _format = "%H:%M:%S.%f"
+    _format_no_micro = _format.replace(".%f", "")
+
+    def load(self, data: bytes) -> time:
+        # check if the data contains microseconds
+        fmt = self._format if b"." in data else self._format_no_micro
+        try:
+            return datetime.strptime(self._decode(data)[0], fmt).time()
+        except ValueError as e:
+            return self._raise_error(data, e)
+
+    def _raise_error(self, data: bytes, exc: ValueError) -> time:
+        # Most likely, time 24:00
+        if data.startswith(b"24"):
+            raise ValueError("time with hour 24 not supported by Python")
+
+        # We genuinely received something we cannot parse
+        raise exc
+
+
 @Loader.text(builtins["timestamp"].oid)
 class TimestampLoader(DateLoader):
     def __init__(self, oid: int, context: AdaptContext):
         super().__init__(oid, context)
-        self._no_micro_format = self._date_format.replace(".%f", "")
+        self._format_no_micro = self._format.replace(".%f", "")
 
     def load(self, data: bytes) -> datetime:
         # check if the data contains microseconds
-        fmt = self._date_format if b"." in data[19:] else self._no_micro_format
+        fmt = (
+            self._format if data.find(b".", 19) >= 0 else self._format_no_micro
+        )
         try:
             return datetime.strptime(self._decode(data)[0], fmt)
         except ValueError as e:
index b3d4a14b43d405c8e507f96d415f98d66e69a730..eaddb51e41ee00112ad2b9a3de4f52a18782880d 100644 (file)
@@ -9,20 +9,26 @@ from psycopg3.adapt import Format
 #
 
 
+def as_date(s):
+    return (
+        dt.date(*map(int, s.split(","))) if "," in s else getattr(dt.date, s)
+    )
+
+
 @pytest.mark.parametrize(
     "val, expr",
     [
-        (dt.date.min, "0001-01-01"),
-        (dt.date(1000, 1, 1), "1000-01-01"),
-        (dt.date(2000, 1, 1), "2000-01-01"),
-        (dt.date(2000, 12, 31), "2000-12-31"),
-        (dt.date(3000, 1, 1), "3000-01-01"),
-        (dt.date.max, "9999-12-31"),
+        ("min", "0001-01-01"),
+        ("1000,1,1", "1000-01-01"),
+        ("2000,1,1", "2000-01-01"),
+        ("2000,12,31", "2000-12-31"),
+        ("3000,1,1", "3000-01-01"),
+        ("max", "9999-12-31"),
     ],
 )
 def test_dump_date(conn, val, expr):
     cur = conn.cursor()
-    cur.execute(f"select '{expr}'::date = %s", (val,))
+    cur.execute(f"select '{expr}'::date = %s", (as_date(val),))
     assert cur.fetchone()[0] is True
 
 
@@ -45,26 +51,26 @@ def test_dump_date_datestyle(conn, datestyle_in):
 @pytest.mark.parametrize(
     "val, expr",
     [
-        (dt.date.min, "0001-01-01"),
-        (dt.date(1000, 1, 1), "1000-01-01"),
-        (dt.date(2000, 1, 1), "2000-01-01"),
-        (dt.date(2000, 12, 31), "2000-12-31"),
-        (dt.date(3000, 1, 1), "3000-01-01"),
-        (dt.date.max, "9999-12-31"),
+        ("min", "0001-01-01"),
+        ("1000,1,1", "1000-01-01"),
+        ("2000,1,1", "2000-01-01"),
+        ("2000,12,31", "2000-12-31"),
+        ("3000,1,1", "3000-01-01"),
+        ("max", "9999-12-31"),
     ],
 )
 def test_load_date(conn, val, expr):
     cur = conn.cursor()
     cur.execute(f"select '{expr}'::date")
-    assert cur.fetchone()[0] == val
+    assert cur.fetchone()[0] == as_date(val)
 
 
 @pytest.mark.xfail  # TODO: binary load
-@pytest.mark.parametrize("val, expr", [(dt.date(2000, 1, 1), "2000-01-01")])
+@pytest.mark.parametrize("val, expr", [("2000,1,1", "2000-01-01")])
 def test_load_date_binary(conn, val, expr):
     cur = conn.cursor(format=Format.BINARY)
     cur.execute("select '{expr}'::date" % expr)
-    assert cur.fetchone()[0] == val
+    assert cur.fetchone()[0] == as_date(val)
 
 
 @pytest.mark.parametrize("datestyle_out", ["ISO", "Postgres", "SQL", "German"])
@@ -98,36 +104,55 @@ def test_load_date_too_large(conn, datestyle_out):
 #
 
 
+def as_dt(s):
+    if "~" in s:
+        s, off = s.split("~")
+    else:
+        off = None
+
+    if "," in s:
+        rv = dt.datetime(*map(int, s.split(",")))
+    else:
+        rv = getattr(dt.datetime, s)
+
+    if off:
+        tzoff = dt.timedelta(
+            **dict(
+                zip(("hours", "minutes", "seconds"), map(int, off.split(":")))
+            )
+        )
+        rv = rv.replace(tzinfo=dt.timezone(tzoff))
+
+    return rv
+
+
 @pytest.mark.parametrize(
     "val, expr",
     [
-        (dt.datetime.min, "0001-01-01 00:00"),
-        (dt.datetime(1000, 1, 1, 0, 0), "1000-01-01 00:00"),
-        (dt.datetime(2000, 1, 1, 0, 0), "2000-01-01 00:00"),
-        (
-            dt.datetime(2000, 12, 31, 23, 59, 59, 999999),
-            "2000-12-31 23:59:59.999999",
-        ),
-        (dt.datetime(3000, 1, 1, 0, 0), "3000-01-01 00:00"),
-        (dt.datetime.max, "9999-12-31 23:59:59.999999"),
+        ("min", "0001-01-01 00:00"),
+        ("1000,1,1,0,0", "1000-01-01 00:00"),
+        ("2000,1,1,0,0", "2000-01-01 00:00"),
+        ("2000,12,31,23,59,59,999999", "2000-12-31 23:59:59.999999"),
+        ("3000,1,1,0,0", "3000-01-01 00:00"),
+        ("max", "9999-12-31 23:59:59.999999"),
     ],
 )
 def test_dump_datetime(conn, val, expr):
     cur = conn.cursor()
     cur.execute("set timezone to '+02:00'")
-    cur.execute(f"select '{expr}'::timestamp = %s", (val,))
+    cur.execute(f"select '{expr}'::timestamp = %s", (as_dt(val),))
     assert cur.fetchone()[0] is True
 
 
 @pytest.mark.xfail  # TODO: binary dump
 @pytest.mark.parametrize(
     "val, expr",
-    [(dt.datetime(2000, 1, 1, 0, 0), "'2000-01-01 00:00'::timestamp")],
+    [("2000,1,1,0,0", "'2000-01-01 00:00'::timestamp")],
 )
 def test_dump_datetime_binary(conn, val, expr):
     cur = conn.cursor()
     cur.execute("set timezone to '+02:00'")
-    cur.execute("select %s = %%b" % expr, (val,))
+    cur.execute("select %s = %%b" % expr, (as_dt(val),))
     assert cur.fetchone()[0] is True
 
 
@@ -161,14 +186,9 @@ def test_dump_datetime_datestyle(conn, datestyle_in):
 def test_load_datetime(conn, val, expr, datestyle_in, datestyle_out):
     cur = conn.cursor()
     cur.execute(f"set datestyle = {datestyle_out}, {datestyle_in}")
-    val = (
-        dt.datetime(*map(int, val.split(",")))
-        if "," in val
-        else getattr(dt.datetime, val)
-    )
     cur.execute("set timezone to '+02:00'")
     cur.execute(f"select '{expr}'::timestamp")
-    assert cur.fetchone()[0] == val
+    assert cur.fetchone()[0] == as_dt(val)
 
 
 #
@@ -179,35 +199,27 @@ def test_load_datetime(conn, val, expr, datestyle_in, datestyle_out):
 @pytest.mark.parametrize(
     "val, expr",
     [
-        (dt.datetime.min, "0001-01-01 00:00"),
-        (dt.datetime(1000, 1, 1, 0, 0), "1000-01-01 00:00+2"),
-        (dt.datetime(2000, 1, 1, 0, 0), "2000-01-01 00:00+2"),
-        (
-            dt.datetime(2000, 12, 31, 23, 59, 59, 999999),
-            "2000-12-31 23:59:59.999999+2",
-        ),
-        (dt.datetime(3000, 1, 1, 0, 0), "3000-01-01 00:00+2"),
-        (dt.datetime.max, "9999-12-31 23:59:59.999999"),
+        ("min~2", "0001-01-01 00:00"),
+        ("1000,1,1,0,0~2", "1000-01-01 00:00+2"),
+        ("2000,1,1,0,0~2", "2000-01-01 00:00+2"),
+        ("2000,12,31,23,59,59,999999~2", "2000-12-31 23:59:59.999999+2"),
+        ("3000,1,1,0,0~2", "3000-01-01 00:00+2"),
+        ("max~2", "9999-12-31 23:59:59.999999"),
     ],
 )
 def test_dump_datetimetz(conn, val, expr):
-    val = val.replace(tzinfo=dt.timezone(dt.timedelta(hours=2)))
     cur = conn.cursor()
     cur.execute("set timezone to '-02:00'")
-    cur.execute(f"select '{expr}'::timestamptz = %s", (val,))
+    cur.execute(f"select '{expr}'::timestamptz = %s", (as_dt(val),))
     assert cur.fetchone()[0] is True
 
 
 @pytest.mark.xfail  # TODO: binary dump
-@pytest.mark.parametrize(
-    "val, expr",
-    [(dt.datetime(2000, 1, 1, 0, 0), "2000-01-01 00:00")],
-)
+@pytest.mark.parametrize("val, expr", [("2000,1,1,0,0~2", "2000-01-01 00:00")])
 def test_dump_datetimetz_binary(conn, val, expr):
-    val = val.replace(tzinfo=dt.timezone(dt.timedelta(hours=2)))
     cur = conn.cursor()
     cur.execute("set timezone to '-02:00'")
-    cur.execute(f"select '{expr}'::timestamptz = %b", (val,))
+    cur.execute(f"select '{expr}'::timestamptz = %b", (as_dt(val),))
     assert cur.fetchone()[0] is True
 
 
@@ -225,61 +237,98 @@ def test_dump_datetimetz_datestyle(conn, datestyle_in):
 
 
 @pytest.mark.parametrize(
-    "val, offset, expr, timezone",
+    "val, expr, timezone",
     [
-        ("2000,1,1", "02:00", "2000-01-01", "-02:00"),
-        ("2000,1,2,3,4,5,6", "02:00", "2000-01-02 03:04:05.000006", "-02:00"),
-        (
-            "2000,1,2,3,4,5,678",
-            "01:00",
-            "2000-01-02 03:04:05.000678",
-            "Europe/Rome",
-        ),
-        (
-            "2000,7,2,3,4,5,678",
-            "02:00",
-            "2000-07-02 03:04:05.000678",
-            "Europe/Rome",
-        ),
-        (
-            "2000,1,2,3,0,0,456789",
-            "02:00",
-            "2000-01-02 03:00:00.456789",
-            "-02:00",
-        ),
-        ("2000,12,31", "02:00", "2000-12-31", "-02:00"),
-        ("1900,1,1", "05:21:10", "1900-01-01", "Asia/Calcutta"),
+        ("2000,1,1~2", "2000-01-01", "-02:00"),
+        ("2000,1,2,3,4,5,6~2", "2000-01-02 03:04:05.000006", "-02:00"),
+        ("2000,1,2,3,4,5,678~1", "2000-01-02 03:04:05.000678", "Europe/Rome"),
+        ("2000,7,2,3,4,5,678~2", "2000-07-02 03:04:05.000678", "Europe/Rome"),
+        ("2000,1,2,3,0,0,456789~2", "2000-01-02 03:00:00.456789", "-02:00"),
+        ("2000,12,31~2", "2000-12-31", "-02:00"),
+        ("1900,1,1~05:21:10", "1900-01-01", "Asia/Calcutta"),
     ],
 )
 @pytest.mark.parametrize("datestyle_out", ["ISO"])
-def test_load_datetimetz(conn, val, offset, expr, timezone, datestyle_out):
+def test_load_datetimetz(conn, val, expr, timezone, datestyle_out):
     cur = conn.cursor()
     cur.execute(f"set datestyle = {datestyle_out}, DMY")
-    val = dt.datetime(*map(int, val.split(",")))
-    tzoff = dt.timedelta(
-        **dict(
-            zip(("hours", "minutes", "seconds"), map(int, offset.split(":")))
-        )
-    )
-    val = val.replace(tzinfo=dt.timezone(tzoff))
     cur.execute(f"set timezone to '{timezone}'")
     cur.execute(f"select '{expr}'::timestamptz")
-    assert cur.fetchone()[0] == val
+    assert cur.fetchone()[0] == as_dt(val)
 
 
 @pytest.mark.xfail  # parse timezone names
-@pytest.mark.parametrize("val, expr", [("2000,1,1", "2000-01-01")])
+@pytest.mark.parametrize("val, expr", [("2000,1,1~2", "2000-01-01")])
 @pytest.mark.parametrize("datestyle_out", ["SQL", "Postgres", "German"])
 @pytest.mark.parametrize("datestyle_in", ["DMY", "MDY", "YMD"])
 def test_load_datetimetz_tzname(conn, val, expr, datestyle_in, datestyle_out):
     cur = conn.cursor()
     cur.execute(f"set datestyle = {datestyle_out}, {datestyle_in}")
-    val = (
-        dt.datetime(*map(int, val.split(",")))
-        if "," in val
-        else getattr(dt.datetime, val)
-    )
-    val = val.replace(tzinfo=dt.timezone(dt.timedelta(hours=2)))
     cur.execute("set timezone to '-02:00'")
     cur.execute(f"select '{expr}'::timestamptz")
-    assert cur.fetchone()[0] == val
+    assert cur.fetchone()[0] == as_dt(val)
+
+
+#
+# time tests
+#
+
+
+def as_time(s):
+    return (
+        dt.time(*map(int, s.split(","))) if "," in s else getattr(dt.time, s)
+    )
+
+
+@pytest.mark.parametrize(
+    "val, expr",
+    [
+        ("min", "00:00"),
+        ("10,20,30,40", "10:20:30.000040"),
+        ("max", "23:59:59.999999"),
+    ],
+)
+def test_dump_time(conn, val, expr):
+    cur = conn.cursor()
+    cur.execute(f"select '{expr}'::time = %s", (as_time(val),))
+    assert cur.fetchone()[0] is True
+
+
+@pytest.mark.xfail  # TODO: binary dump
+@pytest.mark.parametrize("val, expr", [(dt.time(0, 0), "00:00")])
+def test_dump_time_binary(conn, val, expr):
+    cur = conn.cursor()
+    cur.execute(f"select '{expr}'::time = %b", (val,))
+    assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize(
+    "val, expr",
+    [
+        ("min", "00:00"),
+        ("1,2", "01:02"),
+        ("10,20", "10:20"),
+        ("10,20,30", "10:20:30"),
+        ("10,20,30,40", "10:20:30.000040"),
+        ("max", "23:59:59.999999"),
+    ],
+)
+def test_load_time(conn, val, expr):
+    cur = conn.cursor()
+    cur.execute(f"select '{expr}'::time")
+    assert cur.fetchone()[0] == as_time(val)
+
+
+@pytest.mark.xfail  # TODO: binary load
+@pytest.mark.parametrize("val, expr", [("0,0", "00:00")])
+def test_load_time_binary(conn, val, expr):
+    cur = conn.cursor(format=Format.BINARY)
+    cur.execute("select '{expr}'::time" % expr)
+    assert cur.fetchone()[0] == as_time(val)
+
+
+def test_load_time_24(conn):
+    cur = conn.cursor()
+    cur.execute("select '24:00'::time")
+    with pytest.raises(ValueError):
+        cur.fetchone()[0]