]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add date binary adapters
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 12 May 2021 18:55:40 +0000 (20:55 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 12 May 2021 18:55:40 +0000 (20:55 +0200)
psycopg3/psycopg3/types/__init__.py
psycopg3/psycopg3/types/date.py
tests/fix_faker.py
tests/types/test_date.py

index ab01af9146cc1921d085e1cc1d63ca18a30d9db9..2be28c4f0ff959904228f21ce3f43c03d2cb7c14 100644 (file)
@@ -81,12 +81,14 @@ from .singletons import (
 )
 from .date import (
     DateDumper as DateDumper,
+    DateBinaryDumper as DateBinaryDumper,
     TimeDumper as TimeDumper,
     TimeTzDumper as TimeTzDumper,
     DateTimeTzDumper as DateTimeTzDumper,
     DateTimeDumper as DateTimeDumper,
     TimeDeltaDumper as TimeDeltaDumper,
     DateLoader as DateLoader,
+    DateBinaryLoader as DateBinaryLoader,
     TimeLoader as TimeLoader,
     TimeTzLoader as TimeTzLoader,
     TimestampLoader as TimestampLoader,
@@ -204,10 +206,12 @@ def register_default_globals(ctx: AdaptContext) -> None:
     BoolBinaryLoader.register("bool", ctx)
 
     DateDumper.register("datetime.date", ctx)
+    DateBinaryDumper.register("datetime.date", ctx)
     TimeDumper.register("datetime.time", ctx)
     DateTimeTzDumper.register("datetime.datetime", ctx)
     TimeDeltaDumper.register("datetime.timedelta", ctx)
     DateLoader.register("date", ctx)
+    DateBinaryLoader.register("date", ctx)
     TimeLoader.register("time", ctx)
     TimeTzLoader.register("timetz", ctx)
     TimestampLoader.register("timestamp", ctx)
index 478e969d38fc0c0262ce00773fee9480ba5e1166..f348cab2a6d8beff68c412ba2aed941bf77d9ca7 100644 (file)
@@ -6,8 +6,9 @@ Adapters for date/time types.
 
 import re
 import sys
+import struct
 from datetime import date, datetime, time, timedelta
-from typing import cast, Optional, Tuple, Union
+from typing import Callable, cast, Optional, Tuple, Union
 
 from ..pq import Format
 from ..oids import postgres_types as builtins
@@ -15,6 +16,16 @@ from ..adapt import Buffer, Dumper, Loader, Format as Pg3Format
 from ..proto import AdaptContext
 from ..errors import InterfaceError, DataError
 
+_PackInt = Callable[[int], bytes]
+_UnpackInt = Callable[[bytes], Tuple[int]]
+
+_pack_int4 = cast(_PackInt, struct.Struct("!i").pack)
+_unpack_int4 = cast(_UnpackInt, struct.Struct("!i").unpack)
+
+_pg_date_epoch = date(2000, 1, 1).toordinal()
+_py_date_min = date.min.toordinal()
+_py_date_max = date.max.toordinal()
+
 
 class DateDumper(Dumper):
 
@@ -27,6 +38,16 @@ class DateDumper(Dumper):
         return str(obj).encode("utf8")
 
 
+class DateBinaryDumper(Dumper):
+
+    format = Format.BINARY
+    _oid = builtins["date"].oid
+
+    def dump(self, obj: date) -> bytes:
+        days = obj.toordinal() - _pg_date_epoch
+        return _pack_int4(days)
+
+
 class TimeDumper(Dumper):
 
     format = Format.TEXT
@@ -181,6 +202,21 @@ class DateLoader(Loader):
         return max(map(len, parts))
 
 
+class DateBinaryLoader(Loader):
+
+    format = Format.BINARY
+
+    def load(self, data: Buffer) -> date:
+        days = _unpack_int4(data)[0] + _pg_date_epoch
+        if _py_date_min <= days <= _py_date_max:
+            return date.fromordinal(days)
+        else:
+            if days < _py_date_min:
+                raise DataError("date too small (before year 1)")
+            else:
+                raise DataError("date too large (after year 10K)")
+
+
 class TimeLoader(Loader):
 
     format = Format.TEXT
index 6770b554bebd35c7e405323830c1280154635b40..9eb81578c3867580729ba00796b440945937cfdc 100644 (file)
@@ -1,3 +1,4 @@
+import datetime as dt
 import importlib
 from math import isnan
 from uuid import UUID
@@ -226,6 +227,10 @@ class Faker:
         length = randrange(self.str_max_length)
         return spec(bytes([randrange(256) for i in range(length)]))
 
+    def make_date(self, spec):
+        day = randrange(dt.date.max.toordinal())
+        return dt.date.fromordinal(day + 1)
+
     def make_Decimal(self, spec):
         if random() >= 0.99:
             if self.conn.info.server_version >= 140000:
index 0f1981f72d0810c65fd17f25a09a7f24aa3b135e..0f4251cf87f8bddad1e950966eeb157eea5e4974 100644 (file)
@@ -3,7 +3,7 @@ import datetime as dt
 
 import pytest
 
-from psycopg3 import DataError, sql
+from psycopg3 import DataError, pq, sql
 from psycopg3.adapt import Format
 
 
@@ -23,31 +23,27 @@ from psycopg3.adapt import Format
         ("max", "9999-12-31"),
     ],
 )
