]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Base date and timestamp loaders on regexp instead of strptime
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 15 May 2021 10:47:58 +0000 (12:47 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 15 May 2021 10:47:58 +0000 (12:47 +0200)
Not only is 4x faster, but it's more straightforward to port to C
and keep similar algorithms (even if we decide to drop regexp in C)

psycopg3/psycopg3/types/date.py
tests/types/test_date.py

index c959666a2c6bc8341d13c949195d15c8f5afe2fd..d9e7c737fcad2732f35301ddc794335f3c4faef1 100644 (file)
@@ -8,7 +8,7 @@ import re
 import sys
 import struct
 from datetime import date, datetime, time, timedelta, timezone
-from typing import Callable, cast, Optional, Tuple, Union
+from typing import Callable, cast, Optional, Pattern, Tuple, Union
 
 from ..pq import Format
 from ..oids import postgres_types as builtins
@@ -251,67 +251,59 @@ class TimeDeltaBinaryDumper(Dumper):
         return _pack_interval(micros, obj.days, 0)
 
 
-class DateLoader(Loader):
-
+class _DTTextLoader(Loader):
     format = Format.TEXT
+    _re_format: Pattern[bytes]
 
     def __init__(self, oid: int, context: Optional[AdaptContext] = None):
         super().__init__(oid, context)
-        self._format = self._format_from_context()
+        self._order = self._order_from_context()
+
+    def _order_from_context(self) -> Tuple[int, ...]:
+        raise NotImplementedError
+
+    def _get_datestyle(self) -> bytes:
+        if self.connection:
+            ds = self.connection.pgconn.parameter_status(b"DateStyle")
+            if ds:
+                return ds
+
+        return b"ISO, DMY"
+
+
+class DateLoader(_DTTextLoader):
+
+    _re_format = re.compile(rb"^(\d+)[^\d](\d+)[^\d](\d+)( BC)?")
 
     def load(self, data: Buffer) -> date:
-        if isinstance(data, memoryview):
-            data = bytes(data)
+        m = self._re_format.match(data)
+        if not m:
+            s = bytes(data).decode("utf8", "replace")
+            raise DataError(f"can't parse date {s!r}")
+
+        t = m.groups()
+        ye, mo, da = (t[i] for i in self._order)
         try:
-            return datetime.strptime(data.decode("utf8"), self._format).date()
+            if t[3]:
+                raise ValueError("BC dates not supported")
+            return date(int(ye), int(mo), int(da))
         except ValueError as e:
-            return self._raise_error(data, e)
+            s = bytes(data).decode("utf8", "replace")
+            raise DataError(f"can't manage date {s!r}: {e}")
 
-    def _format_from_context(self) -> str:
+    def _order_from_context(self) -> Tuple[int, int, int]:
         ds = self._get_datestyle()
         if ds.startswith(b"I"):  # ISO
-            return "%Y-%m-%d"
+            return (0, 1, 2)
         elif ds.startswith(b"G"):  # German
-            return "%d.%m.%Y"
+            return (2, 1, 0)
         elif ds.startswith(b"S"):  # SQL
-            return "%d/%m/%Y" if ds.endswith(b"DMY") else "%m/%d/%Y"
+            return (2, 1, 0) if ds.endswith(b"DMY") else (2, 0, 1)
         elif ds.startswith(b"P"):  # Postgres
-            return "%d-%m-%Y" if ds.endswith(b"DMY") else "%m-%d-%Y"
+            return (2, 1, 0) if ds.endswith(b"DMY") else (2, 0, 1)
         else:
             raise InterfaceError(f"unexpected DateStyle: {ds.decode('ascii')}")
 
-    def _get_datestyle(self) -> bytes:
-        rv = b"ISO, DMY"
-        if self.connection:
-            ds = self.connection.pgconn.parameter_status(b"DateStyle")
-            if ds:
-                rv = ds
-
-        return rv
-
-    def _raise_error(self, data: bytes, exc: ValueError) -> date:
-        # Most likely we received a BC date, which Python doesn't support
-        # Otherwise the unexpected value is displayed in the exception.
-        if data.endswith(b"BC"):
-            raise DataError(
-                "Python doesn't support BC date:"
-                f" got {data.decode('utf8', 'replace')}"
-            )
-
-        if self._get_year_digits(data) > 4:
-            raise DataError(
-                "Python date doesn't support years after 9999:"
-                f" got {data.decode('utf8', 'replace')}"
-            )
-
-        # We genuinely received something we cannot parse
-        raise exc
-
-    def _get_year_digits(self, data: bytes) -> int:
-        datesep = self._format[2].encode("ascii")
-        parts = data.split(b" ")[0].split(datesep)
-        return max(map(len, parts))
-
 
 class DateBinaryLoader(Loader):
 
@@ -442,62 +434,78 @@ if sys.version_info < (3, 7):
     )
 
 
-class TimestampLoader(DateLoader):
+class TimestampLoader(_DTTextLoader):
 
     format = Format.TEXT
