]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Return timestamptz in the connection timezone
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 14 May 2021 15:50:06 +0000 (17:50 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 14 May 2021 15:50:06 +0000 (17:50 +0200)
First cut, more tests to add. All current tests pass, except the
explicit checks for UTC tzinfo returned.

See https://github.com/psycopg/psycopg3/discussions/56

psycopg3/psycopg3/connection.py
psycopg3/psycopg3/types/date.py
psycopg3/psycopg3/utils/compat.py
psycopg3/setup.cfg
tests/types/test_date.py

index d5472437ac2122238f2c8fe5f9e88be8732cc515..7ba7595dae78ca4d67bd82030f2ecf6082f7df2d 100644 (file)
@@ -31,7 +31,7 @@ from .conninfo import make_conninfo, ConnectionInfo
 from .generators import notifies
 from ._preparing import PrepareManager
 from .transaction import Transaction, AsyncTransaction
-from .utils.compat import asynccontextmanager
+from .utils.compat import asynccontextmanager, ZoneInfo
 from .server_cursor import ServerCursor, AsyncServerCursor
 
 logger = logging.getLogger("psycopg3")
@@ -59,6 +59,8 @@ else:
     connect = generators.connect
     execute = generators.execute
 
+_UTC = ZoneInfo("UTC")
+
 
 class Notify(NamedTuple):
     """An asynchronous notification received from the database."""
@@ -219,6 +221,22 @@ class BaseConnection(AdaptContext, Generic[Row]):
         if result.status != ExecStatus.TUPLES_OK:
             raise e.error_from_result(result, encoding=self.client_encoding)
 
+    @property
+    def timezone(self) -> ZoneInfo:
+        """The Python timezone info of the connection's timezone."""
+        tzname = self.pgconn.parameter_status(b"TimeZone")
+        if tzname:
+            try:
+                return ZoneInfo(tzname.decode("utf8"))
+            except KeyError:
+                logger.warning(
+                    "unknown PostgreSQL timezone: %r will use UTC",
+                    tzname.decode("utf8"),
+                )
+                return _UTC
+        else:
+            return _UTC
+
     @property
     def info(self) -> ConnectionInfo:
         """A `ConnectionInfo` attribute to inspect connection properties."""
index e4811494e9de477d0313ea81b1d6ef2467f8fa36..892a14227a799a426b1fa0904755f80981d4f806 100644 (file)
@@ -15,6 +15,7 @@ from ..oids import postgres_types as builtins
 from ..adapt import Buffer, Dumper, Loader, Format as Pg3Format
 from ..proto import AdaptContext
 from ..errors import InterfaceError, DataError
+from ..utils.compat import ZoneInfo
 
 _PackInt = Callable[[int], bytes]
 _UnpackInt = Callable[[bytes], Tuple[int]]
@@ -40,6 +41,8 @@ _pg_datetime_epoch = datetime(2000, 1, 1)
 _pg_datetimetz_epoch = datetime(2000, 1, 1, tzinfo=timezone.utc)
 _py_date_min_days = date.min.toordinal()
 
+_UTC = ZoneInfo("UTC")
+
 
 class DateDumper(Dumper):
 
@@ -517,6 +520,10 @@ class TimestampTzLoader(TimestampLoader):
 
     format = Format.TEXT
 
+    def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+        super().__init__(oid, context)
+        self._timezone = self.connection.timezone if self.connection else _UTC
+
     def _format_from_context(self) -> str:
         ds = self._get_datestyle()
         if ds.startswith(b"I"):  # ISO
@@ -557,7 +564,7 @@ class TimestampTzLoader(TimestampLoader):
         if data[-3] in (43, 45):
             data += b"00"
 
-        return super().load(data).astimezone(timezone.utc)
+        return super().load(data).astimezone(self._timezone)
 
     def _load_py36(self, data: Buffer) -> datetime:
         if isinstance(data, memoryview):
@@ -579,7 +586,7 @@ class TimestampTzLoader(TimestampLoader):
             tzoff = -tzoff
 
         rv = super().load(data[: m.start()])
-        return (rv - tzoff).replace(tzinfo=timezone.utc)
+        return (rv - tzoff).replace(tzinfo=self._timezone)
 
     def _load_notimpl(self, data: Buffer) -> datetime:
         if isinstance(data, memoryview):
@@ -598,10 +605,15 @@ class TimestampTzBinaryLoader(Loader):
 
     format = Format.BINARY
 
+    def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+        super().__init__(oid, context)
+        self._timezone = self.connection.timezone if self.connection else _UTC
+
     def load(self, data: Buffer) -> datetime:
         micros = _unpack_int8(data)[0]
         try:
-            return _pg_datetimetz_epoch + timedelta(microseconds=micros)
+            ts = _pg_datetimetz_epoch + timedelta(microseconds=micros)
+            return ts.astimezone(self._timezone)
         except OverflowError:
             if micros <= 0:
                 raise DataError("timestamp too small (before year 1)")
index 7f97bc05a451899fae7976e9f94611c2eac503ab..5c1ed180dc978d0bafdf156262b7a5eadaf9799d 100644 (file)
@@ -49,9 +49,14 @@ else:
 
     Task = asyncio.Future
 
+if sys.version_info >= (3, 9):
+    from zoneinfo import ZoneInfo
+else:
+    from backports.zoneinfo import ZoneInfo
 
 __all__ = [
     "Protocol",
+    "ZoneInfo",
     "asynccontextmanager",
     "create_task",
     "get_running_loop",
index 85164391f9361b1ee8a478946d1aae878017c46a..1bc23b383c9d0a11f8efc8a807756979f811265e 100644 (file)
@@ -30,6 +30,7 @@ python_requires = >= 3.6
 packages = find:
 zip_safe = False
 install_requires =
+    backports.zoneinfo; python_version < "3.9"
     typing_extensions; python_version < "3.8"
 
 [options.package_data]
index 5041fdded93a95410e18f7b731a21c4ef37530f8..607eb45d9577558d7df34c6d15e05be4127173b4 100644 (file)
@@ -258,7 +258,6 @@ def test_load_datetimetz(conn, val, expr, timezone, datestyle_out):
     cur.execute(f"set timezone to '{timezone}'")
     got = cur.execute(f"select '{expr}'::timestamptz").fetchone()[0]
     assert got == as_dt(val)
-    assert got.tzinfo == dt.timezone.utc
 
 
 @pytest.mark.parametrize("val, expr, timezone", load_datetimetz_samples)
@@ -267,7 +266,6 @@ def test_load_datetimetz_binary(conn, val, expr, timezone):
     cur.execute(f"set timezone to '{timezone}'")
     got = cur.execute(f"select '{expr}'::timestamptz").fetchone()[0]
     assert got == as_dt(val)
-    assert got.tzinfo == dt.timezone.utc
 
 
 @pytest.mark.xfail  # parse timezone names