-def test_dump_date(conn, val, expr):
+@pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT, Format.BINARY])
+def test_dump_date(conn, val, expr, fmt_in):
     val = as_date(val)
     cur = conn.cursor()
-    cur.execute(f"select '{expr}'::date = %s", (val,))
+    cur.execute(f"select '{expr}'::date = %{fmt_in}", (val,))
     assert cur.fetchone()[0] is True
 
     cur.execute(
-        sql.SQL("select {val}::date = %s").format(val=sql.Literal(val)), (val,)
+        sql.SQL("select {}::date = {}").format(
+            sql.Literal(val), sql.Placeholder(format=fmt_in)
+        ),
+        (val,),
     )
     assert cur.fetchone()[0] is True
 
 
-@pytest.mark.xfail  # TODO: binary dump
-@pytest.mark.parametrize("val, expr", [("2000,1,1", "2000-01-01")])
-def test_dump_date_binary(conn, val, expr):
-    cur = conn.cursor()
-    cur.execute(f"select '{expr}'::date = %b", (as_date(val),))
-    assert cur.fetchone()[0] is True
-
-
 @pytest.mark.parametrize("datestyle_in", ["DMY", "MDY", "YMD"])
 def test_dump_date_datestyle(conn, datestyle_in):
     cur = conn.cursor()
     cur.execute(f"set datestyle = ISO, {datestyle_in}")
-    cur.execute("select 'epoch'::date + 1 = %s", (dt.date(1970, 1, 2),))
+    cur.execute("select 'epoch'::date + 1 = %t", (dt.date(1970, 1, 2),))
     assert cur.fetchone()[0] is True
 
 
@@ -62,16 +58,9 @@ def test_dump_date_datestyle(conn, datestyle_in):
         ("max", "9999-12-31"),
     ],
 )
-def test_load_date(conn, val, expr):
-    cur = conn.cursor()
-    cur.execute(f"select '{expr}'::date")
-    assert cur.fetchone()[0] == as_date(val)
-
-
-@pytest.mark.xfail  # TODO: binary load
-@pytest.mark.parametrize("val, expr", [("2000,1,1", "2000-01-01")])
-def test_load_date_binary(conn, val, expr):
-    cur = conn.cursor(binary=Format.BINARY)
+@pytest.mark.parametrize("fmt_out", [pq.Format.TEXT, pq.Format.BINARY])
+def test_load_date(conn, val, expr, fmt_out):
+    cur = conn.cursor(binary=fmt_out)
     cur.execute(f"select '{expr}'::date")
     assert cur.fetchone()[0] == as_date(val)
 
@@ -96,6 +85,16 @@ def test_load_date_overflow(conn, val, datestyle_out):
         cur.fetchone()[0]
 
 
+@pytest.mark.parametrize("val", ["min", "max"])
+def test_load_date_overflow_binary(conn, val):
+    cur = conn.cursor(binary=True)
+    cur.execute(
+        "select %s + %s::int", (as_date(val), -1 if val == "min" else 1)
+    )
+    with pytest.raises(DataError):
+        cur.fetchone()[0]
+
+
 #
 # datetime tests
 #