-
-    def __init__(self, oid: int, context: Optional[AdaptContext] = None):
-        super().__init__(oid, context)
-        self._format_no_micro = self._format.replace(".%f", "")
+    _re_format = re.compile(
+        rb"""(?ix)
+        ^
+        (?:(\d+)|[a-z]+)    [^a-z0-9]   # DoW or first number, separator
+        (\d+|[a-z]+)        [^a-z0-9]   # Month name or second number, separator
+        (\d+|[a-z]+)                    # Month name or thrid number
+                    (?: T | [^a-z0-9] ) # Separator, including T
+        (\d+)               [^a-z0-9]   # Other 3 numbers
+        (\d+)               [^a-z0-9]
+        (\d+)
+        (?: \.(\d+) )?                  # micros
+        (?: [^a-z0-9] (\d+) )?          # year in PG format
+        ( \s* BC)?                      # BC
+        $
+        """
+    )
 
     def load(self, data: Buffer) -> datetime:
-        if isinstance(data, memoryview):
-            data = bytes(data)
+        m = self._re_format.match(data)
+        if not m:
+            s = bytes(data).decode("utf8", "replace")
+            raise DataError(f"can't parse timestamp {s!r}")
+
+        t = m.groups()
+        ye, mo, da, ho, mi, se, ms = (t[i] for i in self._order)
+
+        if ms is None:
+            ms = b"0"
+        elif len(ms) < 6:
+            ms += b"0" * (6 - len(ms))
 
-        # check if the data contains microseconds
-        fmt = (
-            self._format if data.find(b".", 19) >= 0 else self._format_no_micro
-        )
         try:
-            return datetime.strptime(data.decode("utf8"), fmt)
+            if not b"0" <= mo[0:1] <= b"9":
+                try:
+                    mo = _month_abbr[mo]
+                except KeyError:
+                    s = mo.decode("utf8", "replace")
+                    raise DataError(f"unexpected month: {s!r}")
+
+            if t[8]:
+                raise ValueError("BC dates not supported")
+            return datetime(
+                int(ye), int(mo), int(da), int(ho), int(mi), int(se), int(ms)
+            )
         except ValueError as e:
-            return self._raise_error(data, e)
+            s = bytes(data).decode("utf8", "replace")
+            raise DataError(f"can't manage timestamp {s!r}: {e}")
 
-    def _format_from_context(self) -> str:
+    def _order_from_context(self) -> Tuple[int, int, int, int, int, int, int]:
         ds = self._get_datestyle()
         if ds.startswith(b"I"):  # ISO
-            return "%Y-%m-%d %H:%M:%S.%f"
+            return (0, 1, 2, 3, 4, 5, 6)
         elif ds.startswith(b"G"):  # German
-            return "%d.%m.%Y %H:%M:%S.%f"
+            return (2, 1, 0, 3, 4, 5, 6)
         elif ds.startswith(b"S"):  # SQL
             return (
-                "%d/%m/%Y %H:%M:%S.%f"
+                (2, 1, 0, 3, 4, 5, 6)
                 if ds.endswith(b"DMY")
-                else "%m/%d/%Y %H:%M:%S.%f"
+                else (2, 0, 1, 3, 4, 5, 6)
             )
         elif ds.startswith(b"P"):  # Postgres
             return (
-                "%a %d %b %H:%M:%S.%f %Y"
+                (7, 2, 1, 3, 4, 5, 6)
                 if ds.endswith(b"DMY")
-                else "%a %b %d %H:%M:%S.%f %Y"
+                else (7, 1, 2, 3, 4, 5, 6)
             )
         else:
             raise InterfaceError(f"unexpected DateStyle: {ds.decode('ascii')}")
 
-    def _raise_error(self, data: bytes, exc: ValueError) -> datetime:
-        return cast(datetime, super()._raise_error(data, exc))
-
-    def _get_year_digits(self, data: bytes) -> int:
-        # Find the year from the date.
-        if not self._get_datestyle().startswith(b"P"):  # Postgres
-            return super()._get_year_digits(data)
-        else:
-            parts = data.split()
-            if len(parts) > 4:
-                return len(parts[4])
-            else:
-                return 0
-
 
 class TimestampBinaryLoader(Loader):
 
@@ -713,3 +721,11 @@ class IntervalBinaryLoader(Loader):
             years, months = divmod(-months, 12)
             days = days - 30 * months - 365 * years
         return timedelta(days=days, microseconds=micros)
+
+
+_month_abbr = {
+    n: str(i).encode("utf8")
+    for i, n in enumerate(
+        b"Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec".split(), 1
+    )
+}
index 71567fb96aaf57f5813ea21af5025651d6e48448..f22c70bffca93c0fdc36af1976261c879b1594f5 100644 (file)
@@ -184,6 +184,14 @@ class TestDatetime:
         with pytest.raises(DataError):
             cur.fetchone()[0]
 
+    def test_load_all_month_names(self, conn):
+        cur = conn.cursor(binary=False)
+        cur.execute("set datestyle = 'Postgres'")
+        for i in range(12):
+            d = dt.datetime(2000, i + 1, 15)
+            cur.execute("select %s", [d])
+            assert cur.fetchone()[0] == d
+
 
 class TestDateTimeTz:
     @pytest.mark.parametrize(