import re
import sys
from datetime import date, datetime, time, timedelta
-from typing import cast, Optional
+from typing import cast, Optional, Tuple, Union
from ..pq import Format
from ..oids import builtins
-from ..adapt import Buffer, Dumper, Loader
+from ..adapt import Buffer, Dumper, Loader, Format as Pg3Format
from ..proto import AdaptContext
from ..errors import InterfaceError, DataError
class TimeDumper(Dumper):
format = Format.TEXT
- _oid = builtins["timetz"].oid
+
+ # Can change to timetz type if the object dumped is naive
+ _oid = builtins["time"].oid
def dump(self, obj: time) -> bytes:
return str(obj).encode("utf8")
+ def get_key(
+ self, obj: time, format: Pg3Format
+ ) -> Union[type, Tuple[type]]:
+ # Use (cls,) to report the need to upgrade to a dumper for timetz (the
+ # Frankenstein of the data types).
+ if not obj.tzinfo:
+ return self.cls
+ else:
+ return (self.cls,)
+
+ def upgrade(self, obj: time, format: Pg3Format) -> "Dumper":
+ if not obj.tzinfo:
+ return self
+ else:
+ return TimeTzDumper(self.cls)
+
+
+class TimeTzDumper(TimeDumper):
+
+ _oid = builtins["timetz"].oid
+
-class DateTimeDumper(Dumper):
+class DateTimeTzDumper(Dumper):
format = Format.TEXT
+
+ # Can change to timestamp type if the object dumped is naive
_oid = builtins["timestamptz"].oid
- def dump(self, obj: date) -> bytes:
+ def dump(self, obj: datetime) -> bytes:
# NOTE: whatever the PostgreSQL DateStyle input format (DMY, MDY, YMD)
# the YYYY-MM-DD is always understood correctly.
return str(obj).encode("utf8")
+ def get_key(
+ self, obj: datetime, format: Pg3Format
+ ) -> Union[type, Tuple[type]]:
+ # Use (cls,) to report the need to upgrade (downgrade, actually) to a
+ # dumper for naive timestamp.
+ if obj.tzinfo:
+ return self.cls
+ else:
+ return (self.cls,)
+
+ def upgrade(self, obj: datetime, format: Pg3Format) -> "Dumper":
+ if obj.tzinfo:
+ return self
+ else:
+ return DateTimeDumper(self.cls)
+
+
+class DateTimeDumper(DateTimeTzDumper):
+ _oid = builtins["timestamp"].oid
+
class TimeDeltaDumper(Dumper):
assert cur.fetchone()[0] == as_dt(val)
+@pytest.mark.parametrize(
+ "val, type",
+ [
+ ("2000,1,2,3,4,5,6", "timestamp"),
+ ("2000,1,2,3,4,5,6~0", "timestamptz"),
+ ("2000,1,2,3,4,5,6~2", "timestamptz"),
+ ],
+)
+@pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT, Format.BINARY])
+def test_dump_datetime_tz_or_not_tz(conn, val, type, fmt_in):
+ if fmt_in == Format.BINARY:
+ pytest.xfail("binary datetime not implemented")
+ val = as_dt(val)
+ cur = conn.cursor()
+ cur.execute(
+ f"select pg_typeof(%{fmt_in}) = %s::regtype, %{fmt_in}",
+ [val, type, val],
+ )
+ rec = cur.fetchone()
+ assert rec[0] is True, type
+ assert rec[1] == val
+
+
#
# time tests
#
cur.fetchone()[0]
+@pytest.mark.parametrize(
+ "val, type",
+ [
+ ("3,4,5,6", "time"),
+ ("3,4,5,6~0", "timetz"),
+ ("3,4,5,6~2", "timetz"),
+ ],
+)
+@pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT, Format.BINARY])
+def test_dump_time_tz_or_not_tz(conn, val, type, fmt_in):
+ if fmt_in == Format.BINARY:
+ pytest.xfail("binary time not implemented")
+ val = as_time(val)
+ cur = conn.cursor()
+ cur.execute(
+ f"select pg_typeof(%{fmt_in}) = %s::regtype, %{fmt_in}",
+ [val, type, val],
+ )
+ rec = cur.fetchone()
+ assert rec[0] is True, type
+ assert rec[1] == val
+
+
#
# Interval
#