]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add timestamptz adapter
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 13 May 2021 01:19:11 +0000 (03:19 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 13 May 2021 01:43:19 +0000 (03:43 +0200)
psycopg3/psycopg3/types/__init__.py
psycopg3/psycopg3/types/date.py
tests/fix_faker.py
tests/types/test_date.py

index 058fa46e2b76c774545cc9cd1e8758d983c9c17e..aecec425b6378fa96280517aaac758a12713f276 100644 (file)
@@ -99,7 +99,8 @@ from .date import (
     TimeTzBinaryLoader as TimeTzBinaryLoader,
     TimestampLoader as TimestampLoader,
     TimestampBinaryLoader as TimestampBinaryLoader,
-    TimestamptzLoader as TimestamptzLoader,
+    TimestampTzLoader as TimestampTzLoader,
+    TimestampTzBinaryLoader as TimestampTzBinaryLoader,
     IntervalLoader as IntervalLoader,
 )
 from .json import (
@@ -227,7 +228,8 @@ def register_default_globals(ctx: AdaptContext) -> None:
     TimeTzBinaryLoader.register("timetz", ctx)
     TimestampLoader.register("timestamp", ctx)
     TimestampBinaryLoader.register("timestamp", ctx)
-    TimestamptzLoader.register("timestamptz", ctx)
+    TimestampTzLoader.register("timestamptz", ctx)
+    TimestampTzBinaryLoader.register("timestamptz", ctx)
     IntervalLoader.register("interval", ctx)
 
     # Currently json binary format is nothing different than text, maybe with
index a10cfa40daa881a5d20ec8f4c74c5482741287c4..8dcef7bb46879bdf70720056ae4e0c6df2ce26ac 100644 (file)
@@ -167,8 +167,22 @@ class DateTimeTzBinaryDumper(_BaseDateTimeDumper):
 
     format = Format.BINARY
 
+    # Somewhere, between year 2270 and 2275, float rounding in total_seconds
+    # cause us errors: switch to an algorithm without rounding before then.
+    _delta_prec_loss = (
+        datetime(2250, 1, 1) - _pg_datetime_epoch
+    ).total_seconds()
+
     def dump(self, obj: datetime) -> bytes:
-        raise NotImplementedError
+        delta = obj - _pg_datetimetz_epoch
+        secs = delta.total_seconds()
+        if secs < self._delta_prec_loss:
+            micros = int(1_000_000 * secs)
+        else:
+            micros = delta.microseconds + 1_000_000 * (
+                86_400 * delta.days + delta.seconds
+            )
+        return _pack_int8(micros)
 
     def upgrade(self, obj: datetime, format: Pg3Format) -> "Dumper":
         if obj.tzinfo:
@@ -180,12 +194,6 @@ class DateTimeTzBinaryDumper(_BaseDateTimeDumper):
 class DateTimeBinaryDumper(DateTimeTzBinaryDumper):
     _oid = builtins["timestamp"].oid
 
-    # Somewhere, between year 2270 and 2275, float rounding in total_seconds
-    # cause us errors: switch to an algorithm without rounding before then.
-    _delta_prec_loss = (
-        datetime(2250, 1, 1) - _pg_datetime_epoch
-    ).total_seconds()
-
     def dump(self, obj: datetime) -> bytes:
         delta = obj - _pg_datetime_epoch
         secs = delta.total_seconds()
@@ -489,16 +497,10 @@ class TimestampBinaryLoader(Loader):
                 raise DataError("timestamp too large (after year 10K)")
 
 
-class TimestamptzLoader(TimestampLoader):
+class TimestampTzLoader(TimestampLoader):
 
     format = Format.TEXT
 
-    def __init__(self, oid: int, context: Optional[AdaptContext] = None):
-        if sys.version_info < (3, 7):
-            setattr(self, "load", self._load_py36)
-
-        super().__init__(oid, context)
-
     def _format_from_context(self) -> str:
         ds = self._get_datestyle()
         if ds.startswith(b"I"):  # ISO
@@ -559,6 +561,25 @@ class TimestamptzLoader(TimestampLoader):
         )
 
 
+if sys.version_info < (3, 7):
+    setattr(TimestampTzLoader, "load", TimestampTzLoader._load_py36)
+
+
+class TimestampTzBinaryLoader(Loader):
+
+    format = Format.BINARY
+
+    def load(self, data: Buffer) -> datetime:
+        micros = _unpack_int8(data)[0]
+        try:
+            return _pg_datetimetz_epoch + timedelta(microseconds=micros)
+        except OverflowError:
+            if micros <= 0:
+                raise DataError("timestamp too small (before year 1)")
+            else:
+                raise DataError("timestamp too large (after year 10K)")
+
+
 class IntervalLoader(Loader):
 
     format = Format.TEXT
index db159fdb734bf25bfc7c39406eb84310fd9815f7..73dd1025ee5ead0dd7239f52c1494bf2ce202435 100644 (file)
@@ -134,10 +134,13 @@ class Faker:
                 schema[i] = [scls]
             elif cls is tuple:
                 schema[i] = tuple(self.choose_schema(types=types, ncols=ncols))
+            # Pick timezone yes/no
             elif cls is dt.time:
-                # Pick timezone yes/no
                 if choice([True, False]):
                     schema[i] = TimeTz
+            elif cls is dt.datetime:
+                if choice([True, False]):
+                    schema[i] = DateTimeTz
 
         return schema
 
@@ -240,6 +243,10 @@ class Faker:
         micros = randrange((delta.days + 1) * 24 * 60 * 60 * 1_000_000)
         return dt.datetime.min + dt.timedelta(microseconds=micros)
 
+    def make_DateTimeTz(self, spec):
+        rv = self.make_datetime(spec)
+        return rv.replace(tzinfo=self._make_tz(spec))
+
     def make_Decimal(self, spec):
         if random() >= 0.99:
             if self.conn.info.server_version >= 140000:
@@ -398,12 +405,18 @@ class JsonFloat:
     pass
 
 
-class TimeTz(dt.time):
+class TimeTz:
     """
     Placeholder to create time objects with tzinfo.
     """
 
 
+class DateTimeTz:
+    """
+    Placeholder to create datetime objects with tzinfo.
+    """
+
+
 def deep_import(name):
     parts = deque(name.split("."))
     seen = []
index b89fa12bf6b449375b81b67b13dc2877efad3747..b95e932d42053ab9de79145219b6c2e568f90783 100644 (file)
@@ -215,25 +215,17 @@ def test_load_datetime_overflow_binary(conn, val):
         ("max~2", "9999-12-31 23:59:59.999999"),
     ],
 )
