]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Move timezone attribute to the Connection.info interface
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 14 May 2021 16:37:24 +0000 (18:37 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 14 May 2021 17:43:40 +0000 (19:43 +0200)
Type changed to the abstract `datetime.tzinfo` so that we have the
option to return something else such as a `datetime.timezone` with fixed
offset.

Added function to retrieve the object internally from the PGconn.

psycopg3/psycopg3/_tz.py [new file with mode: 0644]
psycopg3/psycopg3/connection.py
psycopg3/psycopg3/conninfo.py
psycopg3/psycopg3/types/date.py

diff --git a/psycopg3/psycopg3/_tz.py b/psycopg3/psycopg3/_tz.py
new file mode 100644 (file)
index 0000000..a96b397
--- /dev/null
@@ -0,0 +1,38 @@
+"""
+Timezone utility functions.
+"""
+
+# Copyright (C) 2020-2021 The Psycopg Team
+
+import logging
+from typing import Dict, Optional, Union
+from datetime import timezone, tzinfo
+
+from .pq.proto import PGconn
+from .utils.compat import ZoneInfo
+
+logger = logging.getLogger("psycopg3")
+
+_timezones: Dict[Union[None, bytes], tzinfo] = {
+    None: timezone.utc,
+    b"UTC": timezone.utc,
+}
+
+
+def get_tzinfo(pgconn: Optional[PGconn]) -> tzinfo:
+    """Return the Python timezone info of the connection's timezone."""
+    tzname = pgconn.parameter_status(b"TimeZone") if pgconn else None
+    try:
+        return _timezones[tzname]
+    except KeyError:
+        sname = tzname.decode("utf8") if tzname else "UTC"
+        try:
+            zi: tzinfo = ZoneInfo(sname)
+        except KeyError:
+            logger.warning(
+                "unknown PostgreSQL timezone: %r; will use UTC", sname
+            )
+            zi = timezone.utc
+
+        _timezones[tzname] = zi
+        return zi
index c4bd0a9c2afd3894c9bc7d9af72afa59b64d9055..d5472437ac2122238f2c8fe5f9e88be8732cc515 100644 (file)
@@ -9,7 +9,7 @@ import logging
 import warnings
 import threading
 from types import TracebackType
-from typing import Any, AsyncIterator, Dict, Callable, Generic, Iterator, List
+from typing import Any, AsyncIterator, Callable, Generic, Iterator, List
 from typing import NamedTuple, Optional, Type, TypeVar, Union
 from typing import overload, TYPE_CHECKING
 from weakref import ref, ReferenceType
@@ -31,7 +31,7 @@ from .conninfo import make_conninfo, ConnectionInfo
 from .generators import notifies
 from ._preparing import PrepareManager
 from .transaction import Transaction, AsyncTransaction
-from .utils.compat import asynccontextmanager, ZoneInfo
+from .utils.compat import asynccontextmanager
 from .server_cursor import ServerCursor, AsyncServerCursor
 
 logger = logging.getLogger("psycopg3")
@@ -59,8 +59,6 @@ else:
     connect = generators.connect
     execute = generators.execute
 
-_timezones: Dict[Union[None, bytes], ZoneInfo] = {}
-
 
 class Notify(NamedTuple):
     """An asynchronous notification received from the database."""
@@ -221,25 +219,6 @@ class BaseConnection(AdaptContext, Generic[Row]):
         if result.status != ExecStatus.TUPLES_OK:
             raise e.error_from_result(result, encoding=self.client_encoding)
 
-    @property
-    def timezone(self) -> ZoneInfo:
-        """The Python timezone info of the connection's timezone."""
-        tzname = self.pgconn.parameter_status(b"TimeZone")
-        try:
-            return _timezones[tzname]
-        except KeyError:
-            sname = tzname.decode("utf8") if tzname else "UTC"
-            try:
-                zi = ZoneInfo(sname)
-            except KeyError:
-                logger.warning(
-                    "unknown PostgreSQL timezone: %r will use UTC", sname
-                )
-                zi = ZoneInfo("UTC")
-
-            _timezones[tzname] = zi
-            return zi
-
     @property
     def info(self) -> ConnectionInfo:
         """A `ConnectionInfo` attribute to inspect connection properties."""
index b2171b6834d80395e6e2f4edddd6ea1524cfd07f..e137f1dfe83c204c961cfed9aa23cb11b90d8a50 100644 (file)
@@ -7,10 +7,12 @@ Functions to manipulate conninfo strings
 import re
 from typing import Any, Dict, List, Optional
 from pathlib import Path
+from datetime import tzinfo
 
 from . import pq
 from . import errors as e
 from . import encodings
+from ._tz import get_tzinfo
 
 
 def make_conninfo(conninfo: str = "", **kwargs: Any) -> str:
@@ -224,6 +226,11 @@ class ConnectionInfo:
         """
         return self._get_pgconn_attr("error_message")
 
+    @property
+    def timezone(self) -> tzinfo:
+        """The Python timezone info of the connection's timezone."""
+        return get_tzinfo(self.pgconn)
+
     def _get_pgconn_attr(self, name: str) -> str:
         value: bytes = getattr(self.pgconn, name)
         return value.decode(self._pyenc)
index 844a8bef909e375276d828dde2c2fab9836d39fd..4e4e8f8bee5ef8271c8d0ed72e5751a1cbcc45b6 100644 (file)
@@ -15,7 +15,7 @@ from ..oids import postgres_types as builtins
 from ..adapt import Buffer, Dumper, Loader, Format as Pg3Format
 from ..proto import AdaptContext
 from ..errors import InterfaceError, DataError
-from ..utils.compat import ZoneInfo
+from .._tz import get_tzinfo
 
 _PackInt = Callable[[int], bytes]
 _UnpackInt = Callable[[bytes], Tuple[int]]
@@ -520,8 +520,8 @@ class TimestampTzLoader(TimestampLoader):
 
     def __init__(self, oid: int, context: Optional[AdaptContext] = None):
         super().__init__(oid, context)
-        self._timezone = (
-            self.connection.timezone if self.connection else ZoneInfo("UTC")
+        self._timezone = get_tzinfo(
+            self.connection.pgconn if self.connection else None
         )
 
     def _format_from_context(self) -> str:
@@ -607,8 +607,8 @@ class TimestampTzBinaryLoader(Loader):
 
     def __init__(self, oid: int, context: Optional[AdaptContext] = None):
         super().__init__(oid, context)
-        self._timezone = (
-            self.connection.timezone if self.connection else ZoneInfo("UTC")
+        self._timezone = get_tzinfo(
+            self.connection.pgconn if self.connection else None
         )
 
     def load(self, data: Buffer) -> datetime: