]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Base timestamptz loader on regexp
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 15 May 2021 13:41:26 +0000 (15:41 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 15 May 2021 14:04:48 +0000 (16:04 +0200)
About 4.5x faster than strptime.

psycopg3/psycopg3/types/date.py

index 8490cd8bcfb0f94f45ac7ac616db62827382cdb7..7eb59e8260cae08768450e4871e56d5b99518dec 100644 (file)
@@ -8,7 +8,7 @@ import re
 import sys
 import struct
 from datetime import date, datetime, time, timedelta, timezone
-from typing import Callable, cast, Optional, Pattern, Tuple, Union
+from typing import Any, Callable, cast, Optional, Tuple, Union, TYPE_CHECKING
 
 from ..pq import Format
 from ..oids import postgres_types as builtins
@@ -17,6 +17,9 @@ from ..proto import AdaptContext
 from ..errors import InterfaceError, DataError
 from .._tz import get_tzinfo
 
+if TYPE_CHECKING:
+    from ..connection import BaseConnection
+
 _PackInt = Callable[[int], bytes]
 _UnpackInt = Callable[[bytes], Tuple[int]]
 
@@ -36,9 +39,10 @@ _unpack_interval = cast(
     Callable[[bytes], Tuple[int, int, int]], struct.Struct("!qii").unpack
 )
 
+utc = timezone.utc
 _pg_date_epoch_days = date(2000, 1, 1).toordinal()
 _pg_datetime_epoch = datetime(2000, 1, 1)
-_pg_datetimetz_epoch = datetime(2000, 1, 1, tzinfo=timezone.utc)
+_pg_datetimetz_epoch = datetime(2000, 1, 1, tzinfo=utc)
 _py_date_min_days = date.min.toordinal()
 
 
@@ -251,30 +255,15 @@ class TimeDeltaBinaryDumper(Dumper):
         return _pack_interval(micros, obj.days, 0)
 
 
-class _DTTextLoader(Loader):
+class DateLoader(Loader):
+
     format = Format.TEXT
-    _re_format: Pattern[bytes]
+    _re_format = re.compile(rb"^(\d+)[^\d](\d+)[^\d](\d+)$")
 
     def __init__(self, oid: int, context: Optional[AdaptContext] = None):
         super().__init__(oid, context)
         self._order = self._order_from_context()
 
-    def _order_from_context(self) -> Tuple[int, ...]:
-        raise NotImplementedError
-
-    def _get_datestyle(self) -> bytes:
-        if self.connection:
-            ds = self.connection.pgconn.parameter_status(b"DateStyle")
-            if ds:
-                return ds
-
-        return b"ISO, DMY"
-
-
-class DateLoader(_DTTextLoader):
-
-    _re_format = re.compile(rb"^(\d+)[^\d](\d+)[^\d](\d+)$")
-
     def load(self, data: Buffer) -> date:
         m = self._re_format.match(data)
         if not m:
@@ -292,7 +281,7 @@ class DateLoader(_DTTextLoader):
             raise DataError(f"can't manage date {s!r}: {e}")
 
     def _order_from_context(self) -> Tuple[int, int, int]:
-        ds = self._get_datestyle()
+        ds = _get_datestyle(self.connection)
         if ds.startswith(b"I"):  # ISO
             return (0, 1, 2)
         elif ds.startswith(b"G"):  # German
@@ -440,7 +429,7 @@ if sys.version_info < (3, 7):
     )
 
 
-class TimestampLoader(_DTTextLoader):
+class TimestampLoader(Loader):
 
     format = Format.TEXT
     _re_format = re.compile(
@@ -459,6 +448,10 @@ class TimestampLoader(_DTTextLoader):
         """
     )
 
+    def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+        super().__init__(oid, context)
+        self._order = self._order_from_context()
+
     def load(self, data: Buffer) -> datetime:
         m = self._re_format.match(data)
         if not m:
@@ -495,7 +488,7 @@ class TimestampLoader(_DTTextLoader):
             raise DataError(f"can't manage timestamp {s!r}: {e}")
 
     def _order_from_context(self) -> Tuple[int, int, int, int, int, int, int]:
-        ds = self._get_datestyle()
+        ds = _get_datestyle(self.connection)
         if ds.startswith(b"I"):  # ISO
             return (0, 1, 2, 3, 4, 5, 6)
         elif ds.startswith(b"G"):  # German
@@ -531,9 +524,20 @@ class TimestampBinaryLoader(Loader):
                 raise DataError("timestamp too large (after year 10K)")
 
 
-class TimestampTzLoader(TimestampLoader):
+class TimestampTzLoader(Loader):
 
     format = Format.TEXT
+    _re_format = re.compile(
+        rb"""(?ix)
+        ^
+        (\d+) [^a-z0-9] (\d+) [^a-z0-9] (\d+)       # Date
+        (?: T | [^a-z0-9] )                         # Separator, including T
+        (\d+) [^a-z0-9] (\d+) [^a-z0-9] (\d+)       # Time
+        (?: \.(\d+) )?                              # Micros
+        (-|\+) (\d+) (?: : (\d+) )? (?: : (\d+) )?  # Timezone
+        $
+        """
+    )
 
     def __init__(self, oid: int, context: Optional[AdaptContext] = None):
         super().__init__(oid, context)
@@ -541,87 +545,54 @@ class TimestampTzLoader(TimestampLoader):
             self.connection.pgconn if self.connection else None
         )
 
-    def _format_from_context(self) -> str:
-        ds = self._get_datestyle()
-        if ds.startswith(b"I"):  # ISO
-            if sys.version_info >= (3, 7):
-                return "%Y-%m-%d %H:%M:%S.%f%z"
-            else:
-                # No tz parsing: it will be handles separately.
-                return "%Y-%m-%d %H:%M:%S.%f"
-
-        # These don't work: the timezone name is not always displayed
-        # elif ds.startswith(b"G"):  # German
-        #     return "%d.%m.%Y %H:%M:%S.%f %Z"
-        # elif ds.startswith(b"S"):  # SQL
-        #     return (
-        #         "%d/%m/%Y %H:%M:%S.%f %Z"
-        #         if ds.endswith(b"DMY")
-        #         else "%m/%d/%Y %H:%M:%S.%f %Z"
-        #     )
-        # elif ds.startswith(b"P"):  # Postgres
-        #     return (
-        #         "%a %d %b %H:%M:%S.%f %Y %Z"
-        #         if ds.endswith(b"DMY")
-        #         else "%a %b %d %H:%M:%S.%f %Y %Z"
-        #     )
-        # else:
-        #     raise InterfaceError(f"unexpected DateStyle: {ds.decode('ascii')}")
-        else:
+        ds = _get_datestyle(self.connection)
+        if not ds.startswith(b"I"):  # not ISO
             setattr(self, "load", self._load_notimpl)
-            return ""
-
-    _re_tz = re.compile(br"([-+])(\d+)(?::(\d+)(?::(\d+))?)?$")
 
     def load(self, data: Buffer) -> datetime:
-        if isinstance(data, memoryview):
-            data = bytes(data)
+        m = self._re_format.match(data)
+        if not m:
+            s = bytes(data).decode("utf8", "replace")
+            if s.endswith("BC"):
+                raise DataError(f"BC timestamps not supported, got {s!r}")
+            raise DataError(f"can't parse timestamp {s!r}")
 
-        # Hack to convert +HH in +HHMM
-        if data[-3] in (43, 45):
-            data += b"00"
+        ye, mo, da, ho, mi, se, ms, sgn, oh, om, os = m.groups()
 
-        return super().load(data).astimezone(self._timezone)
+        # Pad the fraction of second to get millis
+        if ms:
+            if len(ms) == 6:
+                ims = int(ms)
+            else:
+                ims = int(ms + _ms_trail[len(ms)])
+        else:
+            ims = 0
 
-    def _load_py36(self, data: Buffer) -> datetime:
-        if isinstance(data, memoryview):
-            data = bytes(data)
+        # Calculate timezone offset
+        soff = 60 * 60 * int(oh)
+        if om:
+            soff += 60 * int(om)
+        if os:
+            soff += int(os)
+        tzoff = timedelta(0, soff if sgn == b"+" else -soff)
 
-        # Separate the timezone from the rest
-        m = self._re_tz.search(data)
-        if not m:
-            raise DataError(
-                "failed to parse timezone from '{data.decode('ascii')}'"
+        try:
+            dt = datetime(
+                int(ye), int(mo), int(da), int(ho), int(mi), int(se), ims, utc
             )
-
-        sign, hour, min, sec = m.groups()
-        tzoff = timedelta(
-            seconds=(int(sec) if sec else 0)
-            + 60 * ((int(min) if min else 0) + 60 * int(hour))
-        )
-        if sign == b"-":
-            tzoff = -tzoff
-
-        rv = super().load(data[: m.start()])
-        return (
-            (rv - tzoff)
-            .replace(tzinfo=timezone.utc)
-            .astimezone(self._timezone)
-        )
+            return (dt - tzoff).astimezone(self._timezone)
+        except ValueError as e:
+            s = bytes(data).decode("utf8", "replace")
+            raise DataError(f"can't manage timestamp {s!r}: {e}")
 
     def _load_notimpl(self, data: Buffer) -> datetime:
-        if isinstance(data, memoryview):
-            data = bytes(data)
+        s = bytes(data).decode("utf8", "replace")
+        ds = _get_datestyle(self.connection).decode("ascii")
         raise NotImplementedError(
-            "can't parse datetimetz with DateStyle"
-            f" {self._get_datestyle().decode('ascii')}: {data.decode('ascii')}"
+            f"can't parse datetimetz with DateStyle {ds!r}: {s!r}"
         )
 
 
-if sys.version_info < (3, 7):
-    setattr(TimestampTzLoader, "load", TimestampTzLoader._load_py36)
-
-
 class TimestampTzBinaryLoader(Loader):
 
     format = Format.BINARY
@@ -732,6 +703,15 @@ class IntervalBinaryLoader(Loader):
         return timedelta(days=days, microseconds=micros)
 
 
+def _get_datestyle(conn: Optional["BaseConnection[Any]"]) -> bytes:
+    if conn:
+        ds = conn.pgconn.parameter_status(b"DateStyle")
+        if ds:
+            return ds
+
+    return b"ISO, DMY"
+
+
 _month_abbr = {
     n: str(i).encode("utf8")
     for i, n in enumerate(