-def test_dump_datetimetz(conn, val, expr):
+@pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT, Format.BINARY])
+def test_dump_datetimetz(conn, val, expr, fmt_in):
     # adjust for Python 3.6 missing seconds in tzinfo
-    if val.count(":") > 1:
+    if sys.version_info < (3, 7) and val.count(":") > 1:
         expr = expr.rsplit(":", 1)[0]
         val, rest = val.rsplit(":", 1)
         val += rest[3:]  # skip tz seconds, but include micros
 
     cur = conn.cursor()
     cur.execute("set timezone to '-02:00'")
-    cur.execute(f"select '{expr}'::timestamptz = %s", (as_dt(val),))
-    assert cur.fetchone()[0] is True
-
-
-@pytest.mark.xfail  # TODO: binary dump
-@pytest.mark.parametrize("val, expr", [("2000,1,1,0,0~2", "2000-01-01 00:00")])
-def test_dump_datetimetz_binary(conn, val, expr):
-    cur = conn.cursor()
-    cur.execute("set timezone to '-02:00'")
-    cur.execute(f"select '{expr}'::timestamptz = %b", (as_dt(val),))
+    cur.execute(f"select '{expr}'::timestamptz = %{fmt_in}", (as_dt(val),))
     assert cur.fetchone()[0] is True
 
 
@@ -250,19 +242,19 @@ def test_dump_datetimetz_datestyle(conn, datestyle_in):
     assert cur.fetchone()[0] is True
 
 
-@pytest.mark.parametrize(
-    "val, expr, timezone",
-    [
-        ("2000,1,1~2", "2000-01-01", "-02:00"),
-        ("2000,1,2,3,4,5,6~2", "2000-01-02 03:04:05.000006", "-02:00"),
-        ("2000,1,2,3,4,5,678~1", "2000-01-02 03:04:05.000678", "Europe/Rome"),
-        ("2000,7,2,3,4,5,678~2", "2000-07-02 03:04:05.000678", "Europe/Rome"),
-        ("2000,1,2,3,0,0,456789~2", "2000-01-02 03:00:00.456789", "-02:00"),
-        ("2000,1,2,3,0,0,456789~-2", "2000-01-02 03:00:00.456789", "+02:00"),
-        ("2000,12,31~2", "2000-12-31", "-02:00"),
-        ("1900,1,1~05:21:10", "1900-01-01", "Asia/Calcutta"),
-    ],
-)
+load_datetimetz_samples = [
+    ("2000,1,1~2", "2000-01-01", "-02:00"),
+    ("2000,1,2,3,4,5,6~2", "2000-01-02 03:04:05.000006", "-02:00"),
+    ("2000,1,2,3,4,5,678~1", "2000-01-02 03:04:05.000678", "Europe/Rome"),
+    ("2000,7,2,3,4,5,678~2", "2000-07-02 03:04:05.000678", "Europe/Rome"),
+    ("2000,1,2,3,0,0,456789~2", "2000-01-02 03:00:00.456789", "-02:00"),
+    ("2000,1,2,3,0,0,456789~-2", "2000-01-02 03:00:00.456789", "+02:00"),
+    ("2000,12,31~2", "2000-12-31", "-02:00"),
+    ("1900,1,1~05:21:10", "1900-01-01", "Asia/Calcutta"),
+]
+
+
+@pytest.mark.parametrize("val, expr, timezone", load_datetimetz_samples)
 @pytest.mark.parametrize("datestyle_out", ["ISO"])
 def test_load_datetimetz(conn, val, expr, timezone, datestyle_out):
     cur = conn.cursor(binary=False)
@@ -272,6 +264,14 @@ def test_load_datetimetz(conn, val, expr, timezone, datestyle_out):
     assert cur.fetchone()[0] == as_dt(val)
 
 
+@pytest.mark.parametrize("val, expr, timezone", load_datetimetz_samples)
+def test_load_datetimetz_binary(conn, val, expr, timezone):
+    cur = conn.cursor(binary=True)
+    cur.execute(f"set timezone to '{timezone}'")
+    cur.execute(f"select '{expr}'::timestamptz")
+    assert cur.fetchone()[0] == as_dt(val)
+
+
 @pytest.mark.xfail  # parse timezone names
 @pytest.mark.parametrize("val, expr", [("2000,1,1~2", "2000-01-01")])
 @pytest.mark.parametrize("datestyle_out", ["SQL", "Postgres", "German"])
@@ -294,8 +294,6 @@ def test_load_datetimetz_tzname(conn, val, expr, datestyle_in, datestyle_out):
 )
 @pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT, Format.BINARY])
 def test_dump_datetime_tz_or_not_tz(conn, val, type, fmt_in):
-    if fmt_in == Format.BINARY:
-        pytest.xfail("binary datetime not implemented")
     val = as_dt(val)
     cur = conn.cursor()
     cur.execute(